Skip to content

Commit c2ae2d9

Browse files
committed
Add energy based model with gibbs sampling
1 parent 88b60bb commit c2ae2d9

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

ebm.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
import itertools
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
8+
# --- Bars and Stripes Dataset (4x4) ---
9+
def generate_bars_stripes(n=4):
10+
images = []
11+
for row_pattern in itertools.product([0, 1], repeat=n):
12+
image = np.tile(np.array(row_pattern).reshape(n, 1), (1, n))
13+
images.append(image)
14+
for col_pattern in itertools.product([0, 1], repeat=n):
15+
image = np.tile(np.array(col_pattern).reshape(1, n), (n, 1))
16+
images.append(image)
17+
# Remove duplicates
18+
unique = []
19+
for img in images:
20+
if not any(np.array_equal(img, u) for u in unique):
21+
unique.append(img)
22+
return np.array(unique).astype(np.float32)
23+
24+
data_np = generate_bars_stripes(4)
25+
data = torch.tensor(data_np.reshape(len(data_np), -1)) # shape: (N, 16)
26+
27+
# --- EBM: MLP Energy Model ---
28+
class EBM(nn.Module):
29+
def __init__(self, input_dim):
30+
super().__init__()
31+
self.net = nn.Sequential(
32+
nn.Linear(input_dim, 64),
33+
nn.ReLU(),
34+
nn.Linear(64, 1)
35+
)
36+
37+
def forward(self, x):
38+
return self.net(x).squeeze(-1) # (N,)
39+
40+
model = EBM(16)
41+
optimizer = optim.Adam(model.parameters(), lr=1e-3)
42+
43+
# --- Sampling (Gibbs-style) ---
44+
@torch.no_grad()
45+
def gibbs_sample(model, x_init, steps=30):
46+
x = x_init.clone()
47+
for _ in range(steps):
48+
for i in range(x.shape[1]):
49+
x_flip = x.clone()
50+
x_flip[:, i] = 1 - x_flip[:, i] # Flip bit i
51+
e_orig = model(x)
52+
e_flip = model(x_flip)
53+
prob = torch.sigmoid(e_orig - e_flip) # Lower energy = more likely
54+
mask = (torch.rand(x.size(0)) < prob).float()
55+
x[:, i] = x[:, i] * mask + x_flip[:, i] * (1 - mask)
56+
return x
57+
58+
# --- Training Loop ---
59+
epochs = 1000
60+
batch_size = 64
61+
for epoch in range(epochs):
62+
idx = torch.randint(0, data.size(0), (batch_size,))
63+
x_data = data[idx]
64+
65+
x_noise = torch.bernoulli(torch.full_like(x_data, 0.5))
66+
x_neg = gibbs_sample(model, x_noise, steps=40)
67+
68+
energy_pos = model(x_data)
69+
energy_neg = model(x_neg)
70+
71+
loss = (energy_pos - energy_neg).mean()
72+
optimizer.zero_grad()
73+
loss.backward()
74+
optimizer.step()
75+
76+
if epoch % 100 == 0:
77+
print(f"Epoch {epoch}: Loss {loss.item():.4f}")
78+
79+
# --- Visualize Generated Samples ---
80+
samples = gibbs_sample(model, torch.bernoulli(torch.full_like(data, 0.5)), steps=100)
81+
samples = samples[:16].reshape(-1, 4, 4)
82+
83+
fig, axs = plt.subplots(4, 4, figsize=(6, 6))
84+
for ax, img in zip(axs.flat, samples):
85+
ax.imshow(img, cmap="gray", vmin=0, vmax=1)
86+
ax.axis("off")
87+
plt.tight_layout()
88+
plt.show()
89+

0 commit comments

Comments
 (0)