Skip to content

Commit 945983a

Browse files
claude[bot]github-actions[bot]claude
authored
Add challenge 80: Grouped Query Attention (Medium) (#215)
* Add challenge 80: Grouped Query Attention (Medium) Implements a GQA forward pass challenge inspired by real-world LLM inference (LLaMA-3, Mistral, Gemma). Solvers must correctly handle Q/K/V tensors with different head counts and implement scaled dot-product attention with softmax over grouped KV heads. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Fix challenge 80: add example output and convert to LaTeX bmatrix format - Add missing output matrices to the Example section (required by checklist) - Convert example from <pre> notation to LaTeX \begin{bmatrix} for all Q, K, V, and output head matrices (required for 2D/3D data per CLAUDE.md) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent d9fbdf6 commit 945983a

File tree

8 files changed

+453
-0
lines changed

8 files changed

+453
-0
lines changed
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
<p>
2+
Implement Grouped Query Attention (GQA), the attention mechanism used in modern large language
3+
models such as LLaMA-3, Mistral, and Gemma. GQA reduces the KV-cache memory footprint during
4+
inference by sharing key and value heads across groups of query heads. Given query tensor
5+
<code>Q</code> with <code>num_q_heads</code> heads and key/value tensors <code>K</code>,
6+
<code>V</code> each with <code>num_kv_heads</code> heads, compute scaled dot-product attention
7+
where every group of <code>num_q_heads / num_kv_heads</code> consecutive query heads attends to
8+
the same key and value head. All tensors use <code>float32</code>.
9+
</p>
10+
11+
<svg width="700" height="260" viewBox="0 0 700 260" xmlns="http://www.w3.org/2000/svg" style="display:block; margin:20px auto;">
12+
<rect width="700" height="260" fill="#222" rx="10"/>
13+
<!-- Title -->
14+
<text x="350" y="28" fill="#ccc" font-family="monospace" font-size="14" text-anchor="middle">Grouped Query Attention (num_q_heads=4, num_kv_heads=2, groups=2)</text>
15+
16+
<!-- Q heads -->
17+
<text x="80" y="60" fill="#aaa" font-family="monospace" font-size="12" text-anchor="middle">Q heads</text>
18+
<rect x="20" y="70" width="60" height="36" fill="#2563eb" rx="4"/>
19+
<text x="50" y="93" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">Q[0]</text>
20+
<rect x="100" y="70" width="60" height="36" fill="#2563eb" rx="4"/>
21+
<text x="130" y="93" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">Q[1]</text>
22+
<rect x="180" y="70" width="60" height="36" fill="#7c3aed" rx="4"/>
23+
<text x="210" y="93" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">Q[2]</text>
24+
<rect x="260" y="70" width="60" height="36" fill="#7c3aed" rx="4"/>
25+
<text x="290" y="93" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">Q[3]</text>
26+
27+
<!-- KV heads -->
28+
<text x="80" y="175" fill="#aaa" font-family="monospace" font-size="12" text-anchor="middle">KV heads</text>
29+
<rect x="20" y="185" width="120" height="36" fill="#1d4ed8" rx="4"/>
30+
<text x="80" y="208" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">K[0], V[0]</text>
31+
<rect x="180" y="185" width="120" height="36" fill="#5b21b6" rx="4"/>
32+
<text x="240" y="208" fill="#fff" font-family="monospace" font-size="12" text-anchor="middle">K[1], V[1]</text>
33+
34+
<!-- Arrows group 0 -->
35+
<line x1="50" y1="106" x2="70" y2="185" stroke="#60a5fa" stroke-width="1.5" marker-end="url(#arr)"/>
36+
<line x1="130" y1="106" x2="90" y2="185" stroke="#60a5fa" stroke-width="1.5" marker-end="url(#arr)"/>
37+
38+
<!-- Arrows group 1 -->
39+
<line x1="210" y1="106" x2="230" y2="185" stroke="#c4b5fd" stroke-width="1.5" marker-end="url(#arr)"/>
40+
<line x1="290" y1="106" x2="250" y2="185" stroke="#c4b5fd" stroke-width="1.5" marker-end="url(#arr)"/>
41+
42+
<!-- Output boxes -->
43+
<text x="80" y="245" fill="#aaa" font-family="monospace" font-size="11" text-anchor="middle">group 0</text>
44+
<text x="240" y="245" fill="#aaa" font-family="monospace" font-size="11" text-anchor="middle">group 1</text>
45+
46+
<!-- bracket labels -->
47+
<text x="430" y="88" fill="#60a5fa" font-family="monospace" font-size="12">Q[0], Q[1] attend to K[0], V[0]</text>
48+
<text x="430" y="112" fill="#c4b5fd" font-family="monospace" font-size="12">Q[2], Q[3] attend to K[1], V[1]</text>
49+
<text x="430" y="150" fill="#4ade80" font-family="monospace" font-size="12">scale = 1 / sqrt(head_dim)</text>
50+
<text x="430" y="174" fill="#4ade80" font-family="monospace" font-size="12">scores = Q @ K^T * scale</text>
51+
<text x="430" y="198" fill="#4ade80" font-family="monospace" font-size="12">weights = softmax(scores)</text>
52+
<text x="430" y="222" fill="#4ade80" font-family="monospace" font-size="12">output = weights @ V</text>
53+
54+
<defs>
55+
<marker id="arr" markerWidth="6" markerHeight="6" refX="3" refY="3" orient="auto">
56+
<path d="M0,0 L0,6 L6,3 z" fill="#888"/>
57+
</marker>
58+
</defs>
59+
</svg>
60+
61+
<h2>Implementation Requirements</h2>
62+
<ul>
63+
<li>Implement the function <code>solve(Q, K, V, output, num_q_heads, num_kv_heads, seq_len, head_dim)</code>.</li>
64+
<li>Do not change the function signature or use external libraries beyond the standard GPU frameworks.</li>
65+
<li>Write the result into the provided <code>output</code> buffer.</li>
66+
<li><code>num_q_heads</code> is always divisible by <code>num_kv_heads</code>.</li>
67+
<li>Use scaled dot-product attention with scale factor <code>1 / sqrt(head_dim)</code> and a softmax over the key dimension.</li>
68+
</ul>
69+
70+
<h2>Example</h2>
71+
<p>
72+
With <code>num_q_heads</code> = 4, <code>num_kv_heads</code> = 2 (groups of 2), <code>seq_len</code> = 3,
73+
<code>head_dim</code> = 4:
74+
</p>
75+
<p>
76+
<strong>Input:</strong><br>
77+
\(Q_0\) (3&times;4):
78+
\[
79+
\begin{bmatrix}
80+
1 & 0 & 0 & 1 \\
81+
0 & 1 & 1 & 0 \\
82+
1 & 1 & 0 & 0
83+
\end{bmatrix}
84+
\]
85+
\(Q_1\) (3&times;4):
86+
\[
87+
\begin{bmatrix}
88+
0 & 1 & 0 & 1 \\
89+
1 & 0 & 1 & 0 \\
90+
0 & 0 & 1 & 1
91+
\end{bmatrix}
92+
\]
93+
\(Q_2\) (3&times;4):
94+
\[
95+
\begin{bmatrix}
96+
-1 & 0 & 0.5 & 0 \\
97+
0 & -1 & 0 & 0.5 \\
98+
0.5 & 0 & -1 & 0
99+
\end{bmatrix}
100+
\]
101+
\(Q_3\) (3&times;4):
102+
\[
103+
\begin{bmatrix}
104+
0 & 0.5 & 0 & -1 \\
105+
0.5 & 0 & 0 & -1 \\
106+
0 & 0 & 0.5 & 0.5
107+
\end{bmatrix}
108+
\]
109+
\(K_0\) (3&times;4):
110+
\[
111+
\begin{bmatrix}
112+
1 & 0 & 1 & 0 \\
113+
0 & 1 & 0 & 1 \\
114+
1 & 1 & 1 & 1
115+
\end{bmatrix}
116+
\]
117+
\(K_1\) (3&times;4):
118+
\[
119+
\begin{bmatrix}
120+
0 & 1 & 0 & -1 \\
121+
-1 & 0 & 1 & 0 \\
122+
0 & -1 & 0 & 1
123+
\end{bmatrix}
124+
\]
125+
\(V_0\) (3&times;4):
126+
\[
127+
\begin{bmatrix}
128+
1 & 2 & 3 & 4 \\
129+
5 & 6 & 7 & 8 \\
130+
9 & 10 & 11 & 12
131+
\end{bmatrix}
132+
\]
133+
\(V_1\) (3&times;4):
134+
\[
135+
\begin{bmatrix}
136+
-1 & -2 & -3 & -4 \\
137+
2 & 3 & 4 & 5 \\
138+
6 & 7 & 8 & 9
139+
\end{bmatrix}
140+
\]
141+
Groups: \(Q_0, Q_1 \to K_0, V_0\); \quad \(Q_2, Q_3 \to K_1, V_1\)
142+
</p>
143+
<p>
144+
<strong>Output</strong> (values rounded to 2 decimal places):<br>
145+
\(\text{output}_0\) (3&times;4):
146+
\[
147+
\begin{bmatrix}
148+
5.71 & 6.71 & 7.71 & 8.71 \\
149+
5.71 & 6.71 & 7.71 & 8.71 \\
150+
5.71 & 6.71 & 7.71 & 8.71
151+
\end{bmatrix}
152+
\]
153+
\(\text{output}_1\) (3&times;4):
154+
\[
155+
\begin{bmatrix}
156+
6.07 & 7.07 & 8.07 & 9.07 \\
157+
5.00 & 6.00 & 7.00 & 8.00 \\
158+
5.71 & 6.71 & 7.71 & 8.71
159+
\end{bmatrix}
160+
\]
161+
\(\text{output}_2\) (3&times;4):
162+
\[
163+
\begin{bmatrix}
164+
2.24 & 2.76 & 3.27 & 3.79 \\
165+
3.96 & 4.70 & 5.44 & 6.17 \\
166+
2.40 & 2.60 & 2.79 & 2.98
167+
\end{bmatrix}
168+
\]
169+
\(\text{output}_3\) (3&times;4):
170+
\[
171+
\begin{bmatrix}
172+
0.76 & 0.58 & 0.40 & 0.22 \\
173+
1.17 & 1.08 & 1.00 & 0.91 \\
174+
2.84 & 3.37 & 3.91 & 4.44
175+
\end{bmatrix}
176+
\]
177+
</p>
178+
179+
<h2>Constraints</h2>
180+
<ul>
181+
<li>1 &le; <code>num_kv_heads</code> &le; <code>num_q_heads</code> &le; 64</li>
182+
<li><code>num_q_heads</code> is divisible by <code>num_kv_heads</code></li>
183+
<li>1 &le; <code>seq_len</code> &le; 4,096</li>
184+
<li>8 &le; <code>head_dim</code> &le; 256; <code>head_dim</code> is a multiple of 8</li>
185+
<li>All tensor values are <code>float32</code></li>
186+
<li>Performance is measured with <code>num_q_heads</code> = 32, <code>num_kv_heads</code> = 8, <code>seq_len</code> = 1,024, <code>head_dim</code> = 128</li>
187+
</ul>
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import ctypes
2+
import math
3+
from typing import Any, Dict, List
4+
5+
import torch
6+
from core.challenge_base import ChallengeBase
7+
8+
9+
class Challenge(ChallengeBase):
10+
def __init__(self):
11+
super().__init__(
12+
name="Grouped Query Attention",
13+
atol=1e-04,
14+
rtol=1e-04,
15+
num_gpus=1,
16+
access_tier="free",
17+
)
18+
19+
def reference_impl(
20+
self,
21+
Q: torch.Tensor,
22+
K: torch.Tensor,
23+
V: torch.Tensor,
24+
output: torch.Tensor,
25+
num_q_heads: int,
26+
num_kv_heads: int,
27+
seq_len: int,
28+
head_dim: int,
29+
):
30+
assert Q.shape == (num_q_heads, seq_len, head_dim)
31+
assert K.shape == (num_kv_heads, seq_len, head_dim)
32+
assert V.shape == (num_kv_heads, seq_len, head_dim)
33+
assert output.shape == (num_q_heads, seq_len, head_dim)
34+
assert Q.dtype == K.dtype == V.dtype == output.dtype == torch.float32
35+
assert Q.device.type == "cuda"
36+
assert K.device.type == "cuda"
37+
assert V.device.type == "cuda"
38+
assert output.device.type == "cuda"
39+
assert num_q_heads % num_kv_heads == 0
40+
41+
num_groups = num_q_heads // num_kv_heads
42+
scale = 1.0 / math.sqrt(head_dim)
43+
44+
# Expand K, V from (num_kv_heads, seq_len, head_dim)
45+
# to (num_q_heads, seq_len, head_dim) by repeating each KV head num_groups times
46+
K_expanded = K.repeat_interleave(num_groups, dim=0)
47+
V_expanded = V.repeat_interleave(num_groups, dim=0)
48+
49+
# Scaled dot-product attention: (num_q_heads, seq_len, seq_len)
50+
scores = torch.bmm(Q, K_expanded.transpose(1, 2)) * scale
51+
52+
# Softmax over the key dimension
53+
attn_weights = torch.softmax(scores, dim=-1)
54+
55+
# Weighted sum of values: (num_q_heads, seq_len, head_dim)
56+
output.copy_(torch.bmm(attn_weights, V_expanded))
57+
58+
def get_solve_signature(self) -> Dict[str, tuple]:
59+
return {
60+
"Q": (ctypes.POINTER(ctypes.c_float), "in"),
61+
"K": (ctypes.POINTER(ctypes.c_float), "in"),
62+
"V": (ctypes.POINTER(ctypes.c_float), "in"),
63+
"output": (ctypes.POINTER(ctypes.c_float), "out"),
64+
"num_q_heads": (ctypes.c_int, "in"),
65+
"num_kv_heads": (ctypes.c_int, "in"),
66+
"seq_len": (ctypes.c_int, "in"),
67+
"head_dim": (ctypes.c_int, "in"),
68+
}
69+
70+
def _make_test_case(self, num_q_heads, num_kv_heads, seq_len, head_dim, zero_inputs=False):
71+
dtype = torch.float32
72+
device = "cuda"
73+
if zero_inputs:
74+
Q = torch.zeros(num_q_heads, seq_len, head_dim, device=device, dtype=dtype)
75+
K = torch.zeros(num_kv_heads, seq_len, head_dim, device=device, dtype=dtype)
76+
V = torch.zeros(num_kv_heads, seq_len, head_dim, device=device, dtype=dtype)
77+
else:
78+
Q = torch.randn(num_q_heads, seq_len, head_dim, device=device, dtype=dtype)
79+
K = torch.randn(num_kv_heads, seq_len, head_dim, device=device, dtype=dtype)
80+
V = torch.randn(num_kv_heads, seq_len, head_dim, device=device, dtype=dtype)
81+
output = torch.zeros(num_q_heads, seq_len, head_dim, device=device, dtype=dtype)
82+
return {
83+
"Q": Q,
84+
"K": K,
85+
"V": V,
86+
"output": output,
87+
"num_q_heads": num_q_heads,
88+
"num_kv_heads": num_kv_heads,
89+
"seq_len": seq_len,
90+
"head_dim": head_dim,
91+
}
92+
93+
def generate_example_test(self) -> Dict[str, Any]:
94+
torch.manual_seed(0)
95+
dtype = torch.float32
96+
device = "cuda"
97+
num_q_heads = 4
98+
num_kv_heads = 2
99+
seq_len = 3
100+
head_dim = 4
101+
102+
Q = torch.tensor(
103+
[
104+
[[1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 1.0, 0.0], [1.0, 1.0, 0.0, 0.0]],
105+
[[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0], [0.0, 0.0, 1.0, 1.0]],
106+
[[-1.0, 0.0, 0.5, 0.0], [0.0, -1.0, 0.0, 0.5], [0.5, 0.0, -1.0, 0.0]],
107+
[[0.0, 0.5, 0.0, -1.0], [0.5, 0.0, 0.0, -1.0], [0.0, 0.0, 0.5, 0.5]],
108+
],
109+
device=device,
110+
dtype=dtype,
111+
)
112+
K = torch.tensor(
113+
[
114+
[[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], [1.0, 1.0, 1.0, 1.0]],
115+
[[0.0, 1.0, 0.0, -1.0], [-1.0, 0.0, 1.0, 0.0], [0.0, -1.0, 0.0, 1.0]],
116+
],
117+
device=device,
118+
dtype=dtype,
119+
)
120+
V = torch.tensor(
121+
[
122+
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
123+
[[-1.0, -2.0, -3.0, -4.0], [2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]],
124+
],
125+
device=device,
126+
dtype=dtype,
127+
)
128+
output = torch.zeros(num_q_heads, seq_len, head_dim, device=device, dtype=dtype)
129+
return {
130+
"Q": Q,
131+
"K": K,
132+
"V": V,
133+
"output": output,
134+
"num_q_heads": num_q_heads,
135+
"num_kv_heads": num_kv_heads,
136+
"seq_len": seq_len,
137+
"head_dim": head_dim,
138+
}
139+
140+
def generate_functional_test(self) -> List[Dict[str, Any]]:
141+
torch.manual_seed(42)
142+
tests = []
143+
144+
# Edge case: MQA (num_kv_heads=1), single token
145+
tests.append(self._make_test_case(4, 1, 1, 8))
146+
147+
# Edge case: GQA with groups=2, tiny seq
148+
tests.append(self._make_test_case(2, 1, 2, 4))
149+
150+
# Zero inputs
151+
tests.append(self._make_test_case(4, 2, 4, 8, zero_inputs=True))
152+
153+
# Power-of-2: groups=4 (LLaMA-3 style ratio)
154+
tests.append(self._make_test_case(8, 2, 16, 32))
155+
156+
# Power-of-2: seq_len=32, head_dim=64
157+
tests.append(self._make_test_case(4, 2, 32, 64))
158+
159+
# Non-power-of-2 seq_len
160+
tests.append(self._make_test_case(4, 2, 30, 32))
161+
162+
# Non-power-of-2 seq_len, different grouping
163+
tests.append(self._make_test_case(6, 3, 100, 32))
164+
165+
# GQA groups=8 (Mistral style), seq_len=255
166+
tests.append(self._make_test_case(8, 1, 255, 64))
167+
168+
# MHA equivalent (num_q_heads == num_kv_heads)
169+
tests.append(self._make_test_case(8, 8, 64, 32))
170+
171+
# Realistic small inference batch
172+
tests.append(self._make_test_case(8, 2, 128, 64))
173+
174+
return tests
175+
176+
def generate_performance_test(self) -> Dict[str, Any]:
177+
torch.manual_seed(0)
178+
# LLaMA-3 8B style: 32 Q heads, 8 KV heads, head_dim=128
179+
return self._make_test_case(32, 8, 1024, 128)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include <cuda_runtime.h>
2+
3+
// Q, K, V, output are device pointers
4+
extern "C" void solve(const float* Q, const float* K, const float* V, float* output,
5+
int num_q_heads, int num_kv_heads, int seq_len, int head_dim) {}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import cutlass
2+
import cutlass.cute as cute
3+
4+
5+
# Q, K, V, output are tensors on the GPU
6+
@cute.jit
7+
def solve(
8+
Q: cute.Tensor,
9+
K: cute.Tensor,
10+
V: cute.Tensor,
11+
output: cute.Tensor,
12+
num_q_heads: cute.Int32,
13+
num_kv_heads: cute.Int32,
14+
seq_len: cute.Int32,
15+
head_dim: cute.Int32,
16+
):
17+
pass

0 commit comments

Comments
 (0)