diff --git a/CLAUDE.md b/CLAUDE.md index bb2844a..c5103fa 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -108,6 +108,7 @@ HTML fragment with four required sections: **Formatting rules:** - `` for variables/functions; `
` for 1D examples, LaTeX `\begin{bmatrix}` for matrices
 - `≤`, `≥`, `×` for math symbols
+- **LaTeX underscores**: Inside `\text{}`, use plain `_` (not `\_`). The backslash-escaped form renders literally as `\_` in MathJax/KaTeX.
 - **Performance test size bullet**: Must include a bullet documenting the exact parameters used in `generate_performance_test()`, formatted as:
   - `
  • Performance is measured with param = value
  • ` - Use commas for numbers ≥ 1,000 (e.g., `25,000,000`) diff --git a/challenges/medium/74_gpt2_block/challenge.html b/challenges/medium/74_gpt2_block/challenge.html new file mode 100644 index 0000000..12fc408 --- /dev/null +++ b/challenges/medium/74_gpt2_block/challenge.html @@ -0,0 +1,234 @@ +

    + Implement a single GPT-2 transformer decoder block. Given an input tensor + \(x\) of shape (seq_len, 768) and a packed weight buffer containing + all block parameters, compute the output using pre-norm architecture with + multi-head self-attention and a feed-forward network with GELU activation. +

    + + + + + + + + + + x (seq_len, 768) + + + + + + + + + residual + + + + LayerNorm 1 + + + + + QKV Projection + + + + + Multi-Head Attention + + + + + Output Projection + + + + + + + + + + + + + residual + + + + LayerNorm 2 + + + + + Linear (768 → 3072) + + + + + GELU + + + + + Linear (3072 → 768) + + + + + + + + + + output (seq_len, 768) + + +

    The block uses GPT-2's pre-norm architecture: LayerNorm is applied +before each sub-layer (attention and feed-forward), not after. At a high level:

    + +\[ +\begin{aligned} +x' &= x + \text{MultiHeadAttn}\!\left(\text{LN}_1(x)\right) \\[4pt] +\text{output} &= x' + \text{FeedForward}\!\left(\text{LN}_2(x')\right) +\end{aligned} +\] + +

    where the sub-layers are defined as:

    + +\[ +\begin{aligned} +\text{LN}(z) &= \frac{z - \mu}{\sqrt{\sigma^2 + \epsilon}} \odot \gamma + \beta, \quad \mu = \frac{1}{d}\sum_i z_i, \quad \sigma^2 = \frac{1}{d}\sum_i (z_i - \mu)^2 \\[8pt] +[Q \mid K \mid V] &= \text{LN}_1(x) \cdot W_{qkv} + b_{qkv} \\[4pt] +\text{head}_i &= \text{softmax}\!\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right) V_i, \quad d_k = 64 \\[4pt] +\text{MultiHeadAttn}(z) &= \text{Concat}(\text{head}_1, \ldots, \text{head}_{12}) \cdot W_{\text{attn}} + b_{\text{attn}} \\[8pt] +\text{FeedForward}(z) &= \text{GELU}\!\left(z \cdot W_{fc} + b_{fc}\right) \cdot W_{\text{proj}} + b_{\text{proj}} +\end{aligned} +\] + +

    Expanding into individual steps:

    + +
      +
    1. Layer Norm 1: \(x_{\text{norm}} = \text{LN}_1(x)\) with parameters \(\gamma_1, \beta_1\)
    2. +
    3. QKV Projection: \(QKV = x_{\text{norm}} \cdot W_{qkv} + b_{qkv}\), split into \(Q, K, V\) each of shape (seq_len, 768)
    4. +
    5. Multi-Head Attention: Reshape \(Q, K, V\) into 12 heads of dimension 64, compute per-head scaled dot-product attention (no causal mask), then concatenate heads into \(A\)
    6. +
    7. Output Projection: \(P = A \cdot W_{\text{attn}} + b_{\text{attn}}\)
    8. +
    9. Residual 1: \(x' = x + P\)
    10. +
    11. Layer Norm 2: \(h_{\text{norm}} = \text{LN}_2(x')\) with parameters \(\gamma_2, \beta_2\)
    12. +
    13. Feed-Forward: \(F = \text{GELU}(h_{\text{norm}} \cdot W_{fc} + b_{fc}) \cdot W_{\text{proj}} + b_{\text{proj}}\)
    14. +
    15. Residual 2: \(\text{output} = x' + F\)
    16. +
    + +

    Implementation Requirements

    + + +

    Weight Layout

    +

    All block parameters are packed into a single contiguous weights buffer +(7,087,872 floats) in the following order. Index into the buffer using the offsets below +(e.g. \(W_{qkv}[i][j]\) is at weights[1536 + i * 2304 + j]). +All 2D matrices are stored in row-major order.

    + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
    ParameterShapeSizeOffset
    \(\gamma_1\) (LN1 weight)(768,)7680
    \(\beta_1\) (LN1 bias)(768,)768768
    \(W_{qkv}\)(768, 2304)1,769,4721,536
    \(b_{qkv}\)(2304,)2,3041,771,008
    \(W_{\text{attn}}\)(768, 768)589,8241,773,312
    \(b_{\text{attn}}\)(768,)7682,363,136
    \(\gamma_2\) (LN2 weight)(768,)7682,363,904
    \(\beta_2\) (LN2 bias)(768,)7682,364,672
    \(W_{fc}\)(768, 3072)2,359,2962,365,440
    \(b_{fc}\)(3072,)3,0724,724,736
    \(W_{\text{proj}}\)(3072, 768)2,359,2964,727,808
    \(b_{\text{proj}}\)(768,)7687,087,104
    + +

    Example

    +

    With seq_len = 4, x uniformly drawn from [−1, 1], and weights randomly initialized +(see Weight Layout for the packing structure):

    +
    +Input:  x.shape       = (4, 768)       # 4 token embeddings
    +        weights.shape = (7,087,872,)   # packed weight buffer
    +        seq_len       = 4
    +Output: output.shape  = (4, 768)       # transformed token embeddings
    +
    + +

    Constraints

    + diff --git a/challenges/medium/74_gpt2_block/challenge.py b/challenges/medium/74_gpt2_block/challenge.py new file mode 100644 index 0000000..349c84e --- /dev/null +++ b/challenges/medium/74_gpt2_block/challenge.py @@ -0,0 +1,184 @@ +import ctypes +import math +from typing import Any, Dict, List + +import torch +import torch.nn.functional as F +from core.challenge_base import ChallengeBase + +# GPT-2 124M fixed dimensions +D = 768 +H = 12 +DH = D // H # 64 +FFN = 3072 + +# Weight layout offsets in the packed buffer +O_LN1_W = 0 +O_LN1_B = O_LN1_W + D +O_WQKV = O_LN1_B + D +O_BQKV = O_WQKV + D * 3 * D +O_WAPROJ = O_BQKV + 3 * D +O_BAPROJ = O_WAPROJ + D * D +O_LN2_W = O_BAPROJ + D +O_LN2_B = O_LN2_W + D +O_WFC = O_LN2_B + D +O_BFC = O_WFC + D * FFN +O_WPROJ = O_BFC + FFN +O_BPROJ = O_WPROJ + FFN * D +TOTAL_WEIGHTS = O_BPROJ + D + + +class Challenge(ChallengeBase): + def __init__(self): + super().__init__( + name="GPT-2 Transformer Block", + atol=1e-03, + rtol=1e-03, + num_gpus=1, + access_tier="free", + ) + + def reference_impl( + self, + x: torch.Tensor, + output: torch.Tensor, + weights: torch.Tensor, + seq_len: int, + ): + assert x.shape == (seq_len, D) + assert output.shape == (seq_len, D) + assert weights.shape == (TOTAL_WEIGHTS,) + assert x.dtype == output.dtype == weights.dtype + assert x.device.type == "cuda" + assert output.device.type == "cuda" + assert weights.device.type == "cuda" + + # unpack weights + ln1_w = weights[O_LN1_W:O_LN1_B] + ln1_b = weights[O_LN1_B:O_WQKV] + W_qkv = weights[O_WQKV:O_BQKV].view(D, 3 * D) + b_qkv = weights[O_BQKV:O_WAPROJ] + W_attn = weights[O_WAPROJ:O_BAPROJ].view(D, D) + b_attn = weights[O_BAPROJ:O_LN2_W] + ln2_w = weights[O_LN2_W:O_LN2_B] + ln2_b = weights[O_LN2_B:O_WFC] + W_fc = weights[O_WFC:O_BFC].view(D, FFN) + b_fc = weights[O_BFC:O_WPROJ] + W_proj = weights[O_WPROJ:O_BPROJ].view(FFN, D) + b_proj = weights[O_BPROJ : O_BPROJ + D] + + # layer norm 1 + x_norm = F.layer_norm(x, [D], ln1_w, ln1_b, eps=1e-5) + + # qkv projection + qkv = x_norm @ W_qkv + b_qkv + q, k, v = qkv.split(D, dim=-1) + + # reshape for multi-head attention: (H, seq_len, DH) + q = q.view(seq_len, H, DH).transpose(0, 1) + k = k.view(seq_len, H, DH).transpose(0, 1) + v = v.view(seq_len, H, DH).transpose(0, 1) + + # scaled dot-product attention + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(DH) + attn_weights = torch.softmax(scores, dim=-1) + attn_out = torch.matmul(attn_weights, v) + + # concat heads and project + attn_out = attn_out.transpose(0, 1).contiguous().view(seq_len, D) + attn_proj = attn_out @ W_attn + b_attn + + # residual connection 1 + hidden = x + attn_proj + + # layer norm 2 + h_norm = F.layer_norm(hidden, [D], ln2_w, ln2_b, eps=1e-5) + + # ffn: linear -> gelu (tanh approx) -> linear + fc = h_norm @ W_fc + b_fc + fc = F.gelu(fc, approximate="tanh") + proj = fc @ W_proj + b_proj + + # residual connection 2 + output.copy_(hidden + proj) + + def get_solve_signature(self) -> Dict[str, tuple]: + return { + "x": (ctypes.POINTER(ctypes.c_float), "in"), + "output": (ctypes.POINTER(ctypes.c_float), "out"), + "weights": (ctypes.POINTER(ctypes.c_float), "in"), + "seq_len": (ctypes.c_int, "in"), + } + + def _make_weights(self, device, dtype): + scale = 0.02 + ln1_w = torch.empty(D, device=device, dtype=dtype).uniform_(0.8, 1.2) + ln1_b = torch.empty(D, device=device, dtype=dtype).uniform_(-0.1, 0.1) + W_qkv = torch.empty(D, 3 * D, device=device, dtype=dtype).normal_(0, scale) + b_qkv = torch.zeros(3 * D, device=device, dtype=dtype) + W_attn = torch.empty(D, D, device=device, dtype=dtype).normal_(0, scale) + b_attn = torch.zeros(D, device=device, dtype=dtype) + ln2_w = torch.empty(D, device=device, dtype=dtype).uniform_(0.8, 1.2) + ln2_b = torch.empty(D, device=device, dtype=dtype).uniform_(-0.1, 0.1) + W_fc = torch.empty(D, FFN, device=device, dtype=dtype).normal_(0, scale) + b_fc = torch.zeros(FFN, device=device, dtype=dtype) + W_proj = torch.empty(FFN, D, device=device, dtype=dtype).normal_(0, scale) + b_proj = torch.zeros(D, device=device, dtype=dtype) + return torch.cat( + [ + ln1_w, + ln1_b, + W_qkv.flatten(), + b_qkv, + W_attn.flatten(), + b_attn, + ln2_w, + ln2_b, + W_fc.flatten(), + b_fc, + W_proj.flatten(), + b_proj, + ] + ) + + def _make_test_case(self, seq_len, zero_x=False): + dtype = torch.float32 + device = "cuda" + weights = self._make_weights(device, dtype) + if zero_x: + x = torch.zeros(seq_len, D, device=device, dtype=dtype) + else: + x = torch.empty(seq_len, D, device=device, dtype=dtype).uniform_(-1.0, 1.0) + return { + "x": x, + "output": torch.empty(seq_len, D, device=device, dtype=dtype), + "weights": weights, + "seq_len": seq_len, + } + + def generate_example_test(self) -> Dict[str, Any]: + torch.manual_seed(0) + return self._make_test_case(4) + + def generate_functional_test(self) -> List[Dict[str, Any]]: + tests = [] + # single token + tests.append(self._make_test_case(1)) + # zero input + tests.append(self._make_test_case(4, zero_x=True)) + # small edge cases + tests.append(self._make_test_case(2)) + tests.append(self._make_test_case(4)) + # power-of-2 + tests.append(self._make_test_case(16)) + tests.append(self._make_test_case(64)) + # non-power-of-2 + tests.append(self._make_test_case(30)) + tests.append(self._make_test_case(100)) + # realistic + tests.append(self._make_test_case(128)) + tests.append(self._make_test_case(256)) + return tests + + def generate_performance_test(self) -> Dict[str, Any]: + return self._make_test_case(1024) diff --git a/challenges/medium/74_gpt2_block/starter/starter.cu b/challenges/medium/74_gpt2_block/starter/starter.cu new file mode 100644 index 0000000..3bc17a3 --- /dev/null +++ b/challenges/medium/74_gpt2_block/starter/starter.cu @@ -0,0 +1,4 @@ +#include + +// x, output, weights are device pointers +extern "C" void solve(const float* x, float* output, const float* weights, int seq_len) {} diff --git a/challenges/medium/74_gpt2_block/starter/starter.cute.py b/challenges/medium/74_gpt2_block/starter/starter.cute.py new file mode 100644 index 0000000..f019e7c --- /dev/null +++ b/challenges/medium/74_gpt2_block/starter/starter.cute.py @@ -0,0 +1,13 @@ +import cutlass +import cutlass.cute as cute + + +# x, output, weights are tensors on the GPU +@cute.jit +def solve( + x: cute.Tensor, + output: cute.Tensor, + weights: cute.Tensor, + seq_len: cute.Int32, +): + pass diff --git a/challenges/medium/74_gpt2_block/starter/starter.jax.py b/challenges/medium/74_gpt2_block/starter/starter.jax.py new file mode 100644 index 0000000..d3cb8d1 --- /dev/null +++ b/challenges/medium/74_gpt2_block/starter/starter.jax.py @@ -0,0 +1,9 @@ +import jax +import jax.numpy as jnp + + +# x, weights are tensors on GPU +@jax.jit +def solve(x: jax.Array, weights: jax.Array, seq_len: int) -> jax.Array: + # return output tensor directly + pass diff --git a/challenges/medium/74_gpt2_block/starter/starter.mojo b/challenges/medium/74_gpt2_block/starter/starter.mojo new file mode 100644 index 0000000..55275dc --- /dev/null +++ b/challenges/medium/74_gpt2_block/starter/starter.mojo @@ -0,0 +1,9 @@ +from gpu.host import DeviceContext +from gpu.id import block_dim, block_idx, thread_idx +from memory import UnsafePointer +from math import ceildiv + +# x, output, weights are device pointers +@export +def solve(x: UnsafePointer[Float32], output: UnsafePointer[Float32], weights: UnsafePointer[Float32], seq_len: Int32): + pass diff --git a/challenges/medium/74_gpt2_block/starter/starter.pytorch.py b/challenges/medium/74_gpt2_block/starter/starter.pytorch.py new file mode 100644 index 0000000..ae42c1d --- /dev/null +++ b/challenges/medium/74_gpt2_block/starter/starter.pytorch.py @@ -0,0 +1,6 @@ +import torch + + +# x, output, weights are tensors on the GPU +def solve(x: torch.Tensor, output: torch.Tensor, weights: torch.Tensor, seq_len: int): + pass diff --git a/challenges/medium/74_gpt2_block/starter/starter.triton.py b/challenges/medium/74_gpt2_block/starter/starter.triton.py new file mode 100644 index 0000000..7bf7bfc --- /dev/null +++ b/challenges/medium/74_gpt2_block/starter/starter.triton.py @@ -0,0 +1,8 @@ +import torch +import triton +import triton.language as tl + + +# x, output, weights are tensors on the GPU +def solve(x: torch.Tensor, output: torch.Tensor, weights: torch.Tensor, seq_len: int): + pass