Skip to content

Commit e43f346

Browse files
committed
Implement Ring Attention forward pass, and unit tests for it.
1 parent 161c739 commit e43f346

File tree

3 files changed

+101
-57
lines changed

3 files changed

+101
-57
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
__pycache__/
33
*.py[cod]
44
.pytest_cache
5+
.cache
56

67
# C extensions
78
*.so

src/flash_attn_jax/flash_sharding.py

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,49 +28,82 @@
2828
_flash_mha_fwd_hlo_sharded = custom_partitioning(_flash_mha_fwd_hlo, static_argnums=(3,4,5))
2929
_flash_mha_bwd_hlo_sharded = custom_partitioning(_flash_mha_bwd_hlo, static_argnums=(6,7,8))
3030

31+
from jax._src.ad_checkpoint import _optimization_barrier
32+
33+
def ring_fwd(softmax_scale, is_causal, axis_name, axis_size, q,k,v):
34+
[n,l,h,d] = q.shape
35+
36+
q_ix = jax.lax.axis_index(axis_name)
37+
k_ix = jax.lax.axis_index(axis_name)
38+
39+
o = jnp.zeros([n,l,h,d], jnp.float32)
40+
lse = jnp.full([n,h,l], float('-inf'), jnp.float32)
41+
42+
# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
43+
def f(c, a):
44+
(k, v, o, lse, k_ix) = c
45+
46+
o1, lse1 = o, lse
47+
if is_causal:
48+
o2, lse2 = jax.lax.switch((k_ix < q_ix).astype(jnp.int32) + (k_ix <= q_ix).astype(jnp.int32),
49+
[
50+
lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype), jnp.full([n,h,l], float('-inf'), jnp.float32)),
51+
lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1)),
52+
lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)),
53+
], q, k, v)
54+
else:
55+
o2, lse2 = _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
56+
o2 = o2.astype(jnp.float32)
57+
58+
mx = jnp.maximum(lse1,lse2)
59+
mn = jnp.minimum(lse1,lse2)
60+
lse = jnp.log1p(jnp.exp(mn-mx)) + mx
61+
62+
o = (o1 * rearrange(jnp.exp(lse1 - lse), 'n h l -> n l h 1') +
63+
o2 * rearrange(jnp.exp(lse2 - lse), 'n h l -> n l h 1'))
64+
65+
k2 = jax.lax.ppermute(k, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
66+
v2 = jax.lax.ppermute(v, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
67+
k_ix = jax.lax.ppermute(k_ix, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
68+
69+
return ((k2, v2, o, lse, k_ix), None)
70+
acc = (k,v,o,lse,k_ix)
71+
# We sadly have to manually unroll this because scan breaks the axis context preventing us from using ppermute (unroll=axis_size doesn't help either).
72+
# Optimization barrier prevents instruction reordering so that ppermute and flash_mha execute concurrently.
73+
for _ in range(axis_size):
74+
acc, _ = f(acc, None)
75+
acc = _optimization_barrier(acc)
76+
(_,_,o,lse,_) = acc
77+
# (_,_,o,lse), _ = jax.lax.scan(f,init,None,axis_size)
78+
return o.astype(q.dtype), lse
79+
3180
def partition_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
3281
result_shardings = jax.tree_map(lambda x: x.sharding, result_shape)
3382
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
3483

3584
q_sharding = arg_shardings[0]
3685
if isinstance(q_sharding, PositionalSharding):
37-
if not is_causal and window_size == (-1,-1):
38-
# We can handle Q that's sharded across the L dimension
39-
# without replicating Q by executing it as a cross
40-
# attention:
41-
#
42-
# q : n [L/devices] h d
43-
# kv : n L h d
44-
# -> o : n [L/devices] h d
45-
#
46-
# TODO: We could handle q sharded across L even with
47-
# causal/local if we could communicate the slice offset
48-
# (of q in kv) to the c++ driver. But it's unclear how to
49-
# do that since the HLO has to be identical (SPMD).
50-
q_sharding = q_sharding.replicate(3)
51-
kv_sharding = q_sharding.replicate(1)
52-
(n,l,h,d) = q_sharding.shape
53-
result_shardings = q_sharding, q_sharding.reshape((n,l,h)).transpose(0,2,1) # n h l
54-
arg_shardings = q_sharding, kv_sharding, kv_sharding
55-
else:
56-
# We need to replicate d always.
57-
q_sharding = q_sharding.replicate((1,3))
58-
(n,l,h,d) = q_sharding.shape # l=1, d=1
59-
result_shardings = q_sharding, q_sharding.reshape((n,l,h)).transpose(0,2,1)
60-
arg_shardings = q_sharding, q_sharding, q_sharding
86+
(n,l,h,d) = q_sharding.shape
87+
assert d == 1, "Sharding across `d` won't be efficient, so it's not supported."
88+
assert l == 1, "For ring attention, use `with Mesh(...) as mesh` and NamedSharding."
89+
result_shardings = q_sharding, q_sharding.reshape((n,h,1)) # n h l
90+
arg_shardings = q_sharding, q_sharding, q_sharding
6191
elif isinstance(q_sharding, NamedSharding):
6292
mesh = q_sharding.mesh
6393
[n,l,h,d] = q_sharding.spec
64-
if not is_causal and window_size == (-1,-1):
65-
q_sharding = NamedSharding(mesh, P(n,l,h,None))
66-
kv_sharding = NamedSharding(mesh, P(n,None,h,None))
67-
lse_sharding = NamedSharding(mesh, P(n,h,l))
94+
assert d == None, "Sharding across `d` won't be efficient, so it's not supported."
95+
if l != None:
96+
# assert not is_causal and window_size == (-1,-1), "Ring attention doesn't support causal or local masking yet."
97+
assert window_size == (-1,-1), "Ring attention doesn't support local masking yet."
98+
result_shardings = q_sharding, NamedSharding(mesh, P(n,h,l))
99+
arg_shardings = q_sharding, q_sharding, q_sharding
100+
axis_name = l
101+
axis_size = mesh.shape[axis_name]
102+
# ring attention
103+
return mesh, partial(ring_fwd, softmax_scale, is_causal, axis_name, axis_size), result_shardings, arg_shardings
68104
else:
69-
q_sharding = NamedSharding(mesh, P(n,None,h,None))
70-
kv_sharding = q_sharding
71-
lse_sharding = NamedSharding(mesh, P(n,h,None))
72-
result_sharding = (q_sharding, lse_sharding)
73-
arg_shardings = (q_sharding, kv_sharding, kv_sharding)
105+
result_shardings = q_sharding, NamedSharding(mesh, P(n,h,l))
106+
arg_shardings = q_sharding, q_sharding, q_sharding
74107
def fwd(q,k,v):
75108
return _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size)
76109
return mesh, fwd, result_shardings, arg_shardings

tests/test_sharding.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1-
import sys, glob
1+
import glob
2+
import sys
3+
24
if glob.glob('build/lib.linux-*'):
35
sys.path.append(glob.glob('build/lib.linux-*')[0])
46

57
import jax
68
import jax.numpy as jnp
79
import numpy as np
810
import pytest
9-
11+
from jax.sharding import Mesh, NamedSharding
12+
from jax.sharding import PartitionSpec as P
1013
from jax.sharding import PositionalSharding
1114
from jax.tree_util import tree_map
1215

1316
from flash_attn_jax import flash_mha
1417

18+
1519
def ref_mha(q,k,v, is_causal=False, window_size=(-1,-1)):
1620
softmax_scale = 1/np.sqrt(q.shape[-1])
1721
att = jnp.einsum('nlhd,nLhd->nhlL',q,k)
@@ -79,13 +83,18 @@ def with_sharding(q_sharding, kv_sharding=None):
7983
assert 'all-gather' not in hlo
8084
assert 'dynamic-slice' not in hlo
8185

82-
# With q sharded and kv replicated, should need no communication
83-
# (handle it as a cross attention), as long as causal and local
84-
# are both false.
85-
hlo = with_sharding(PositionalSharding(devices).reshape(1,n,1,1), PositionalSharding(devices).replicate())
86-
if not (causal or local):
87-
assert 'all-gather' not in hlo
88-
assert 'dynamic-slice' not in hlo
86+
if not local:
87+
with Mesh(np.array(devices), axis_names=('x',)) as mesh:
88+
sharding = NamedSharding(mesh, P(None,'x',None,None))
89+
hlo = with_sharding(sharding)
90+
# No resharding should occur, only manual collective-permute.
91+
assert 'all-gather' not in hlo
92+
assert 'dynamic-slice' not in hlo
93+
assert 'collective-permute' in hlo
94+
# Should always run concurrently, meaning custom-call is always between start and done.
95+
import re
96+
collectives = ''.join(re.findall(" collective-permute-start| collective-permute-done| custom-call", hlo))
97+
assert 'collective-permute-start collective-permute-done' not in collectives, hlo
8998

9099

91100
@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16])
@@ -128,8 +137,7 @@ def flash(qkv):
128137
@pytest.mark.parametrize("d", [32])
129138
@pytest.mark.parametrize("h", [4, 8])
130139
@pytest.mark.parametrize("seqlen", [128])
131-
@pytest.mark.parametrize("shard_dim", [0,1,2,3])
132-
def test_flash_fwd_sharded(seqlen, h, d, causal, local, dtype, shard_dim):
140+
def test_flash_fwd_sharded(seqlen, h, d, causal, local, dtype):
133141
window_size = (3,3) if local else (-1,-1)
134142

135143
devices = jax.local_devices()
@@ -145,23 +153,25 @@ def flash(qkv):
145153
k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32)
146154
v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32)
147155

148-
if q.shape[shard_dim] % n != 0:
149-
pytest.skip(f"{q.shape[shard_dim]} doesn't divide into {n} so we can't shard it.")
150-
151156
ref_out = ref((q,k,v))
152157
q = q.astype(dtype)
153158
k = k.astype(dtype)
154159
v = v.astype(dtype)
155-
repl_out = flash((q,k,v))
160+
ref16_out = flash((q,k,v))
156161

157-
shape = [1,1,1,1]
158-
shape[shard_dim] = n
159-
sharding = PositionalSharding(devices).reshape(shape)
162+
def check_sharding(sharding,q,k,v):
163+
(q,k,v) = jax.device_put((q,k,v), sharding)
164+
out = flash((q,k,v))
165+
check(ref_out,ref16_out,out)
160166

161-
(q,k,v) = jax.device_put((q,k,v), sharding)
162-
hlo = flash.lower((q,k,v)).compile().as_text()
163-
out = flash((q,k,v))
164-
check(ref_out, repl_out, out)
167+
check_sharding(PositionalSharding(devices).reshape(n,1,1,1),q,k,v)
168+
check_sharding(PositionalSharding(devices).reshape(1,1,n,1),q,k,v)
169+
170+
if not local:
171+
# Ring attention
172+
with Mesh(np.array(devices), axis_names=('x',)) as mesh:
173+
sharding = NamedSharding(mesh, P(None,'x',None,None))
174+
check_sharding(sharding,q,k,v)
165175

166176

167177
@pytest.mark.skipif(len(jax.local_devices()) < 2, reason='Requires >1 gpu device')
@@ -171,7 +181,7 @@ def flash(qkv):
171181
@pytest.mark.parametrize("d", [32])
172182
@pytest.mark.parametrize("h", [4, 8])
173183
@pytest.mark.parametrize("seqlen", [128])
174-
@pytest.mark.parametrize("shard_dim", [0,1,2,3])
184+
@pytest.mark.parametrize("shard_dim", [0,2])
175185
def test_flash_bwd_sharded(seqlen, h, d, causal, local, dtype, shard_dim):
176186
window_size = (3,3) if local else (-1,-1)
177187

@@ -209,4 +219,4 @@ def flash(qkv):
209219
check(ref_out, repl_out, out)
210220

211221
if __name__ == '__main__':
212-
test_flash_bwd_sharded_hlo(128,4,32,False,False,jnp.float16)
222+
test_flash_fwd_sharded_hlo(128,4,32,False,False,jnp.float16)

0 commit comments

Comments
 (0)