Skip to content

Commit 3ca9f14

Browse files
Merge pull request jax-ml#25361 from Rifur13:regression
PiperOrigin-RevId: 704885039
2 parents 1c1a17e + e1e174f commit 3ca9f14

File tree

1 file changed

+39
-66
lines changed

1 file changed

+39
-66
lines changed

tests/pallas/gpu_ops_test.py

Lines changed: 39 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -148,40 +148,16 @@ def setUp(self):
148148
if jtu.test_device_matches(["tpu"]):
149149
self.skipTest("Not intended for TPU")
150150

151-
@jtu.parameterized_filterable(
152-
kwargs=[
153-
dict(
154-
batch_size=batch_size,
155-
seq_len=seq_len,
156-
num_heads=num_heads,
157-
head_dim=head_dim,
158-
causal=causal,
159-
use_fwd=use_fwd,
160-
use_segment_ids=use_segment_ids,
161-
kwargs=kwargs,
162-
)
163-
for (
164-
batch_size,
165-
seq_len,
166-
num_heads,
167-
head_dim,
168-
causal,
169-
use_fwd,
170-
use_segment_ids,
171-
kwargs,
172-
) in [
173-
(1, 384, 1, 64, False, False, True, {}),
174-
(1, 384, 1, 64, False, False, False, {}),
175-
(2, 384, 2, 64, False, False, True, {}),
176-
(1, 384, 1, 64, True, False, True, {}),
177-
# (2, 384, 2, 64, True, False, True, {}), # TODO(sharadmv): Investigate.
178-
(1, 384, 8, 64, True, True, True, {}),
179-
(1, 384, 8, 64, True, True, False, {}),
180-
(2, 384, 8, 64, True, True, True, {}),
181-
# regression test: https://github.com/jax-ml/jax/pull/17314
182-
(1, 384, 8, 64, True, False, False, {'block_q': 128, 'block_k': 64}),
183-
]
184-
]
151+
@jtu.sample_product(
152+
batch_size=(1, 2),
153+
seq_len=(128, 384),
154+
num_heads=(1, 2, 8),
155+
head_dim=(32, 64, 128),
156+
block_q=(64, 128),
157+
block_k=(64, 128),
158+
causal=(True, False),
159+
use_fwd=(True, False),
160+
use_segment_ids=(True, False),
185161
)
186162
def test_fused_attention_fwd(
187163
self,
@@ -190,10 +166,11 @@ def test_fused_attention_fwd(
190166
seq_len,
191167
num_heads,
192168
head_dim,
169+
block_q,
170+
block_k,
193171
causal,
194172
use_fwd,
195173
use_segment_ids,
196-
kwargs,
197174
):
198175
k1, k2, k3 = random.split(random.key(0), 3)
199176
q = random.normal(
@@ -218,8 +195,12 @@ def test_fused_attention_fwd(
218195
def impl(q, k, v):
219196
v, _ = jax.vjp(
220197
functools.partial(
221-
attention.mha, causal=causal, segment_ids=segment_ids,
222-
interpret=self.INTERPRET, **kwargs
198+
attention.mha,
199+
block_q=block_q,
200+
block_k=block_k,
201+
causal=causal,
202+
segment_ids=segment_ids,
203+
interpret=self.INTERPRET,
223204
),
224205
q,
225206
k,
@@ -229,42 +210,34 @@ def impl(q, k, v):
229210

230211
else:
231212
impl = functools.partial(
232-
attention.mha, causal=causal, segment_ids=segment_ids,
233-
interpret=self.INTERPRET, **kwargs
213+
attention.mha,
214+
block_q=block_q,
215+
block_k=block_k,
216+
causal=causal,
217+
segment_ids=segment_ids,
218+
interpret=self.INTERPRET,
234219
)
235220
o = impl(q, k, v)
236221
o_ref = attention.mha_reference(q, k, v, segment_ids, causal=causal)
237222
np.testing.assert_allclose(o, o_ref, atol=0.05)
238223

239-
@jtu.parameterized_filterable(
240-
kwargs=[
241-
dict(
242-
batch_size=batch_size,
243-
seq_len=seq_len,
244-
num_heads=num_heads,
245-
head_dim=head_dim,
246-
causal=causal,
247-
use_segment_ids=use_segment_ids,
248-
)
249-
for (
250-
batch_size,
251-
seq_len,
252-
num_heads,
253-
head_dim,
254-
causal,
255-
use_segment_ids,
256-
) in [
257-
(1, 384, 1, 32, False, True),
258-
(1, 384, 1, 32, False, False),
259-
(2, 384, 2, 32, False, True),
260-
(2, 384, 2, 32, False, False),
261-
(1, 384, 1, 32, True, True),
262-
(2, 384, 2, 32, True, True),
263-
]
264-
]
224+
@jtu.sample_product(
225+
batch_size=(1, 2),
226+
seq_len=(128, 384),
227+
num_heads=(1, 2, 4),
228+
head_dim=(32,),
229+
causal=(True, False),
230+
use_segment_ids=(True, False),
265231
)
266232
def test_fused_attention_bwd(
267-
self, *, batch_size, seq_len, num_heads, head_dim, causal, use_segment_ids
233+
self,
234+
*,
235+
batch_size,
236+
seq_len,
237+
num_heads,
238+
head_dim,
239+
causal,
240+
use_segment_ids,
268241
):
269242
k1, k2, k3 = random.split(random.key(0), 3)
270243
q = random.normal(

0 commit comments

Comments
 (0)