|
| 1 | +<p> |
| 2 | + Implement a single GPT-2 transformer decoder block. Given an input tensor |
| 3 | + \(x\) of shape <code>(seq_len, 768)</code> and a packed weight buffer containing |
| 4 | + all block parameters, compute the output using pre-norm architecture with |
| 5 | + multi-head self-attention and a feed-forward network with GELU activation. |
| 6 | +</p> |
| 7 | + |
| 8 | +<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 320 600" width="320" height="600" style="display:block; margin:20px auto;"> |
| 9 | + <defs> |
| 10 | + <marker id="ah" viewBox="0 0 10 10" refX="9" refY="5" markerWidth="6" markerHeight="6" orient="auto-start-reverse"> |
| 11 | + <path d="M0 0L10 5L0 10z" fill="#999"/> |
| 12 | + </marker> |
| 13 | + </defs> |
| 14 | + |
| 15 | + <!-- Input label --> |
| 16 | + <text x="130" y="18" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">x (seq_len, 768)</text> |
| 17 | + |
| 18 | + <!-- Arrow: input -> LN1 --> |
| 19 | + <line x1="130" y1="26" x2="130" y2="44" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 20 | + |
| 21 | + <!-- Residual 1: fork right, down, back left to Add1 --> |
| 22 | + <line x1="130" y1="33" x2="260" y2="33" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/> |
| 23 | + <line x1="260" y1="33" x2="260" y2="270" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/> |
| 24 | + <line x1="260" y1="270" x2="145" y2="270" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4" marker-end="url(#ah)"/> |
| 25 | + <text x="268" y="155" fill="#666" font-size="10" font-family="sans-serif" transform="rotate(90,268,155)">residual</text> |
| 26 | + |
| 27 | + <!-- LN1 --> |
| 28 | + <rect x="55" y="47" width="150" height="30" rx="5" fill="#333" stroke="#777" stroke-width="1"/> |
| 29 | + <text x="130" y="67" text-anchor="middle" fill="#ccc" font-size="12" font-family="sans-serif">LayerNorm 1</text> |
| 30 | + <line x1="130" y1="77" x2="130" y2="95" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 31 | + |
| 32 | + <!-- QKV Proj --> |
| 33 | + <rect x="55" y="98" width="150" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/> |
| 34 | + <text x="130" y="118" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">QKV Projection</text> |
| 35 | + <line x1="130" y1="128" x2="130" y2="146" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 36 | + |
| 37 | + <!-- MHA --> |
| 38 | + <rect x="55" y="149" width="150" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/> |
| 39 | + <text x="130" y="169" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">Multi-Head Attention</text> |
| 40 | + <line x1="130" y1="179" x2="130" y2="197" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 41 | + |
| 42 | + <!-- Attn Out Proj --> |
| 43 | + <rect x="55" y="200" width="150" height="30" rx="5" fill="#1e2d4d" stroke="#4477bb" stroke-width="1"/> |
| 44 | + <text x="130" y="220" text-anchor="middle" fill="#aaccee" font-size="12" font-family="sans-serif">Output Projection</text> |
| 45 | + <line x1="130" y1="230" x2="130" y2="258" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 46 | + |
| 47 | + <!-- Add 1 --> |
| 48 | + <circle cx="130" cy="270" r="12" fill="#222" stroke="#999" stroke-width="1.5"/> |
| 49 | + <text x="130" y="275" text-anchor="middle" fill="#ccc" font-size="15" font-family="sans-serif" font-weight="bold">+</text> |
| 50 | + <line x1="130" y1="282" x2="130" y2="306" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 51 | + |
| 52 | + <!-- Residual 2: fork right, down, back left to Add2 --> |
| 53 | + <line x1="130" y1="292" x2="260" y2="292" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/> |
| 54 | + <line x1="260" y1="292" x2="260" y2="530" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4"/> |
| 55 | + <line x1="260" y1="530" x2="145" y2="530" stroke="#555" stroke-width="1.5" stroke-dasharray="5,4" marker-end="url(#ah)"/> |
| 56 | + <text x="268" y="415" fill="#666" font-size="10" font-family="sans-serif" transform="rotate(90,268,415)">residual</text> |
| 57 | + |
| 58 | + <!-- LN2 --> |
| 59 | + <rect x="55" y="309" width="150" height="30" rx="5" fill="#333" stroke="#777" stroke-width="1"/> |
| 60 | + <text x="130" y="329" text-anchor="middle" fill="#ccc" font-size="12" font-family="sans-serif">LayerNorm 2</text> |
| 61 | + <line x1="130" y1="339" x2="130" y2="357" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 62 | + |
| 63 | + <!-- FC --> |
| 64 | + <rect x="55" y="360" width="150" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/> |
| 65 | + <text x="130" y="380" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">Linear (768 → 3072)</text> |
| 66 | + <line x1="130" y1="390" x2="130" y2="408" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 67 | + |
| 68 | + <!-- GELU --> |
| 69 | + <rect x="55" y="411" width="150" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/> |
| 70 | + <text x="130" y="431" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">GELU</text> |
| 71 | + <line x1="130" y1="441" x2="130" y2="459" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 72 | + |
| 73 | + <!-- Proj --> |
| 74 | + <rect x="55" y="462" width="150" height="30" rx="5" fill="#1e3d2d" stroke="#44aa66" stroke-width="1"/> |
| 75 | + <text x="130" y="482" text-anchor="middle" fill="#aaeebb" font-size="12" font-family="sans-serif">Linear (3072 → 768)</text> |
| 76 | + <line x1="130" y1="492" x2="130" y2="518" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 77 | + |
| 78 | + <!-- Add 2 --> |
| 79 | + <circle cx="130" cy="530" r="12" fill="#222" stroke="#999" stroke-width="1.5"/> |
| 80 | + <text x="130" y="535" text-anchor="middle" fill="#ccc" font-size="15" font-family="sans-serif" font-weight="bold">+</text> |
| 81 | + <line x1="130" y1="542" x2="130" y2="566" stroke="#999" stroke-width="1.5" marker-end="url(#ah)"/> |
| 82 | + |
| 83 | + <!-- Output label --> |
| 84 | + <text x="130" y="584" text-anchor="middle" fill="#ccc" font-size="13" font-family="monospace">output (seq_len, 768)</text> |
| 85 | +</svg> |
| 86 | + |
| 87 | +<p>The block uses GPT-2's <strong>pre-norm</strong> architecture: LayerNorm is applied |
| 88 | +<em>before</em> each sub-layer (attention and feed-forward), not after. At a high level:</p> |
| 89 | + |
| 90 | +\[ |
| 91 | +\begin{aligned} |
| 92 | +x' &= x + \text{MultiHeadAttn}\!\left(\text{LN}_1(x)\right) \\[4pt] |
| 93 | +\text{output} &= x' + \text{FeedForward}\!\left(\text{LN}_2(x')\right) |
| 94 | +\end{aligned} |
| 95 | +\] |
| 96 | + |
| 97 | +<p>where the sub-layers are defined as:</p> |
| 98 | + |
| 99 | +\[ |
| 100 | +\begin{aligned} |
| 101 | +\text{LN}(z) &= \frac{z - \mu}{\sqrt{\sigma^2 + \epsilon}} \odot \gamma + \beta, \quad \mu = \frac{1}{d}\sum_i z_i, \quad \sigma^2 = \frac{1}{d}\sum_i (z_i - \mu)^2 \\[8pt] |
| 102 | +[Q \mid K \mid V] &= \text{LN}_1(x) \cdot W_{qkv} + b_{qkv} \\[4pt] |
| 103 | +\text{head}_i &= \text{softmax}\!\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right) V_i, \quad d_k = 64 \\[4pt] |
| 104 | +\text{MultiHeadAttn}(z) &= \text{Concat}(\text{head}_1, \ldots, \text{head}_{12}) \cdot W_{\text{attn}} + b_{\text{attn}} \\[8pt] |
| 105 | +\text{FeedForward}(z) &= \text{GELU}\!\left(z \cdot W_{fc} + b_{fc}\right) \cdot W_{\text{proj}} + b_{\text{proj}} |
| 106 | +\end{aligned} |
| 107 | +\] |
| 108 | + |
| 109 | +<p>Expanding into individual steps:</p> |
| 110 | + |
| 111 | +<ol> |
| 112 | + <li><strong>Layer Norm 1:</strong> \(x_{\text{norm}} = \text{LN}_1(x)\) with parameters \(\gamma_1, \beta_1\)</li> |
| 113 | + <li><strong>QKV Projection:</strong> \(QKV = x_{\text{norm}} \cdot W_{qkv} + b_{qkv}\), split into \(Q, K, V\) each of shape <code>(seq_len, 768)</code></li> |
| 114 | + <li><strong>Multi-Head Attention:</strong> Reshape \(Q, K, V\) into 12 heads of dimension 64, compute per-head scaled dot-product attention (no causal mask), then concatenate heads into \(A\)</li> |
| 115 | + <li><strong>Output Projection:</strong> \(P = A \cdot W_{\text{attn}} + b_{\text{attn}}\)</li> |
| 116 | + <li><strong>Residual 1:</strong> \(x' = x + P\)</li> |
| 117 | + <li><strong>Layer Norm 2:</strong> \(h_{\text{norm}} = \text{LN}_2(x')\) with parameters \(\gamma_2, \beta_2\)</li> |
| 118 | + <li><strong>Feed-Forward:</strong> \(F = \text{GELU}(h_{\text{norm}} \cdot W_{fc} + b_{fc}) \cdot W_{\text{proj}} + b_{\text{proj}}\)</li> |
| 119 | + <li><strong>Residual 2:</strong> \(\text{output} = x' + F\)</li> |
| 120 | +</ol> |
| 121 | + |
| 122 | +<h2>Implementation Requirements</h2> |
| 123 | +<ul> |
| 124 | + <li>Use only native features (external libraries are not permitted)</li> |
| 125 | + <li>The <code>solve</code> function signature must remain unchanged</li> |
| 126 | + <li>The final result must be stored in the <code>output</code> tensor</li> |
| 127 | + <li>LayerNorm uses \(\epsilon = 10^{-5}\)</li> |
| 128 | + <li>Use the <a href="https://docs.pytorch.org/docs/stable/generated/torch.nn.GELU.html" target="_blank">GELU tanh approximation</a>: \(\text{GELU}(x) = 0.5\,x\!\left(1 + \tanh\!\left(\sqrt{\tfrac{2}{\pi}}\left(x + 0.044715\,x^3\right)\right)\right)\)</li> |
| 129 | +</ul> |
| 130 | + |
| 131 | +<h2>Weight Layout</h2> |
| 132 | +<p>All block parameters are packed into a single contiguous <code>weights</code> buffer |
| 133 | +(7,087,872 floats) in the following order. Index into the buffer using the offsets below |
| 134 | +(e.g. \(W_{qkv}[i][j]\) is at <code>weights[1536 + i * 2304 + j]</code>). |
| 135 | +All 2D matrices are stored in row-major order.</p> |
| 136 | + |
| 137 | +<table style="border-collapse:separate; border-spacing:16px 8px;"> |
| 138 | + <tr> |
| 139 | + <th style="text-align:left;">Parameter</th> |
| 140 | + <th style="text-align:left;">Shape</th> |
| 141 | + <th style="text-align:right;">Size</th> |
| 142 | + <th style="text-align:right;">Offset</th> |
| 143 | + </tr> |
| 144 | + <tr> |
| 145 | + <td>\(\gamma_1\) (LN1 weight)</td> |
| 146 | + <td>(768,)</td> |
| 147 | + <td style="text-align:right;">768</td> |
| 148 | + <td style="text-align:right;">0</td> |
| 149 | + </tr> |
| 150 | + <tr> |
| 151 | + <td>\(\beta_1\) (LN1 bias)</td> |
| 152 | + <td>(768,)</td> |
| 153 | + <td style="text-align:right;">768</td> |
| 154 | + <td style="text-align:right;">768</td> |
| 155 | + </tr> |
| 156 | + <tr> |
| 157 | + <td>\(W_{qkv}\)</td> |
| 158 | + <td>(768, 2304)</td> |
| 159 | + <td style="text-align:right;">1,769,472</td> |
| 160 | + <td style="text-align:right;">1,536</td> |
| 161 | + </tr> |
| 162 | + <tr> |
| 163 | + <td>\(b_{qkv}\)</td> |
| 164 | + <td>(2304,)</td> |
| 165 | + <td style="text-align:right;">2,304</td> |
| 166 | + <td style="text-align:right;">1,771,008</td> |
| 167 | + </tr> |
| 168 | + <tr> |
| 169 | + <td>\(W_{\text{attn}}\)</td> |
| 170 | + <td>(768, 768)</td> |
| 171 | + <td style="text-align:right;">589,824</td> |
| 172 | + <td style="text-align:right;">1,773,312</td> |
| 173 | + </tr> |
| 174 | + <tr> |
| 175 | + <td>\(b_{\text{attn}}\)</td> |
| 176 | + <td>(768,)</td> |
| 177 | + <td style="text-align:right;">768</td> |
| 178 | + <td style="text-align:right;">2,363,136</td> |
| 179 | + </tr> |
| 180 | + <tr> |
| 181 | + <td>\(\gamma_2\) (LN2 weight)</td> |
| 182 | + <td>(768,)</td> |
| 183 | + <td style="text-align:right;">768</td> |
| 184 | + <td style="text-align:right;">2,363,904</td> |
| 185 | + </tr> |
| 186 | + <tr> |
| 187 | + <td>\(\beta_2\) (LN2 bias)</td> |
| 188 | + <td>(768,)</td> |
| 189 | + <td style="text-align:right;">768</td> |
| 190 | + <td style="text-align:right;">2,364,672</td> |
| 191 | + </tr> |
| 192 | + <tr> |
| 193 | + <td>\(W_{fc}\)</td> |
| 194 | + <td>(768, 3072)</td> |
| 195 | + <td style="text-align:right;">2,359,296</td> |
| 196 | + <td style="text-align:right;">2,365,440</td> |
| 197 | + </tr> |
| 198 | + <tr> |
| 199 | + <td>\(b_{fc}\)</td> |
| 200 | + <td>(3072,)</td> |
| 201 | + <td style="text-align:right;">3,072</td> |
| 202 | + <td style="text-align:right;">4,724,736</td> |
| 203 | + </tr> |
| 204 | + <tr> |
| 205 | + <td>\(W_{\text{proj}}\)</td> |
| 206 | + <td>(3072, 768)</td> |
| 207 | + <td style="text-align:right;">2,359,296</td> |
| 208 | + <td style="text-align:right;">4,727,808</td> |
| 209 | + </tr> |
| 210 | + <tr> |
| 211 | + <td>\(b_{\text{proj}}\)</td> |
| 212 | + <td>(768,)</td> |
| 213 | + <td style="text-align:right;">768</td> |
| 214 | + <td style="text-align:right;">7,087,104</td> |
| 215 | + </tr> |
| 216 | +</table> |
| 217 | + |
| 218 | +<h2>Example</h2> |
| 219 | +<p>With <code>seq_len</code> = 4, <code>x</code> uniformly drawn from [−1, 1], and weights randomly initialized |
| 220 | +(see Weight Layout for the packing structure):</p> |
| 221 | +<pre> |
| 222 | +Input: x.shape = (4, 768) # 4 token embeddings |
| 223 | + weights.shape = (7,087,872,) # packed weight buffer |
| 224 | + seq_len = 4 |
| 225 | +Output: output.shape = (4, 768) # transformed token embeddings |
| 226 | +</pre> |
| 227 | + |
| 228 | +<h2>Constraints</h2> |
| 229 | +<ul> |
| 230 | + <li><code>d_model</code> = 768, <code>n_heads</code> = 12, <code>ffn_dim</code> = 3,072 (GPT-2 124M architecture)</li> |
| 231 | + <li>1 ≤ <code>seq_len</code> ≤ 4,096</li> |
| 232 | + <li>All tensors use 32-bit floating point</li> |
| 233 | + <li>Performance is measured with <code>seq_len</code> = 1,024</li> |
| 234 | +</ul> |
0 commit comments