Skip to content

Commit 4367317

Browse files
committed
Fully overlap communication with computation ring backward. Well, in theory anyway. It's not working reliably until #20884 is merged and we can use a scan instead of unrolling.
1 parent af4317a commit 4367317

File tree

6 files changed

+101
-46
lines changed

6 files changed

+101
-46
lines changed

src/flash_attn_jax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .flash import flash_mha
2-
__version__ = 'v0.1.0'
2+
__version__ = 'v0.2.0'

src/flash_attn_jax/ring_attention.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,26 +97,38 @@ def f(acc, _):
9797
# 0: ix < ix2
9898
# 1: ix = ix2
9999
# 2: ix > ix2
100+
def skip():
101+
return (jnp.zeros(q.shape, q.dtype), jnp.zeros(k.shape, k.dtype), jnp.zeros(v.shape, v.dtype))
102+
def causal():
103+
return mha_bwd(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1))
104+
def non_causal():
105+
return mha_bwd(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
106+
107+
(dk2_,dv2_) = jax.lax.ppermute((dk2,dv2), axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
108+
(k2_,v2_,ix2_) = jax.lax.ppermute((k2,v2,ix2), axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
109+
100110
if is_causal:
101-
dqa, dka, dva = jax.lax.switch(cmp, (
102-
lambda: (jnp.zeros(q.shape, q.dtype), jnp.zeros(k.shape, k.dtype), jnp.zeros(v.shape, v.dtype)),
103-
lambda: mha_bwd(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1)),
104-
lambda: mha_bwd(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
105-
))
111+
(dqa, dka, dva) = jax.lax.switch(cmp, [skip, causal, non_causal])
106112
else:
107-
dqa, dka, dva = mha_bwd(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
108-
109-
dq += dqa
110-
dk2 += dka
111-
dv2 += dva
112-
113-
(k2,v2,dk2,dv2,ix2) = jax.lax.ppermute((k2,v2,dk2,dv2,ix2), axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
113+
(dqa, dka, dva) = non_causal()
114114

115-
return ((k2,v2,dk2,dv2,ix2, dq), None)
115+
# Send/receive of dk/dv retires here (because the following depends on it).
116+
if is_causal:
117+
(dq, dk2_, dv2_) = jax.lax.switch(cmp, [
118+
lambda: (dq, dk2_, dv2_),
119+
lambda: (dq+dqa, dk2_+dka, dv2_+dva),
120+
lambda: (dq+dqa, dk2_+dka, dv2_+dva)
121+
])
122+
else:
123+
dq, dk2_, dv2_ = (dq+dqa, dk2_+dka, dv2_+dva)
124+
125+
return ((k2_,v2_,dk2_,dv2_,ix2_, dq), None)
116126
acc = (k,v,dk,dv,ix, dq)
117127
# See above (#20884).
118128
for _ in range(axis_size):
119129
acc, _ = f(acc, None)
120130
acc = _optimization_barrier(acc)
131+
# acc, _ = jax.lax.scan(f,acc,None,axis_size)
121132
(k,v,dk,dv,ix2, dq) = acc
133+
(dk,dv) = jax.lax.ppermute((dk,dv), axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
122134
return dq.astype(q.dtype),dk.astype(q.dtype),dv.astype(q.dtype)

tests/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
0

tests/ref_mha.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,11 @@ def ref_fwd(q,k,v, is_causal=False, window_size=(-1,-1), softmax_scale=None):
3838
lse = einops.rearrange(lse, 'n h x l -> n (h x) l')
3939
return o.astype(q.dtype), lse.astype(jnp.float32)
4040
else:
41-
att = jnp.einsum('nlhd,nLhd->nhlL',q,k)*softmax_scale
42-
[_, _, l, L] = att.shape
43-
mask = make_mask(l,L,is_causal,window_size)
44-
att = jnp.where(mask, att, float('-inf'))
45-
lse = jax.nn.logsumexp(att, axis=-1) #nhl
46-
att = jnp.exp(att - lse[...,None])
47-
o = jnp.einsum('nhlL,nLhd->nlhd',att,v)
41+
S = jnp.einsum('nlhd,nLhd->nhlL',q,k)
42+
S = jnp.where(mask, S, float('-inf'))
43+
lse = jax.nn.logsumexp(S*softmax_scale, axis=-1) #nhl
44+
P = jax.nn.softmax(S*softmax_scale, axis=-1) #jnp.exp(att - lse[...,None])
45+
o = jnp.einsum('nhlL,nLhd->nlhd',P,v)
4846
return o.astype(q.dtype), lse.astype(jnp.float32)
4947

5048
def ref_bwd(do,q,k,v,o,lse, is_causal=False, window_size=(-1,-1), softmax_scale=None):

tests/test_ring.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from jax.experimental.shard_map import shard_map
1818
from functools import partial
1919
import einops
20+
import math
2021

2122
from flash_attn_jax.ring_attention import ring_fwd, ring_bwd
2223
from .ref_mha import ref_fwd, ref_bwd
@@ -87,15 +88,18 @@ def ring(q,k,v):
8788

8889
@pytest.mark.parametrize("causal", ['causal',''])
8990
@pytest.mark.parametrize("m", [1,2])
90-
@pytest.mark.parametrize("d", [8])
91-
@pytest.mark.parametrize("h", [1])
92-
@pytest.mark.parametrize("seqlen", [2])
91+
@pytest.mark.parametrize("d", [32])
92+
@pytest.mark.parametrize("h", [4])
93+
@pytest.mark.parametrize("seqlen", [128])
9394
def test_ring_bwd(seqlen, h, d, m, causal):
9495
window_size = (-1,-1)
9596

9697
devices = jax.devices(backend='cpu')
9798
n_device = len(devices)
9899

100+
n = 1
101+
A = 1.0 / math.sqrt(n * seqlen * h * d)
102+
99103
with Mesh(np.array(devices), axis_names=('x',)) as mesh:
100104
@jax.jit
101105
def ref(q,k,v,do):
@@ -114,11 +118,13 @@ def ring(q,k,v,do):
114118
q = jax.random.normal(jax.random.PRNGKey(0), [1, seqlen, h*m, d], dtype=jnp.float32)
115119
k = jax.random.normal(jax.random.PRNGKey(1), [1, seqlen, h, d], dtype=jnp.float32)
116120
v = jax.random.normal(jax.random.PRNGKey(2), [1, seqlen, h, d], dtype=jnp.float32)
117-
do = jax.random.normal(jax.random.PRNGKey(3), [1, seqlen, h*m, d], dtype=jnp.float32)
121+
do = jax.random.normal(jax.random.PRNGKey(3), [1, seqlen, h*m, d], dtype=jnp.float32) * A
118122
o_ref = ref(q,k,v,do)
119123
o_ring = ring(q,k,v,do)
124+
# print(jnp.stack([o_ref[0], o_ring[0], o_ref[0] - o_ring[0]], axis=-1))
125+
print(jnp.stack([o_ref[2], o_ring[2], o_ref[2] - o_ring[2]], axis=-1))
120126
for i in range(3):
121-
assert jnp.allclose(o_ref[i], o_ring[i], rtol=1e-2, atol=1e-3)
127+
assert jnp.allclose(o_ref[i], o_ring[i], rtol=1e-2, atol=1e-3), i
122128

123129
if __name__ == '__main__':
124130
test_ref()

tests/test_sharding.py

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ def pretty(tensor):
2929
std = jnp.std(tensor)
3030
return f'[{shape}: {mn:.3g} | {mean:.3g}±{std:.3g} | {mx:.3g}]'
3131

32-
def check(ref_out, jax_out, out):
32+
def check(ref_out, jax_out, out, eps=3):
3333
def check1(ref_out, jax_out, out):
34-
assert jnp.max(jnp.abs(out - ref_out)).item() <= 3 * jnp.max(jnp.abs(jax_out - ref_out)).item(), (pretty(jnp.abs(out - ref_out)), 'vs', pretty(jnp.abs(jax_out - ref_out)))
34+
out_diff = jnp.abs(out - ref_out)
35+
jax_diff = jnp.abs(jax_out - ref_out)
36+
assert jnp.max(out_diff) <= eps * jnp.max(jax_diff), (pretty(out_diff), 'vs', pretty(jax_diff))
3537
tree_map(check1, ref_out, jax_out, out)
3638

3739
@pytest.mark.skipif(len(jax.local_devices()) < 2, reason='Requires >1 gpu device')
@@ -130,9 +132,11 @@ def with_sharding(sharding):
130132
assert 'dynamic-slice' not in hlo
131133
assert 'collective-permute' in hlo
132134
# Should always run concurrently, meaning custom-call is always between start and done.
133-
import re
134-
collectives = ''.join(re.findall(" collective-permute-start| collective-permute-done| custom-call", hlo))
135-
assert 'collective-permute-start collective-permute-done' not in collectives, hlo
135+
# import re
136+
# collectives = ''.join(re.findall(" collective-permute-start| collective-permute-done| custom-call", hlo))
137+
# assert 'collective-permute-start collective-permute-done' not in collectives, hlo
138+
print(hlo)
139+
assert 'collective-permute-start collective-permute-done' not in decode_hlo(hlo), decode_hlo(hlo)
136140

137141
@pytest.mark.skipif(len(jax.local_devices()) < 2, reason='Requires >1 gpu device')
138142
@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16])
@@ -162,7 +166,7 @@ def flash(qkv):
162166
q = q.astype(dtype)
163167
k = k.astype(dtype)
164168
v = v.astype(dtype)
165-
ref16_out = flash((q,k,v))
169+
ref16_out = ref((q,k,v))
166170

167171
def check_sharding(sharding,q,k,v):
168172
(q,k,v) = jax.device_put((q,k,v), sharding)
@@ -193,37 +197,73 @@ def test_flash_bwd_sharded(seqlen, h, d, m, causal, local, dtype):
193197
devices = jax.local_devices()
194198
n = len(devices)
195199

200+
A = 1.0 / math.sqrt(n * seqlen * h * d)
201+
196202
@jax.jit
197203
@jax.grad
198-
def ref(qkv):
199-
return ref_mha(*qkv, is_causal=bool(causal), window_size=window_size).sum()
204+
def ref(qkv, do):
205+
o = ref_mha(*qkv, is_causal=bool(causal), window_size=window_size)
206+
return (o * do).sum()
200207
@jax.jit
201208
@jax.grad
202-
def flash(qkv):
203-
return flash_mha(*qkv, is_causal=bool(causal), window_size=window_size).sum()
209+
def flash(qkv, do):
210+
o = flash_mha(*qkv, is_causal=bool(causal), window_size=window_size)
211+
return (o * do).sum()
204212
q = jax.random.normal(jax.random.PRNGKey(0), [n, seqlen, h*m, d], dtype=jnp.float32)
205213
k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32)
206214
v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32)
215+
do = jax.random.normal(jax.random.PRNGKey(3), [n, seqlen, h*m, d], dtype=jnp.float32) * A
207216

208-
ref_out = ref((q,k,v))
217+
ref_out = ref((q,k,v), do)
209218
q = q.astype(dtype)
210219
k = k.astype(dtype)
211220
v = v.astype(dtype)
212-
ref16_out = flash((q,k,v))
221+
do = do.astype(dtype)
222+
ref16_out = ref((q,k,v), do)
213223

214-
def check_sharding(sharding,q,k,v):
215-
(q,k,v) = jax.device_put((q,k,v), sharding)
216-
out = flash((q,k,v))
217-
check(ref_out,ref16_out,out)
224+
def check_sharding(sharding):
225+
(qs,ks,vs,dos) = jax.device_put((q,k,v,do), sharding)
226+
out = flash((qs,ks,vs),dos)
227+
check(ref_out,ref16_out,out, eps=4)
218228

219-
check_sharding(PositionalSharding(devices).reshape(n,1,1,1),q,k,v)
220-
check_sharding(PositionalSharding(devices).reshape(1,1,n,1),q,k,v)
229+
check_sharding(PositionalSharding(devices).reshape(n,1,1,1))
230+
check_sharding(PositionalSharding(devices).reshape(1,1,n,1))
221231

222232
if not local:
223233
# Ring attention
224234
with Mesh(np.array(devices), axis_names=('x',)) as mesh:
225235
sharding = NamedSharding(mesh, P(None,'x',None,None))
226-
check_sharding(sharding,q,k,v)
236+
check_sharding(sharding)
237+
238+
def decode_hlo(hlo):
239+
computations = {}
240+
current_name = None
241+
current_lines = []
242+
for line in hlo.splitlines():
243+
if line.startswith('%') or line.startswith('ENTRY'):
244+
if current_name is not None:
245+
computations[current_name] = current_lines
246+
current_name = line.split()[0]
247+
current_lines = []
248+
elif line.lstrip().startswith('%') or line.lstrip().startswith('ROOT'):
249+
current_lines.append(line)
250+
if current_lines:
251+
computations[current_name] = current_lines
252+
253+
def visit(name):
254+
for line in computations[name]:
255+
if 'custom-call(' in line:
256+
yield 'custom-call'
257+
elif any('calls='+target in line for target in computations.keys()):
258+
target = [target for target in computations.keys() if 'calls='+target in line][0]
259+
for item in visit(target):
260+
yield item
261+
elif 'collective-permute-start(' in line:
262+
yield 'collective-permute-start'
263+
elif 'collective-permute-done(' in line:
264+
yield 'collective-permute-done'
265+
266+
return ' '.join(visit('ENTRY'))
227267

228268
if __name__ == '__main__':
229269
test_flash_fwd_sharded_hlo(128,4,32,False,False,jnp.float16)

0 commit comments

Comments
 (0)