@@ -29,9 +29,11 @@ def pretty(tensor):
29
29
std = jnp .std (tensor )
30
30
return f'[{ shape } : { mn :.3g} | { mean :.3g} ±{ std :.3g} | { mx :.3g} ]'
31
31
32
- def check (ref_out , jax_out , out ):
32
+ def check (ref_out , jax_out , out , eps = 3 ):
33
33
def check1 (ref_out , jax_out , out ):
34
- assert jnp .max (jnp .abs (out - ref_out )).item () <= 3 * jnp .max (jnp .abs (jax_out - ref_out )).item (), (pretty (jnp .abs (out - ref_out )), 'vs' , pretty (jnp .abs (jax_out - ref_out )))
34
+ out_diff = jnp .abs (out - ref_out )
35
+ jax_diff = jnp .abs (jax_out - ref_out )
36
+ assert jnp .max (out_diff ) <= eps * jnp .max (jax_diff ), (pretty (out_diff ), 'vs' , pretty (jax_diff ))
35
37
tree_map (check1 , ref_out , jax_out , out )
36
38
37
39
@pytest .mark .skipif (len (jax .local_devices ()) < 2 , reason = 'Requires >1 gpu device' )
@@ -130,9 +132,11 @@ def with_sharding(sharding):
130
132
assert 'dynamic-slice' not in hlo
131
133
assert 'collective-permute' in hlo
132
134
# Should always run concurrently, meaning custom-call is always between start and done.
133
- import re
134
- collectives = '' .join (re .findall (" collective-permute-start| collective-permute-done| custom-call" , hlo ))
135
- assert 'collective-permute-start collective-permute-done' not in collectives , hlo
135
+ # import re
136
+ # collectives = ''.join(re.findall(" collective-permute-start| collective-permute-done| custom-call", hlo))
137
+ # assert 'collective-permute-start collective-permute-done' not in collectives, hlo
138
+ print (hlo )
139
+ assert 'collective-permute-start collective-permute-done' not in decode_hlo (hlo ), decode_hlo (hlo )
136
140
137
141
@pytest .mark .skipif (len (jax .local_devices ()) < 2 , reason = 'Requires >1 gpu device' )
138
142
@pytest .mark .parametrize ("dtype" , [jnp .float16 , jnp .bfloat16 ])
@@ -162,7 +166,7 @@ def flash(qkv):
162
166
q = q .astype (dtype )
163
167
k = k .astype (dtype )
164
168
v = v .astype (dtype )
165
- ref16_out = flash ((q ,k ,v ))
169
+ ref16_out = ref ((q ,k ,v ))
166
170
167
171
def check_sharding (sharding ,q ,k ,v ):
168
172
(q ,k ,v ) = jax .device_put ((q ,k ,v ), sharding )
@@ -193,37 +197,73 @@ def test_flash_bwd_sharded(seqlen, h, d, m, causal, local, dtype):
193
197
devices = jax .local_devices ()
194
198
n = len (devices )
195
199
200
+ A = 1.0 / math .sqrt (n * seqlen * h * d )
201
+
196
202
@jax .jit
197
203
@jax .grad
198
- def ref (qkv ):
199
- return ref_mha (* qkv , is_causal = bool (causal ), window_size = window_size ).sum ()
204
+ def ref (qkv , do ):
205
+ o = ref_mha (* qkv , is_causal = bool (causal ), window_size = window_size )
206
+ return (o * do ).sum ()
200
207
@jax .jit
201
208
@jax .grad
202
- def flash (qkv ):
203
- return flash_mha (* qkv , is_causal = bool (causal ), window_size = window_size ).sum ()
209
+ def flash (qkv , do ):
210
+ o = flash_mha (* qkv , is_causal = bool (causal ), window_size = window_size )
211
+ return (o * do ).sum ()
204
212
q = jax .random .normal (jax .random .PRNGKey (0 ), [n , seqlen , h * m , d ], dtype = jnp .float32 )
205
213
k = jax .random .normal (jax .random .PRNGKey (1 ), [n , seqlen , h , d ], dtype = jnp .float32 )
206
214
v = jax .random .normal (jax .random .PRNGKey (2 ), [n , seqlen , h , d ], dtype = jnp .float32 )
215
+ do = jax .random .normal (jax .random .PRNGKey (3 ), [n , seqlen , h * m , d ], dtype = jnp .float32 ) * A
207
216
208
- ref_out = ref ((q ,k ,v ))
217
+ ref_out = ref ((q ,k ,v ), do )
209
218
q = q .astype (dtype )
210
219
k = k .astype (dtype )
211
220
v = v .astype (dtype )
212
- ref16_out = flash ((q ,k ,v ))
221
+ do = do .astype (dtype )
222
+ ref16_out = ref ((q ,k ,v ), do )
213
223
214
- def check_sharding (sharding , q , k , v ):
215
- (q , k , v ) = jax .device_put ((q ,k ,v ), sharding )
216
- out = flash ((q , k , v ) )
217
- check (ref_out ,ref16_out ,out )
224
+ def check_sharding (sharding ):
225
+ (qs , ks , vs , dos ) = jax .device_put ((q ,k ,v , do ), sharding )
226
+ out = flash ((qs , ks , vs ), dos )
227
+ check (ref_out ,ref16_out ,out , eps = 4 )
218
228
219
- check_sharding (PositionalSharding (devices ).reshape (n ,1 ,1 ,1 ), q , k , v )
220
- check_sharding (PositionalSharding (devices ).reshape (1 ,1 ,n ,1 ), q , k , v )
229
+ check_sharding (PositionalSharding (devices ).reshape (n ,1 ,1 ,1 ))
230
+ check_sharding (PositionalSharding (devices ).reshape (1 ,1 ,n ,1 ))
221
231
222
232
if not local :
223
233
# Ring attention
224
234
with Mesh (np .array (devices ), axis_names = ('x' ,)) as mesh :
225
235
sharding = NamedSharding (mesh , P (None ,'x' ,None ,None ))
226
- check_sharding (sharding ,q ,k ,v )
236
+ check_sharding (sharding )
237
+
238
+ def decode_hlo (hlo ):
239
+ computations = {}
240
+ current_name = None
241
+ current_lines = []
242
+ for line in hlo .splitlines ():
243
+ if line .startswith ('%' ) or line .startswith ('ENTRY' ):
244
+ if current_name is not None :
245
+ computations [current_name ] = current_lines
246
+ current_name = line .split ()[0 ]
247
+ current_lines = []
248
+ elif line .lstrip ().startswith ('%' ) or line .lstrip ().startswith ('ROOT' ):
249
+ current_lines .append (line )
250
+ if current_lines :
251
+ computations [current_name ] = current_lines
252
+
253
+ def visit (name ):
254
+ for line in computations [name ]:
255
+ if 'custom-call(' in line :
256
+ yield 'custom-call'
257
+ elif any ('calls=' + target in line for target in computations .keys ()):
258
+ target = [target for target in computations .keys () if 'calls=' + target in line ][0 ]
259
+ for item in visit (target ):
260
+ yield item
261
+ elif 'collective-permute-start(' in line :
262
+ yield 'collective-permute-start'
263
+ elif 'collective-permute-done(' in line :
264
+ yield 'collective-permute-done'
265
+
266
+ return ' ' .join (visit ('ENTRY' ))
227
267
228
268
if __name__ == '__main__' :
229
269
test_flash_fwd_sharded_hlo (128 ,4 ,32 ,False ,False ,jnp .float16 )
0 commit comments