Skip to content

Commit 6423760

Browse files
IgorTavcarclaude
andcommitted
Add transformer components and improve autograd engine
- Simplify backward pass: replace _backward closures with _local_grads tuples (from karpathy#115) - Zero grads before backward for idempotent backward() calls (from karpathy#102) - Add exp, log, tanh, softmax to Value class - Add transformer components: Linear, Embedding, LayerNorm, Attention, MultiHeadAttention, FeedForward, TransformerBlock, Transformer, cross_entropy - Move single-output unwrapping from Layer to MLP (from karpathy#111) - Add input shape assertion in Neuron (from karpathy#107) - Add MLP test (from karpathy#111) - Expand .gitignore with standard Python patterns Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c911406 commit 6423760

File tree

4 files changed

+305
-30
lines changed

4 files changed

+305
-30
lines changed

.gitignore

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,58 @@
11
.ipynb_checkpoints/
2+
3+
# Byte-compiled / optimized / DLL files
4+
__pycache__/
5+
*.py[codz]
6+
*$py.class
7+
8+
# Distribution / packaging
9+
.Python
10+
build/
11+
develop-eggs/
12+
dist/
13+
downloads/
14+
eggs/
15+
.eggs/
16+
lib/
17+
lib64/
18+
parts/
19+
sdist/
20+
var/
21+
wheels/
22+
share/python-wheels/
23+
*.egg-info/
24+
.installed.cfg
25+
*.egg
26+
MANIFEST
27+
28+
# Jupyter Notebook
29+
.ipynb_checkpoints
30+
31+
# IPython
32+
profile_default/
33+
ipython_config.py
34+
35+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
36+
__pypackages__/
37+
38+
# Environments
39+
.env
40+
.envrc
41+
.venv
42+
env/
43+
venv/
44+
ENV/
45+
env.bak/
46+
venv.bak/
47+
48+
# pytype static type analyzer
49+
.pytype/
50+
51+
# Cython debug symbols
52+
cython_debug/
53+
54+
# PyPI configuration file
55+
.pypirc
56+
57+
# VSCode
58+
.vscode/

micrograd/engine.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,55 @@
1+
import math
12

23
class Value:
34
""" stores a single scalar value and its gradient """
45

5-
def __init__(self, data, _children=(), _op=''):
6+
def __init__(self, data, _children=(), _op='', _local_grads=()):
67
self.data = data
78
self.grad = 0
89
# internal variables used for autograd graph construction
9-
self._backward = lambda: None
10-
self._prev = set(_children)
10+
self._prev = _children
1111
self._op = _op # the op that produced this node, for graphviz / debugging / etc
12+
self._local_grads = _local_grads # local derivative of this node w.r.t. its children
1213

1314
def __add__(self, other):
1415
other = other if isinstance(other, Value) else Value(other)
15-
out = Value(self.data + other.data, (self, other), '+')
16-
17-
def _backward():
18-
self.grad += out.grad
19-
other.grad += out.grad
20-
out._backward = _backward
21-
16+
out = Value(self.data + other.data, (self, other), '+', (1, 1))
2217
return out
2318

2419
def __mul__(self, other):
2520
other = other if isinstance(other, Value) else Value(other)
26-
out = Value(self.data * other.data, (self, other), '*')
27-
28-
def _backward():
29-
self.grad += other.data * out.grad
30-
other.grad += self.data * out.grad
31-
out._backward = _backward
32-
21+
out = Value(self.data * other.data, (self, other), '*', (other.data, self.data))
3322
return out
3423

3524
def __pow__(self, other):
3625
assert isinstance(other, (int, float)), "only supporting int/float powers for now"
37-
out = Value(self.data**other, (self,), f'**{other}')
38-
39-
def _backward():
40-
self.grad += (other * self.data**(other-1)) * out.grad
41-
out._backward = _backward
42-
26+
out = Value(self.data**other, (self,), f'**{other}', (other * self.data**(other-1),))
4327
return out
4428

4529
def relu(self):
46-
out = Value(0 if self.data < 0 else self.data, (self,), 'ReLU')
30+
out = Value(0 if self.data < 0 else self.data, (self,), 'ReLU', (float(self.data > 0),))
31+
return out
4732

48-
def _backward():
49-
self.grad += (out.data > 0) * out.grad
50-
out._backward = _backward
33+
def exp(self):
34+
x = math.exp(self.data)
35+
out = Value(x, (self,), 'exp', (x,))
36+
return out
37+
38+
def log(self):
39+
out = Value(math.log(self.data), (self,), 'log', (1.0 / self.data,))
40+
return out
5141

42+
def tanh(self):
43+
t = math.tanh(self.data)
44+
out = Value(t, (self,), 'tanh', (1 - t**2,))
5245
return out
5346

47+
@staticmethod
48+
def softmax(logits):
49+
counts = [logit.exp() for logit in logits]
50+
total = sum(counts)
51+
return [count / total for count in counts]
52+
5453
def backward(self):
5554

5655
# topological order all of the children in the graph
@@ -64,10 +63,16 @@ def build_topo(v):
6463
topo.append(v)
6564
build_topo(self)
6665

66+
# zero the grads of each node prior to accumulating so that calling
67+
# L.backward() twice in a row doesn't produce the wrong answer.
68+
for v in reversed(topo):
69+
v.grad = 0.0
70+
6771
# go one variable at a time and apply the chain rule to get its gradient
6872
self.grad = 1
6973
for v in reversed(topo):
70-
v._backward()
74+
for child, local_grad in zip(v._prev, v._local_grads):
75+
child.grad += local_grad * v.grad
7176

7277
def __neg__(self): # -self
7378
return self * -1

micrograd/nn.py

Lines changed: 184 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self, nin, nonlin=True):
1818
self.nonlin = nonlin
1919

2020
def __call__(self, x):
21+
assert len(x) == len(self.w), "Shape mismatch between input and given nin value"
2122
act = sum((wi*xi for wi,xi in zip(self.w, x)), self.b)
2223
return act.relu() if self.nonlin else act
2324

@@ -34,7 +35,7 @@ def __init__(self, nin, nout, **kwargs):
3435

3536
def __call__(self, x):
3637
out = [n(x) for n in self.neurons]
37-
return out[0] if len(out) == 1 else out
38+
return out
3839

3940
def parameters(self):
4041
return [p for n in self.neurons for p in n.parameters()]
@@ -51,10 +52,191 @@ def __init__(self, nin, nouts):
5152
def __call__(self, x):
5253
for layer in self.layers:
5354
x = layer(x)
54-
return x
55+
return x[0] if len(x) == 1 else x
5556

5657
def parameters(self):
5758
return [p for layer in self.layers for p in layer.parameters()]
5859

5960
def __repr__(self):
6061
return f"MLP of [{', '.join(str(layer) for layer in self.layers)}]"
62+
63+
# --- Transformer components ---
64+
65+
class Linear(Module):
66+
"""Linear projection (no nonlinearity), with optional bias."""
67+
68+
def __init__(self, nin, nout, bias=True):
69+
scale = nin ** -0.5
70+
self.w = [[Value(random.uniform(-scale, scale)) for _ in range(nin)] for _ in range(nout)]
71+
self.b = [Value(0.0) for _ in range(nout)] if bias else None
72+
73+
def __call__(self, x):
74+
out = [sum(wi * xi for wi, xi in zip(row, x)) for row in self.w]
75+
if self.b:
76+
out = [oi + bi for oi, bi in zip(out, self.b)]
77+
return out
78+
79+
def parameters(self):
80+
params = [v for row in self.w for v in row]
81+
if self.b:
82+
params += self.b
83+
return params
84+
85+
def __repr__(self):
86+
nout, nin = len(self.w), len(self.w[0])
87+
return f"Linear({nin}, {nout}, bias={self.b is not None})"
88+
89+
class Embedding(Module):
90+
"""Lookup table that maps integer indices to dense vectors."""
91+
92+
def __init__(self, num_embeddings, embedding_dim):
93+
self.weight = [[Value(random.gauss(0, 0.02)) for _ in range(embedding_dim)]
94+
for _ in range(num_embeddings)]
95+
96+
def __call__(self, idx):
97+
return self.weight[idx]
98+
99+
def parameters(self):
100+
return [v for row in self.weight for v in row]
101+
102+
def __repr__(self):
103+
return f"Embedding({len(self.weight)}, {len(self.weight[0])})"
104+
105+
class LayerNorm(Module):
106+
"""Layer normalization over the last dimension."""
107+
108+
def __init__(self, dim, eps=1e-5):
109+
self.gamma = [Value(1.0) for _ in range(dim)]
110+
self.beta = [Value(0.0) for _ in range(dim)]
111+
self.eps = eps
112+
113+
def __call__(self, x):
114+
mean = sum(x) * (1.0 / len(x))
115+
var = sum((xi - mean) ** 2 for xi in x) * (1.0 / len(x))
116+
return [(xi - mean) * (var + self.eps) ** -0.5 * g + b
117+
for xi, g, b in zip(x, self.gamma, self.beta)]
118+
119+
def parameters(self):
120+
return self.gamma + self.beta
121+
122+
def __repr__(self):
123+
return f"LayerNorm({len(self.gamma)})"
124+
125+
class Attention(Module):
126+
"""Single-head scaled dot-product attention."""
127+
128+
def __init__(self, dim, head_dim):
129+
self.query = Linear(dim, head_dim, bias=False)
130+
self.key = Linear(dim, head_dim, bias=False)
131+
self.value = Linear(dim, head_dim, bias=False)
132+
self.head_dim = head_dim
133+
134+
def __call__(self, x, mask=False):
135+
# x: list of vectors (seq_len x dim)
136+
Q = [self.query(xi) for xi in x]
137+
K = [self.key(xi) for xi in x]
138+
V = [self.value(xi) for xi in x]
139+
scale = self.head_dim ** 0.5
140+
out = []
141+
for i in range(len(x)):
142+
scores = []
143+
for j in range(len(x)):
144+
if mask and j > i:
145+
scores.append(Value(-1e9)) # causal mask
146+
else:
147+
scores.append(sum(qi * ki for qi, ki in zip(Q[i], K[j])) * (1.0 / scale))
148+
weights = Value.softmax(scores)
149+
out.append([sum(w * V[j][d] for j, w in enumerate(weights))
150+
for d in range(self.head_dim)])
151+
return out
152+
153+
def parameters(self):
154+
return self.query.parameters() + self.key.parameters() + self.value.parameters()
155+
156+
class MultiHeadAttention(Module):
157+
"""Multi-head attention with output projection."""
158+
159+
def __init__(self, dim, num_heads):
160+
assert dim % num_heads == 0
161+
head_dim = dim // num_heads
162+
self.heads = [Attention(dim, head_dim) for _ in range(num_heads)]
163+
self.proj = Linear(dim, dim)
164+
165+
def __call__(self, x, mask=False):
166+
head_outs = [head(x, mask) for head in self.heads]
167+
# concatenate heads at each position, then project
168+
concat = [[v for ho in head_outs for v in ho[i]] for i in range(len(x))]
169+
return [self.proj(ci) for ci in concat]
170+
171+
def parameters(self):
172+
params = [p for h in self.heads for p in h.parameters()]
173+
return params + self.proj.parameters()
174+
175+
class FeedForward(Module):
176+
"""Two-layer feed-forward network with ReLU."""
177+
178+
def __init__(self, dim, hidden_dim=None):
179+
hidden_dim = hidden_dim or 4 * dim
180+
self.up = Linear(dim, hidden_dim)
181+
self.down = Linear(hidden_dim, dim)
182+
183+
def __call__(self, x):
184+
return self.down([h.relu() for h in self.up(x)])
185+
186+
def parameters(self):
187+
return self.up.parameters() + self.down.parameters()
188+
189+
class TransformerBlock(Module):
190+
"""Pre-norm transformer block: LN -> Attention -> Residual -> LN -> FFN -> Residual."""
191+
192+
def __init__(self, dim, num_heads):
193+
self.ln1 = LayerNorm(dim)
194+
self.attn = MultiHeadAttention(dim, num_heads)
195+
self.ln2 = LayerNorm(dim)
196+
self.ff = FeedForward(dim)
197+
198+
def __call__(self, x, mask=False):
199+
# attention + residual
200+
attn_out = self.attn([self.ln1(xi) for xi in x], mask)
201+
x = [[a + b for a, b in zip(xv, av)] for xv, av in zip(x, attn_out)]
202+
# feedforward + residual
203+
ff_out = [self.ff(self.ln2(xi)) for xi in x]
204+
x = [[a + b for a, b in zip(xv, fv)] for xv, fv in zip(x, ff_out)]
205+
return x
206+
207+
def parameters(self):
208+
return self.ln1.parameters() + self.attn.parameters() + \
209+
self.ln2.parameters() + self.ff.parameters()
210+
211+
class Transformer(Module):
212+
"""Decoder-only transformer (GPT-style)."""
213+
214+
def __init__(self, vocab_size, dim, num_heads, num_layers, max_seq_len):
215+
self.token_emb = Embedding(vocab_size, dim)
216+
self.pos_emb = Embedding(max_seq_len, dim)
217+
self.blocks = [TransformerBlock(dim, num_heads) for _ in range(num_layers)]
218+
self.ln_f = LayerNorm(dim)
219+
self.output = Linear(dim, vocab_size, bias=False)
220+
221+
def __call__(self, tokens):
222+
# tokens: list of integer token ids
223+
x = [[t + p for t, p in zip(self.token_emb(tok), self.pos_emb(i))]
224+
for i, tok in enumerate(tokens)]
225+
for block in self.blocks:
226+
x = block(x, mask=True)
227+
return [self.output(self.ln_f(xi)) for xi in x]
228+
229+
def parameters(self):
230+
params = self.token_emb.parameters() + self.pos_emb.parameters()
231+
for block in self.blocks:
232+
params += block.parameters()
233+
params += self.ln_f.parameters() + self.output.parameters()
234+
return params
235+
236+
def __repr__(self):
237+
return f"Transformer({len(self.parameters())} parameters)"
238+
239+
def cross_entropy(logits, target):
240+
"""Cross-entropy loss. logits: list of Values, target: integer index."""
241+
probs = Value.softmax(logits)
242+
return -probs[target].log()

0 commit comments

Comments
 (0)