@@ -148,40 +148,16 @@ def setUp(self):
148148 if jtu .test_device_matches (["tpu" ]):
149149 self .skipTest ("Not intended for TPU" )
150150
151- @jtu .parameterized_filterable (
152- kwargs = [
153- dict (
154- batch_size = batch_size ,
155- seq_len = seq_len ,
156- num_heads = num_heads ,
157- head_dim = head_dim ,
158- causal = causal ,
159- use_fwd = use_fwd ,
160- use_segment_ids = use_segment_ids ,
161- kwargs = kwargs ,
162- )
163- for (
164- batch_size ,
165- seq_len ,
166- num_heads ,
167- head_dim ,
168- causal ,
169- use_fwd ,
170- use_segment_ids ,
171- kwargs ,
172- ) in [
173- (1 , 384 , 1 , 64 , False , False , True , {}),
174- (1 , 384 , 1 , 64 , False , False , False , {}),
175- (2 , 384 , 2 , 64 , False , False , True , {}),
176- (1 , 384 , 1 , 64 , True , False , True , {}),
177- # (2, 384, 2, 64, True, False, True, {}), # TODO(sharadmv): Investigate.
178- (1 , 384 , 8 , 64 , True , True , True , {}),
179- (1 , 384 , 8 , 64 , True , True , False , {}),
180- (2 , 384 , 8 , 64 , True , True , True , {}),
181- # regression test: https://github.com/jax-ml/jax/pull/17314
182- (1 , 384 , 8 , 64 , True , False , False , {'block_q' : 128 , 'block_k' : 64 }),
183- ]
184- ]
151+ @jtu .sample_product (
152+ batch_size = (1 , 2 ),
153+ seq_len = (128 , 384 ),
154+ num_heads = (1 , 2 , 8 ),
155+ head_dim = (32 , 64 , 128 ),
156+ block_q = (64 , 128 ),
157+ block_k = (64 , 128 ),
158+ causal = (True , False ),
159+ use_fwd = (True , False ),
160+ use_segment_ids = (True , False ),
185161 )
186162 def test_fused_attention_fwd (
187163 self ,
@@ -190,10 +166,11 @@ def test_fused_attention_fwd(
190166 seq_len ,
191167 num_heads ,
192168 head_dim ,
169+ block_q ,
170+ block_k ,
193171 causal ,
194172 use_fwd ,
195173 use_segment_ids ,
196- kwargs ,
197174 ):
198175 k1 , k2 , k3 = random .split (random .key (0 ), 3 )
199176 q = random .normal (
@@ -218,8 +195,12 @@ def test_fused_attention_fwd(
218195 def impl (q , k , v ):
219196 v , _ = jax .vjp (
220197 functools .partial (
221- attention .mha , causal = causal , segment_ids = segment_ids ,
222- interpret = self .INTERPRET , ** kwargs
198+ attention .mha ,
199+ block_q = block_q ,
200+ block_k = block_k ,
201+ causal = causal ,
202+ segment_ids = segment_ids ,
203+ interpret = self .INTERPRET ,
223204 ),
224205 q ,
225206 k ,
@@ -229,42 +210,34 @@ def impl(q, k, v):
229210
230211 else :
231212 impl = functools .partial (
232- attention .mha , causal = causal , segment_ids = segment_ids ,
233- interpret = self .INTERPRET , ** kwargs
213+ attention .mha ,
214+ block_q = block_q ,
215+ block_k = block_k ,
216+ causal = causal ,
217+ segment_ids = segment_ids ,
218+ interpret = self .INTERPRET ,
234219 )
235220 o = impl (q , k , v )
236221 o_ref = attention .mha_reference (q , k , v , segment_ids , causal = causal )
237222 np .testing .assert_allclose (o , o_ref , atol = 0.05 )
238223
239- @jtu .parameterized_filterable (
240- kwargs = [
241- dict (
242- batch_size = batch_size ,
243- seq_len = seq_len ,
244- num_heads = num_heads ,
245- head_dim = head_dim ,
246- causal = causal ,
247- use_segment_ids = use_segment_ids ,
248- )
249- for (
250- batch_size ,
251- seq_len ,
252- num_heads ,
253- head_dim ,
254- causal ,
255- use_segment_ids ,
256- ) in [
257- (1 , 384 , 1 , 32 , False , True ),
258- (1 , 384 , 1 , 32 , False , False ),
259- (2 , 384 , 2 , 32 , False , True ),
260- (2 , 384 , 2 , 32 , False , False ),
261- (1 , 384 , 1 , 32 , True , True ),
262- (2 , 384 , 2 , 32 , True , True ),
263- ]
264- ]
224+ @jtu .sample_product (
225+ batch_size = (1 , 2 ),
226+ seq_len = (128 , 384 ),
227+ num_heads = (1 , 2 , 4 ),
228+ head_dim = (32 ,),
229+ causal = (True , False ),
230+ use_segment_ids = (True , False ),
265231 )
266232 def test_fused_attention_bwd (
267- self , * , batch_size , seq_len , num_heads , head_dim , causal , use_segment_ids
233+ self ,
234+ * ,
235+ batch_size ,
236+ seq_len ,
237+ num_heads ,
238+ head_dim ,
239+ causal ,
240+ use_segment_ids ,
268241 ):
269242 k1 , k2 , k3 = random .split (random .key (0 ), 3 )
270243 q = random .normal (
0 commit comments