Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ HTML fragment with four required sections:
**Formatting rules:**
- `<code>` for variables/functions; `<pre>` for 1D examples, LaTeX `\begin{bmatrix}` for matrices
- `&le;`, `&ge;`, `&times;` for math symbols
- **LaTeX underscores**: Inside `\text{}`, use plain `_` (not `\_`). The backslash-escaped form renders literally as `\_` in MathJax/KaTeX.
- **Performance test size bullet**: Must include a bullet documenting the exact parameters used in `generate_performance_test()`, formatted as:
- `<li>Performance is measured with <code>param</code> = value</li>`
- Use commas for numbers ≥ 1,000 (e.g., `25,000,000`)
Expand Down
224 changes: 224 additions & 0 deletions challenges/medium/74_gpt2_block/challenge.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
<p>
Implement a single GPT-2 transformer decoder block. Given an input tensor
\(x\) of shape <code>(seq_len, 768)</code> and a packed weight buffer containing
all block parameters, compute the output using pre-norm architecture with
multi-head self-attention and a feed-forward network with GELU activation.
</p>

<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 320 600" width="320" height="600" style="display:block; margin:20px auto;">
<defs>
<marker id="ah" viewBox="0 0 10 10" refX="9" refY="5" markerWidth="6" markerHeight="6" orient="auto-start-reverse">
<path d="M0 0L10 5L0 10z" fill="#999"/>
</marker>
</defs>

<!-- Input label -->
<text x="130" y="18" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">x (seq_len, 768)</text>

<!-- Arrow: input -> LN1 -->
<line x1="130" y1="26" x2="130" y2="44" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Residual 1: fork right, down, back left to Add1 -->
<line x1="130" y1="33" x2="260" y2="33" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
<line x1="260" y1="33" x2="260" y2="270" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
<line x1="260" y1="270" x2="145" y2="270" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4" marker-end="url(#ah)"/>
<text x="268" y="155" fill="#666" font-size="10" font-family="sans-serif" transform="rotate(90,268,155)">residual</text>

<!-- LN1 -->
<rect x="55" y="47" width="150" height="30" rx="5" fill="#333" stroke="#777" stroke-width="1"/>
<text x="130" y="67" text-anchor="middle" fill="#ccc" font-size="12" font-family="sans-serif">LayerNorm 1</text>
<line x1="130" y1="77" x2="130" y2="95" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- QKV Proj -->
<rect x="55" y="98" width="150" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/>
<text x="130" y="118" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">QKV Projection</text>
<line x1="130" y1="128" x2="130" y2="146" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- MHA -->
<rect x="55" y="149" width="150" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/>
<text x="130" y="169" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">Multi-Head Attention</text>
<line x1="130" y1="179" x2="130" y2="197" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Attn Out Proj -->
<rect x="55" y="200" width="150" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/>
<text x="130" y="220" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">Output Projection</text>
<line x1="130" y1="230" x2="130" y2="258" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Add 1 -->
<circle cx="130" cy="270" r="12" fill="#222" stroke="#999" stroke-width="1.5"/>
<text x="130" y="275" text-anchor="middle" fill="#ccc" font-size="15" font-family="sans-serif" font-weight="bold">+</text>
<line x1="130" y1="282" x2="130" y2="306" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Residual 2: fork right, down, back left to Add2 -->
<line x1="130" y1="292" x2="260" y2="292" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
<line x1="260" y1="292" x2="260" y2="530" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/>
<line x1="260" y1="530" x2="145" y2="530" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4" marker-end="url(#ah)"/>
<text x="268" y="415" fill="#666" font-size="10" font-family="sans-serif" transform="rotate(90,268,415)">residual</text>

<!-- LN2 -->
<rect x="55" y="309" width="150" height="30" rx="5" fill="#333" stroke="#777" stroke-width="1"/>
<text x="130" y="329" text-anchor="middle" fill="#ccc" font-size="12" font-family="sans-serif">LayerNorm 2</text>
<line x1="130" y1="339" x2="130" y2="357" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- FC -->
<rect x="55" y="360" width="150" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
<text x="130" y="380" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">Linear (768 → 3072)</text>
<line x1="130" y1="390" x2="130" y2="408" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- GELU -->
<rect x="55" y="411" width="150" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
<text x="130" y="431" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">GELU</text>
<line x1="130" y1="441" x2="130" y2="459" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Proj -->
<rect x="55" y="462" width="150" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/>
<text x="130" y="482" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">Linear (3072 → 768)</text>
<line x1="130" y1="492" x2="130" y2="518" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Add 2 -->
<circle cx="130" cy="530" r="12" fill="#222" stroke="#999" stroke-width="1.5"/>
<text x="130" y="535" text-anchor="middle" fill="#ccc" font-size="15" font-family="sans-serif" font-weight="bold">+</text>
<line x1="130" y1="542" x2="130" y2="566" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/>

<!-- Output label -->
<text x="130" y="584" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">output (seq_len, 768)</text>
</svg>

<p>The block uses GPT-2's <strong>pre-norm</strong> architecture: LayerNorm is applied
<em>before</em> each sub-layer (attention and feed-forward), not after. At a high level:</p>

\[
\begin{aligned}
x' &= x + \text{MultiHeadAttn}\!\left(\text{LN}_1(x)\right) \\[4pt]
\text{output} &= x' + \text{FeedForward}\!\left(\text{LN}_2(x')\right)
\end{aligned}
\]

<p>where the sub-layers are defined as:</p>

\[
\begin{aligned}
\text{LN}(z) &= \frac{z - \mu}{\sqrt{\sigma^2 + \epsilon}} \odot \gamma + \beta, \quad \mu = \frac{1}{d}\sum_i z_i, \quad \sigma^2 = \frac{1}{d}\sum_i (z_i - \mu)^2 \\[8pt]
[Q \mid K \mid V] &= \text{LN}_1(x) \cdot W_{qkv} + b_{qkv} \\[4pt]
\text{head}_i &= \text{softmax}\!\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right) V_i, \quad d_k = 64 \\[4pt]
\text{MultiHeadAttn}(z) &= \text{Concat}(\text{head}_1, \ldots, \text{head}_{12}) \cdot W_{\text{attn}} + b_{\text{attn}} \\[8pt]
\text{FeedForward}(z) &= \text{GELU}\!\left(z \cdot W_{fc} + b_{fc}\right) \cdot W_{\text{proj}} + b_{\text{proj}}
\end{aligned}
\]

<p>Expanding into individual steps:</p>

<ol>
<li><strong>Layer Norm 1:</strong> \(x_{\text{norm}} = \text{LN}_1(x)\) with parameters \(\gamma_1, \beta_1\)</li>
<li><strong>QKV Projection:</strong> \(QKV = x_{\text{norm}} \cdot W_{qkv} + b_{qkv}\), split into \(Q, K, V\) each of shape <code>(seq_len, 768)</code></li>
<li><strong>Multi-Head Attention:</strong> Reshape \(Q, K, V\) into 12 heads of dimension 64, compute per-head scaled dot-product attention (no causal mask), then concatenate heads into \(A\)</li>
<li><strong>Output Projection:</strong> \(P = A \cdot W_{\text{attn}} + b_{\text{attn}}\)</li>
<li><strong>Residual 1:</strong> \(x' = x + P\)</li>
<li><strong>Layer Norm 2:</strong> \(h_{\text{norm}} = \text{LN}_2(x')\) with parameters \(\gamma_2, \beta_2\)</li>
<li><strong>Feed-Forward:</strong> \(F = \text{GELU}(h_{\text{norm}} \cdot W_{fc} + b_{fc}) \cdot W_{\text{proj}} + b_{\text{proj}}\)</li>
<li><strong>Residual 2:</strong> \(\text{output} = x' + F\)</li>
</ol>

<h2>Implementation Requirements</h2>
<ul>
<li>Use only native features (external libraries are not permitted)</li>
<li>The <code>solve</code> function signature must remain unchanged</li>
<li>The final result must be stored in the <code>output</code> tensor</li>
<li>LayerNorm uses \(\epsilon = 10^{-5}\)</li>
<li>Use the <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.GELU.html" target="_blank">GELU tanh approximation</a>: \(\text{GELU}(x) = 0.5\,x\!\left(1 + \tanh\!\left(\sqrt{\tfrac{2}{\pi}}\left(x + 0.044715\,x^3\right)\right)\right)\)</li>
</ul>

<h2>Weight Layout</h2>
<p>All block parameters are packed into a single contiguous <code>weights</code> buffer
(7,087,872 floats) in the following order. Index into the buffer using the offsets below
(e.g. \(W_{qkv}[i][j]\) is at <code>weights[1536 + i * 2304 + j]</code>).
All 2D matrices are stored in row-major order.</p>

<table style="border-collapse:separate; border-spacing:16px 8px;">
<tr>
<th style="text-align:left;">Parameter</th>
<th style="text-align:left;">Shape</th>
<th style="text-align:right;">Size</th>
<th style="text-align:right;">Offset</th>
</tr>
<tr>
<td>\(\gamma_1\) (LN1 weight)</td>
<td>(768,)</td>
<td style="text-align:right;">768</td>
<td style="text-align:right;">0</td>
</tr>
<tr>
<td>\(\beta_1\) (LN1 bias)</td>
<td>(768,)</td>
<td style="text-align:right;">768</td>
<td style="text-align:right;">768</td>
</tr>
<tr>
<td>\(W_{qkv}\)</td>
<td>(768, 2304)</td>
<td style="text-align:right;">1,769,472</td>
<td style="text-align:right;">1,536</td>
</tr>
<tr>
<td>\(b_{qkv}\)</td>
<td>(2304,)</td>
<td style="text-align:right;">2,304</td>
<td style="text-align:right;">1,771,008</td>
</tr>
<tr>
<td>\(W_{\text{attn}}\)</td>
<td>(768, 768)</td>
<td style="text-align:right;">589,824</td>
<td style="text-align:right;">1,773,312</td>
</tr>
<tr>
<td>\(b_{\text{attn}}\)</td>
<td>(768,)</td>
<td style="text-align:right;">768</td>
<td style="text-align:right;">2,363,136</td>
</tr>
<tr>
<td>\(\gamma_2\) (LN2 weight)</td>
<td>(768,)</td>
<td style="text-align:right;">768</td>
<td style="text-align:right;">2,363,904</td>
</tr>
<tr>
<td>\(\beta_2\) (LN2 bias)</td>
<td>(768,)</td>
<td style="text-align:right;">768</td>
<td style="text-align:right;">2,364,672</td>
</tr>
<tr>
<td>\(W_{fc}\)</td>
<td>(768, 3072)</td>
<td style="text-align:right;">2,359,296</td>
<td style="text-align:right;">2,365,440</td>
</tr>
<tr>
<td>\(b_{fc}\)</td>
<td>(3072,)</td>
<td style="text-align:right;">3,072</td>
<td style="text-align:right;">4,724,736</td>
</tr>
<tr>
<td>\(W_{\text{proj}}\)</td>
<td>(3072, 768)</td>
<td style="text-align:right;">2,359,296</td>
<td style="text-align:right;">4,727,808</td>
</tr>
<tr>
<td>\(b_{\text{proj}}\)</td>
<td>(768,)</td>
<td style="text-align:right;">768</td>
<td style="text-align:right;">7,087,104</td>
</tr>
</table>

<h2>Constraints</h2>
<ul>
<li><code>d_model</code> = 768, <code>n_heads</code> = 12, <code>ffn_dim</code> = 3,072 (GPT-2 124M architecture)</li>
<li>1 &le; <code>seq_len</code> &le; 4,096</li>
<li>All tensors use 32-bit floating point</li>
<li>Performance is measured with <code>seq_len</code> = 1,024</li>
</ul>
Loading
Loading