@@ -18,6 +18,7 @@ def __init__(self, nin, nonlin=True):
1818 self .nonlin = nonlin
1919
2020 def __call__ (self , x ):
21+ assert len (x ) == len (self .w ), "Shape mismatch between input and given nin value"
2122 act = sum ((wi * xi for wi ,xi in zip (self .w , x )), self .b )
2223 return act .relu () if self .nonlin else act
2324
@@ -34,7 +35,7 @@ def __init__(self, nin, nout, **kwargs):
3435
3536 def __call__ (self , x ):
3637 out = [n (x ) for n in self .neurons ]
37- return out [ 0 ] if len ( out ) == 1 else out
38+ return out
3839
3940 def parameters (self ):
4041 return [p for n in self .neurons for p in n .parameters ()]
@@ -51,10 +52,191 @@ def __init__(self, nin, nouts):
5152 def __call__ (self , x ):
5253 for layer in self .layers :
5354 x = layer (x )
54- return x
55+ return x [ 0 ] if len ( x ) == 1 else x
5556
5657 def parameters (self ):
5758 return [p for layer in self .layers for p in layer .parameters ()]
5859
5960 def __repr__ (self ):
6061 return f"MLP of [{ ', ' .join (str (layer ) for layer in self .layers )} ]"
62+
63+ # --- Transformer components ---
64+
65+ class Linear (Module ):
66+ """Linear projection (no nonlinearity), with optional bias."""
67+
68+ def __init__ (self , nin , nout , bias = True ):
69+ scale = nin ** - 0.5
70+ self .w = [[Value (random .uniform (- scale , scale )) for _ in range (nin )] for _ in range (nout )]
71+ self .b = [Value (0.0 ) for _ in range (nout )] if bias else None
72+
73+ def __call__ (self , x ):
74+ out = [sum (wi * xi for wi , xi in zip (row , x )) for row in self .w ]
75+ if self .b :
76+ out = [oi + bi for oi , bi in zip (out , self .b )]
77+ return out
78+
79+ def parameters (self ):
80+ params = [v for row in self .w for v in row ]
81+ if self .b :
82+ params += self .b
83+ return params
84+
85+ def __repr__ (self ):
86+ nout , nin = len (self .w ), len (self .w [0 ])
87+ return f"Linear({ nin } , { nout } , bias={ self .b is not None } )"
88+
89+ class Embedding (Module ):
90+ """Lookup table that maps integer indices to dense vectors."""
91+
92+ def __init__ (self , num_embeddings , embedding_dim ):
93+ self .weight = [[Value (random .gauss (0 , 0.02 )) for _ in range (embedding_dim )]
94+ for _ in range (num_embeddings )]
95+
96+ def __call__ (self , idx ):
97+ return self .weight [idx ]
98+
99+ def parameters (self ):
100+ return [v for row in self .weight for v in row ]
101+
102+ def __repr__ (self ):
103+ return f"Embedding({ len (self .weight )} , { len (self .weight [0 ])} )"
104+
105+ class LayerNorm (Module ):
106+ """Layer normalization over the last dimension."""
107+
108+ def __init__ (self , dim , eps = 1e-5 ):
109+ self .gamma = [Value (1.0 ) for _ in range (dim )]
110+ self .beta = [Value (0.0 ) for _ in range (dim )]
111+ self .eps = eps
112+
113+ def __call__ (self , x ):
114+ mean = sum (x ) * (1.0 / len (x ))
115+ var = sum ((xi - mean ) ** 2 for xi in x ) * (1.0 / len (x ))
116+ return [(xi - mean ) * (var + self .eps ) ** - 0.5 * g + b
117+ for xi , g , b in zip (x , self .gamma , self .beta )]
118+
119+ def parameters (self ):
120+ return self .gamma + self .beta
121+
122+ def __repr__ (self ):
123+ return f"LayerNorm({ len (self .gamma )} )"
124+
125+ class Attention (Module ):
126+ """Single-head scaled dot-product attention."""
127+
128+ def __init__ (self , dim , head_dim ):
129+ self .query = Linear (dim , head_dim , bias = False )
130+ self .key = Linear (dim , head_dim , bias = False )
131+ self .value = Linear (dim , head_dim , bias = False )
132+ self .head_dim = head_dim
133+
134+ def __call__ (self , x , mask = False ):
135+ # x: list of vectors (seq_len x dim)
136+ Q = [self .query (xi ) for xi in x ]
137+ K = [self .key (xi ) for xi in x ]
138+ V = [self .value (xi ) for xi in x ]
139+ scale = self .head_dim ** 0.5
140+ out = []
141+ for i in range (len (x )):
142+ scores = []
143+ for j in range (len (x )):
144+ if mask and j > i :
145+ scores .append (Value (- 1e9 )) # causal mask
146+ else :
147+ scores .append (sum (qi * ki for qi , ki in zip (Q [i ], K [j ])) * (1.0 / scale ))
148+ weights = Value .softmax (scores )
149+ out .append ([sum (w * V [j ][d ] for j , w in enumerate (weights ))
150+ for d in range (self .head_dim )])
151+ return out
152+
153+ def parameters (self ):
154+ return self .query .parameters () + self .key .parameters () + self .value .parameters ()
155+
156+ class MultiHeadAttention (Module ):
157+ """Multi-head attention with output projection."""
158+
159+ def __init__ (self , dim , num_heads ):
160+ assert dim % num_heads == 0
161+ head_dim = dim // num_heads
162+ self .heads = [Attention (dim , head_dim ) for _ in range (num_heads )]
163+ self .proj = Linear (dim , dim )
164+
165+ def __call__ (self , x , mask = False ):
166+ head_outs = [head (x , mask ) for head in self .heads ]
167+ # concatenate heads at each position, then project
168+ concat = [[v for ho in head_outs for v in ho [i ]] for i in range (len (x ))]
169+ return [self .proj (ci ) for ci in concat ]
170+
171+ def parameters (self ):
172+ params = [p for h in self .heads for p in h .parameters ()]
173+ return params + self .proj .parameters ()
174+
175+ class FeedForward (Module ):
176+ """Two-layer feed-forward network with ReLU."""
177+
178+ def __init__ (self , dim , hidden_dim = None ):
179+ hidden_dim = hidden_dim or 4 * dim
180+ self .up = Linear (dim , hidden_dim )
181+ self .down = Linear (hidden_dim , dim )
182+
183+ def __call__ (self , x ):
184+ return self .down ([h .relu () for h in self .up (x )])
185+
186+ def parameters (self ):
187+ return self .up .parameters () + self .down .parameters ()
188+
189+ class TransformerBlock (Module ):
190+ """Pre-norm transformer block: LN -> Attention -> Residual -> LN -> FFN -> Residual."""
191+
192+ def __init__ (self , dim , num_heads ):
193+ self .ln1 = LayerNorm (dim )
194+ self .attn = MultiHeadAttention (dim , num_heads )
195+ self .ln2 = LayerNorm (dim )
196+ self .ff = FeedForward (dim )
197+
198+ def __call__ (self , x , mask = False ):
199+ # attention + residual
200+ attn_out = self .attn ([self .ln1 (xi ) for xi in x ], mask )
201+ x = [[a + b for a , b in zip (xv , av )] for xv , av in zip (x , attn_out )]
202+ # feedforward + residual
203+ ff_out = [self .ff (self .ln2 (xi )) for xi in x ]
204+ x = [[a + b for a , b in zip (xv , fv )] for xv , fv in zip (x , ff_out )]
205+ return x
206+
207+ def parameters (self ):
208+ return self .ln1 .parameters () + self .attn .parameters () + \
209+ self .ln2 .parameters () + self .ff .parameters ()
210+
211+ class Transformer (Module ):
212+ """Decoder-only transformer (GPT-style)."""
213+
214+ def __init__ (self , vocab_size , dim , num_heads , num_layers , max_seq_len ):
215+ self .token_emb = Embedding (vocab_size , dim )
216+ self .pos_emb = Embedding (max_seq_len , dim )
217+ self .blocks = [TransformerBlock (dim , num_heads ) for _ in range (num_layers )]
218+ self .ln_f = LayerNorm (dim )
219+ self .output = Linear (dim , vocab_size , bias = False )
220+
221+ def __call__ (self , tokens ):
222+ # tokens: list of integer token ids
223+ x = [[t + p for t , p in zip (self .token_emb (tok ), self .pos_emb (i ))]
224+ for i , tok in enumerate (tokens )]
225+ for block in self .blocks :
226+ x = block (x , mask = True )
227+ return [self .output (self .ln_f (xi )) for xi in x ]
228+
229+ def parameters (self ):
230+ params = self .token_emb .parameters () + self .pos_emb .parameters ()
231+ for block in self .blocks :
232+ params += block .parameters ()
233+ params += self .ln_f .parameters () + self .output .parameters ()
234+ return params
235+
236+ def __repr__ (self ):
237+ return f"Transformer({ len (self .parameters ())} parameters)"
238+
239+ def cross_entropy (logits , target ):
240+ """Cross-entropy loss. logits: list of Values, target: integer index."""
241+ probs = Value .softmax (logits )
242+ return - probs [target ].log ()
0 commit comments