Skip to content

Commit 1471299

Browse files
shxjamesgithub-actions[bot]claudekunal-mansukhani
authored
Add challenge 74: GPT-2 Inference (Medium) (#194)
* gpt2 transformer block challenge * simplify solve signature + update html description to include nice diagrams * Rename challenge 73 to 74 and add zero input test Challenge 73 is already taken by All-Pairs Shortest Paths on main. Rename medium/73_gpt2_block → medium/74_gpt2_block to avoid conflict. Also add an explicit zero-input test case (x=zeros) to generate_functional_test() to satisfy CLAUDE.md requirement that functional tests include zero inputs. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * challenge.html update * Improve challenge.html clarity: spell out acronyms, fix LaTeX rendering, add weight indexing example - Replace MHA/FFN acronyms with MultiHeadAttn/FeedForward in equations and steps - Fix LaTeX \_ rendering issue inside \text{} (plain _ works in MathJax/KaTeX) - Clarify no causal mask is applied in attention - Fix ambiguous attn_out variable naming across steps (use A, P, F) - Add concrete weight buffer indexing example (W_qkv[i][j]) - Document LaTeX underscore rule in CLAUDE.md Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Fix challenge 74: add missing Examples section and fix JAX starter comment - Add required <h2>Example</h2> section to challenge.html (was missing, checklist requires Implementation Requirements, Example(s), Constraints) - Fix starter.jax.py comment: "on the GPU" → "on GPU" to match CLAUDE.md JAX template format Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Fix checklist violations in challenge 74 (GPT-2 block) - Add missing Examples section to challenge.html - Add torch.manual_seed(0) to make generate_example_test() deterministic - Fix starter.jax.py comment: "on the GPU" -> "on GPU" (matches template) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Fix checklist issues in challenge 74 GPT-2 Transformer Block - Add missing <h2>Example</h2> section to challenge.html - Fix device assertion to verify CUDA (assert x.device.type == "cuda") Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Remove duplicate Example sections from challenge.html Previous bot commits had already added Example sections; consolidate to a single <h2>Example</h2> section after Weight Layout. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Add missing Examples section to GPT-2 block challenge.html The checklist requires <h2> sections for Implementation Requirements, Example(s), and Constraints. The Example section was missing. Since D=768 is a fixed architecture dimension, exact tensor values cannot be shown, so the example describes input/output shapes for seq_len=4 (matching generate_example_test()). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Remove duplicate Example section inadvertently added to challenge.html The previous commit added an Example section before Weight Layout, but one already existed after Weight Layout. This removes the newly-added duplicate, leaving only the Example section before Constraints. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Kunal Mansukhani <kmmansukhani@gmail.com>
1 parent d21d957 commit 1471299

File tree

9 files changed

+468
-0
lines changed

9 files changed

+468
-0
lines changed

CLAUDE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ HTML fragment with four required sections:
108108
**Formatting rules:**
109109
- `<code>` for variables/functions; `<pre>` for 1D examples, LaTeX `\begin{bmatrix}` for matrices
110110
- `&le;`, `&ge;`, `&times;` for math symbols
111+
- **LaTeX underscores**: Inside `\text{}`, use plain `_` (not `\_`). The backslash-escaped form renders literally as `\_` in MathJax/KaTeX.
111112
- **Performance test size bullet**: Must include a bullet documenting the exact parameters used in `generate_performance_test()`, formatted as:
112113
- `<li>Performance is measured with <code>param</code> = value</li>`
113114
- Use commas for numbers ≥ 1,000 (e.g., `25,000,000`)
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
<p>
2+
Implement a single GPT-2 transformer decoder block. Given an input tensor
3+
\(x\) of shape <code>(seq_len, 768)</code> and a packed weight buffer containing
4+
all block parameters, compute the output using pre-norm architecture with
5+
multi-head self-attention and a feed-forward network with GELU activation.
6+
</p>
7+
8+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 320 600" width="320" height="600" style="display:block; margin:20px auto;">
9+
<defs>
10+
<marker id="ah" viewBox="0 0 10 10" refX="9" refY="5" markerWidth="6" markerHeight="6" orient="auto-start-reverse">
11+
<path d="M0 0L10 5L0 10z" fill="#999"/>
12+
</marker>
13+
</defs>
14+
15+
<!-- Input label -->
16+
<text x="130" y="18" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">x (seq_len, 768)</text>
17+
18+
<!-- Arrow: input -> LN1 -->
19+
<line x1="130" y1="26" x2="130" y2="44" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
20+
21+
<!-- Residual 1: fork right, down, back left to Add1 -->
22+
<line x1="130" y1="33" x2="260" y2="33" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
23+
<line x1="260" y1="33" x2="260" y2="270" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
24+
<line x1="260" y1="270" x2="145" y2="270" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4" marker-end="url(#ah)"/>
25+
<text x="268" y="155" fill="#666" font-size="10" font-family="sans-serif" transform="rotate(90,268,155)">residual</text>
26+
27+
<!-- LN1 -->
28+
<rect x="55" y="47" width="150" height="30" rx="5" fill="#333" stroke="#777" stroke-width="1"/>
29+
<text x="130" y="67" text-anchor="middle" fill="#ccc" font-size="12" font-family="sans-serif">LayerNorm 1</text>
30+
<line x1="130" y1="77" x2="130" y2="95" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
31+
32+
<!-- QKV Proj -->
33+
<rect x="55" y="98" width="150" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/>
34+
<text x="130" y="118" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">QKV Projection</text>
35+
<line x1="130" y1="128" x2="130" y2="146" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
36+
37+
<!-- MHA -->
38+
<rect x="55" y="149" width="150" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/>
39+
<text x="130" y="169" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">Multi-Head Attention</text>
40+
<line x1="130" y1="179" x2="130" y2="197" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
41+
42+
<!-- Attn Out Proj -->
43+
<rect x="55" y="200" width="150" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/>
44+
<text x="130" y="220" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">Output Projection</text>
45+
<line x1="130" y1="230" x2="130" y2="258" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
46+
47+
<!-- Add 1 -->
48+
<circle cx="130" cy="270" r="12" fill="#222" stroke="#999" stroke-width="1.5"/>
49+
<text x="130" y="275" text-anchor="middle" fill="#ccc" font-size="15" font-family="sans-serif" font-weight="bold">+</text>
50+
<line x1="130" y1="282" x2="130" y2="306" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
51+
52+
<!-- Residual 2: fork right, down, back left to Add2 -->
53+
<line x1="130" y1="292" x2="260" y2="292" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
54+
<line x1="260" y1="292" x2="260" y2="530" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
55+
<line x1="260" y1="530" x2="145" y2="530" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4" marker-end="url(#ah)"/>
56+
<text x="268" y="415" fill="#666" font-size="10" font-family="sans-serif" transform="rotate(90,268,415)">residual</text>
57+
58+
<!-- LN2 -->
59+
<rect x="55" y="309" width="150" height="30" rx="5" fill="#333" stroke="#777" stroke-width="1"/>
60+
<text x="130" y="329" text-anchor="middle" fill="#ccc" font-size="12" font-family="sans-serif">LayerNorm 2</text>
61+
<line x1="130" y1="339" x2="130" y2="357" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
62+
63+
<!-- FC -->
64+
<rect x="55" y="360" width="150" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
65+
<text x="130" y="380" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">Linear (768 → 3072)</text>
66+
<line x1="130" y1="390" x2="130" y2="408" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
67+
68+
<!-- GELU -->
69+
<rect x="55" y="411" width="150" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
70+
<text x="130" y="431" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">GELU</text>
71+
<line x1="130" y1="441" x2="130" y2="459" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
72+
73+
<!-- Proj -->
74+
<rect x="55" y="462" width="150" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
75+
<text x="130" y="482" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">Linear (3072 → 768)</text>
76+
<line x1="130" y1="492" x2="130" y2="518" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
77+
78+
<!-- Add 2 -->
79+
<circle cx="130" cy="530" r="12" fill="#222" stroke="#999" stroke-width="1.5"/>
80+
<text x="130" y="535" text-anchor="middle" fill="#ccc" font-size="15" font-family="sans-serif" font-weight="bold">+</text>
81+
<line x1="130" y1="542" x2="130" y2="566" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>
82+
83+
<!-- Output label -->
84+
<text x="130" y="584" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">output (seq_len, 768)</text>
85+
</svg>
86+
87+
<p>The block uses GPT-2's <strong>pre-norm</strong> architecture: LayerNorm is applied
88+
<em>before</em> each sub-layer (attention and feed-forward), not after. At a high level:</p>
89+
90+
\[
91+
\begin{aligned}
92+
x' &= x + \text{MultiHeadAttn}\!\left(\text{LN}_1(x)\right) \\[4pt]
93+
\text{output} &= x' + \text{FeedForward}\!\left(\text{LN}_2(x')\right)
94+
\end{aligned}
95+
\]
96+
97+
<p>where the sub-layers are defined as:</p>
98+
99+
\[
100+
\begin{aligned}
101+
\text{LN}(z) &= \frac{z - \mu}{\sqrt{\sigma^2 + \epsilon}} \odot \gamma + \beta, \quad \mu = \frac{1}{d}\sum_i z_i, \quad \sigma^2 = \frac{1}{d}\sum_i (z_i - \mu)^2 \\[8pt]
102+
[Q \mid K \mid V] &= \text{LN}_1(x) \cdot W_{qkv} + b_{qkv} \\[4pt]
103+
\text{head}_i &= \text{softmax}\!\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right) V_i, \quad d_k = 64 \\[4pt]
104+
\text{MultiHeadAttn}(z) &= \text{Concat}(\text{head}_1, \ldots, \text{head}_{12}) \cdot W_{\text{attn}} + b_{\text{attn}} \\[8pt]
105+
\text{FeedForward}(z) &= \text{GELU}\!\left(z \cdot W_{fc} + b_{fc}\right) \cdot W_{\text{proj}} + b_{\text{proj}}
106+
\end{aligned}
107+
\]
108+
109+
<p>Expanding into individual steps:</p>
110+
111+
<ol>
112+
<li><strong>Layer Norm 1:</strong> \(x_{\text{norm}} = \text{LN}_1(x)\) with parameters \(\gamma_1, \beta_1\)</li>
113+
<li><strong>QKV Projection:</strong> \(QKV = x_{\text{norm}} \cdot W_{qkv} + b_{qkv}\), split into \(Q, K, V\) each of shape <code>(seq_len, 768)</code></li>
114+
<li><strong>Multi-Head Attention:</strong> Reshape \(Q, K, V\) into 12 heads of dimension 64, compute per-head scaled dot-product attention (no causal mask), then concatenate heads into \(A\)</li>
115+
<li><strong>Output Projection:</strong> \(P = A \cdot W_{\text{attn}} + b_{\text{attn}}\)</li>
116+
<li><strong>Residual 1:</strong> \(x' = x + P\)</li>
117+
<li><strong>Layer Norm 2:</strong> \(h_{\text{norm}} = \text{LN}_2(x')\) with parameters \(\gamma_2, \beta_2\)</li>
118+
<li><strong>Feed-Forward:</strong> \(F = \text{GELU}(h_{\text{norm}} \cdot W_{fc} + b_{fc}) \cdot W_{\text{proj}} + b_{\text{proj}}\)</li>
119+
<li><strong>Residual 2:</strong> \(\text{output} = x' + F\)</li>
120+
</ol>
121+
122+
<h2>Implementation Requirements</h2>
123+
<ul>
124+
<li>Use only native features (external libraries are not permitted)</li>
125+
<li>The <code>solve</code> function signature must remain unchanged</li>
126+
<li>The final result must be stored in the <code>output</code> tensor</li>
127+
<li>LayerNorm uses \(\epsilon = 10^{-5}\)</li>
128+
<li>Use the <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.GELU.html" target="_blank">GELU tanh approximation</a>: \(\text{GELU}(x) = 0.5\,x\!\left(1 + \tanh\!\left(\sqrt{\tfrac{2}{\pi}}\left(x + 0.044715\,x^3\right)\right)\right)\)</li>
129+
</ul>
130+
131+
<h2>Weight Layout</h2>
132+
<p>All block parameters are packed into a single contiguous <code>weights</code> buffer
133+
(7,087,872 floats) in the following order. Index into the buffer using the offsets below
134+
(e.g. \(W_{qkv}[i][j]\) is at <code>weights[1536 + i * 2304 + j]</code>).
135+
All 2D matrices are stored in row-major order.</p>
136+
137+
<table style="border-collapse:separate; border-spacing:16px 8px;">
138+
<tr>
139+
<th style="text-align:left;">Parameter</th>
140+
<th style="text-align:left;">Shape</th>
141+
<th style="text-align:right;">Size</th>
142+
<th style="text-align:right;">Offset</th>
143+
</tr>
144+
<tr>
145+
<td>\(\gamma_1\) (LN1 weight)</td>
146+
<td>(768,)</td>
147+
<td style="text-align:right;">768</td>
148+
<td style="text-align:right;">0</td>
149+
</tr>
150+
<tr>
151+
<td>\(\beta_1\) (LN1 bias)</td>
152+
<td>(768,)</td>
153+
<td style="text-align:right;">768</td>
154+
<td style="text-align:right;">768</td>
155+
</tr>
156+
<tr>
157+
<td>\(W_{qkv}\)</td>
158+
<td>(768, 2304)</td>
159+
<td style="text-align:right;">1,769,472</td>
160+
<td style="text-align:right;">1,536</td>
161+
</tr>
162+
<tr>
163+
<td>\(b_{qkv}\)</td>
164+
<td>(2304,)</td>
165+
<td style="text-align:right;">2,304</td>
166+
<td style="text-align:right;">1,771,008</td>
167+
</tr>
168+
<tr>
169+
<td>\(W_{\text{attn}}\)</td>
170+
<td>(768, 768)</td>
171+
<td style="text-align:right;">589,824</td>
172+
<td style="text-align:right;">1,773,312</td>
173+
</tr>
174+
<tr>
175+
<td>\(b_{\text{attn}}\)</td>
176+
<td>(768,)</td>
177+
<td style="text-align:right;">768</td>
178+
<td style="text-align:right;">2,363,136</td>
179+
</tr>
180+
<tr>
181+
<td>\(\gamma_2\) (LN2 weight)</td>
182+
<td>(768,)</td>
183+
<td style="text-align:right;">768</td>
184+
<td style="text-align:right;">2,363,904</td>
185+
</tr>
186+
<tr>
187+
<td>\(\beta_2\) (LN2 bias)</td>
188+
<td>(768,)</td>
189+
<td style="text-align:right;">768</td>
190+
<td style="text-align:right;">2,364,672</td>
191+
</tr>
192+
<tr>
193+
<td>\(W_{fc}\)</td>
194+
<td>(768, 3072)</td>
195+
<td style="text-align:right;">2,359,296</td>
196+
<td style="text-align:right;">2,365,440</td>
197+
</tr>
198+
<tr>
199+
<td>\(b_{fc}\)</td>
200+
<td>(3072,)</td>
201+
<td style="text-align:right;">3,072</td>
202+
<td style="text-align:right;">4,724,736</td>
203+
</tr>
204+
<tr>
205+
<td>\(W_{\text{proj}}\)</td>
206+
<td>(3072, 768)</td>
207+
<td style="text-align:right;">2,359,296</td>
208+
<td style="text-align:right;">4,727,808</td>
209+
</tr>
210+
<tr>
211+
<td>\(b_{\text{proj}}\)</td>
212+
<td>(768,)</td>
213+
<td style="text-align:right;">768</td>
214+
<td style="text-align:right;">7,087,104</td>
215+
</tr>
216+
</table>
217+
218+
<h2>Example</h2>
219+
<p>With <code>seq_len</code> = 4, <code>x</code> uniformly drawn from [−1, 1], and weights randomly initialized
220+
(see Weight Layout for the packing structure):</p>
221+
<pre>
222+
Input: x.shape = (4, 768) # 4 token embeddings
223+
weights.shape = (7,087,872,) # packed weight buffer
224+
seq_len = 4
225+
Output: output.shape = (4, 768) # transformed token embeddings
226+
</pre>
227+
228+
<h2>Constraints</h2>
229+
<ul>
230+
<li><code>d_model</code> = 768, <code>n_heads</code> = 12, <code>ffn_dim</code> = 3,072 (GPT-2 124M architecture)</li>
231+
<li>1 &le; <code>seq_len</code> &le; 4,096</li>
232+
<li>All tensors use 32-bit floating point</li>
233+
<li>Performance is measured with <code>seq_len</code> = 1,024</li>
234+
</ul>

0 commit comments

Comments
 (0)