Skip to content

Commit 66ef3e8

Browse files
committed
Fixed the ring attention backward to not do 2x work. Although this means it doesn't quite overlap the communications.
1 parent 6cbb831 commit 66ef3e8

File tree

9 files changed

+412
-125
lines changed

9 files changed

+412
-125
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def append_nvcc_threads(nvcc_extra_args):
106106
# cc_flag.append("arch=compute_90,code=sm_90")
107107
ext_modules.append(
108108
CUDAExtension(
109-
name="flash_attn_jax.flash_api",
109+
name="flash_attn_jax_lib.flash_api",
110110
sources=[
111111
"csrc/flash_attn/flash_api.cpp",
112112
"csrc/flash_attn/flash_common.cpp",

src/flash_attn_jax/flash.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
_flash_mha_bwd_p.def_impl(partial(xla.apply_primitive, _flash_mha_bwd_p))
4040

4141
try:
42-
# JAX 0.4.24 and above requires this.
42+
# JAX 0.4.24 and above requires this because of custom partitioning.
4343
import jax._src.dispatch
4444
jax._src.dispatch.prim_requires_devices_during_lowering.add(_flash_mha_bwd_p)
4545
jax._src.dispatch.prim_requires_devices_during_lowering.add(_flash_mha_fwd_p)

src/flash_attn_jax/flash_hlo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from einops import rearrange
2222
import math
2323

24-
import flash_attn_jax.flash_api as flash_api
24+
import flash_attn_jax_lib.flash_api as flash_api
2525

2626
# ==== Register primitives ====
2727

src/flash_attn_jax/flash_sharding.py

Lines changed: 24 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import math
2323

2424
from .flash_hlo import _flash_mha_fwd_hlo, _flash_mha_bwd_hlo
25+
from .ring_attention import ring_fwd, ring_bwd
2526

2627
# ==== Sharding ====
2728

@@ -30,12 +31,20 @@
3031

3132
from jax._src.ad_checkpoint import _optimization_barrier
3233

34+
def is_replicated(sharding):
35+
return (isinstance(sharding, PositionalSharding) and sharding.shape == (1,)) or (isinstance(sharding, NamedSharding) and len(sharding.spec) == 0)
36+
3337
def partition_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
3438
result_shardings = jax.tree_map(lambda x: x.sharding, result_shape)
3539
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
3640

3741
q_sharding = arg_shardings[0]
38-
if isinstance(q_sharding, PositionalSharding):
42+
k_sharding = arg_shardings[1]
43+
v_sharding = arg_shardings[2]
44+
assert q_sharding == k_sharding and q_sharding == v_sharding, "Only support q, k, v sharing the same sharding."
45+
if is_replicated(q_sharding):
46+
result_sharding = (q_sharding, q_sharding)
47+
elif isinstance(q_sharding, PositionalSharding):
3948
(n,l,h,d) = q_sharding.shape
4049
assert d == 1, "Sharding across `d` won't be efficient, so it's not supported."
4150
assert l == 1, "For ring attention, use `with Mesh(...) as mesh` and NamedSharding."
@@ -53,7 +62,7 @@ def partition_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, resul
5362
axis_name = l
5463
axis_size = mesh.shape[axis_name]
5564
# ring attention
56-
return mesh, partial(ring_fwd, softmax_scale, is_causal, axis_name, axis_size), result_shardings, arg_shardings
65+
return mesh, partial(ring_fwd, softmax_scale=softmax_scale, is_causal=is_causal, axis_name=axis_name, axis_size=axis_size, mha_fwd=_flash_mha_fwd_hlo), result_shardings, arg_shardings
5766
else:
5867
result_shardings = q_sharding, NamedSharding(mesh, P(n,h,l))
5968
arg_shardings = q_sharding, q_sharding, q_sharding
@@ -64,7 +73,12 @@ def fwd(q,k,v):
6473
def infer_sharding_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, result_shape):
6574
arg_shardings = jax.tree_map(lambda x: x.sharding, arg_shapes)
6675
q_sharding = arg_shardings[0]
67-
if isinstance(q_sharding, PositionalSharding):
76+
k_sharding = arg_shardings[1]
77+
v_sharding = arg_shardings[2]
78+
assert q_sharding == k_sharding and q_sharding == v_sharding, "Only support q, k, v sharing the same sharding."
79+
if is_replicated(q_sharding):
80+
result_sharding = (q_sharding, q_sharding)
81+
elif isinstance(q_sharding, PositionalSharding):
6882
[n,l,h,d] = q_sharding.shape
6983
result_sharding = (q_sharding, # [n,l,h,d]
7084
q_sharding.replicate(3).reshape(n,l,h).transpose((0,2,1)) # [n,h,l]
@@ -73,6 +87,8 @@ def infer_sharding_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes,
7387
[n,l,h,d] = q_sharding.spec
7488
result_sharding = (q_sharding,
7589
NamedSharding(q_sharding.mesh, P(n,h,l)))
90+
else:
91+
raise ValueError("Unsupported sharding type.", type(q_sharding))
7692
return result_sharding
7793

7894
_flash_mha_fwd_hlo_sharded.def_partition(
@@ -99,7 +115,10 @@ def partition_bwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, resul
99115
v_sharding = arg_shardings[3]
100116
o_sharding = arg_shardings[4]
101117
lse_sharding = arg_shardings[5]
102-
if isinstance(q_sharding, PositionalSharding):
118+
assert q_sharding == k_sharding and q_sharding == v_sharding, "Only support q, k, v sharing the same sharding."
119+
if is_replicated(q_sharding):
120+
result_shardings = (q_sharding,)*3
121+
elif isinstance(q_sharding, PositionalSharding):
103122
assert q_sharding == k_sharding, "Expect q and k sharding to match"
104123
assert q_sharding == v_sharding, "Expect q and v sharding to match"
105124
[n, l, h, d] = q_sharding.shape
@@ -121,7 +140,7 @@ def partition_bwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, resul
121140
axis_name = l
122141
axis_size = mesh.shape[axis_name]
123142
# ring attention
124-
return mesh, partial(ring_bwd, softmax_scale, is_causal, axis_name, axis_size), result_shardings, arg_shardings
143+
return mesh, partial(ring_bwd, softmax_scale=softmax_scale, is_causal=is_causal, axis_name=axis_name, axis_size=axis_size, mha_bwd=_flash_mha_bwd_hlo), result_shardings, arg_shardings
125144
else:
126145
result_shardings = q_sharding, q_sharding, q_sharding
127146
lse_sharding = NamedSharding(mesh, P(n,h,l))
@@ -133,103 +152,3 @@ def fwd(*args):
133152
_flash_mha_bwd_hlo_sharded.def_partition(
134153
infer_sharding_from_operands=infer_sharding_bwd,
135154
partition=partition_bwd)
136-
137-
# ==== Ring Forward ====
138-
139-
def ring_fwd(softmax_scale, is_causal, axis_name, axis_size, q,k,v):
140-
[n,l,h,d] = q.shape
141-
142-
q_ix = jax.lax.axis_index(axis_name)
143-
k_ix = jax.lax.axis_index(axis_name)
144-
145-
o = jnp.zeros([n,l,h,d], jnp.float32)
146-
lse = jnp.full([n,h,l], float('-inf'), jnp.float32)
147-
148-
# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
149-
def f(c, a):
150-
(k, v, o, lse, k_ix) = c
151-
152-
o1, lse1 = o, lse
153-
if is_causal:
154-
o2, lse2 = jax.lax.switch((k_ix < q_ix).astype(jnp.int32) + (k_ix <= q_ix).astype(jnp.int32),
155-
[
156-
lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype), jnp.full([n,h,l], float('-inf'), jnp.float32)),
157-
lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1)),
158-
lambda q,k,v: _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)),
159-
], q, k, v)
160-
else:
161-
o2, lse2 = _flash_mha_fwd_hlo(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
162-
o2 = o2.astype(jnp.float32)
163-
164-
mx = jnp.maximum(lse1,lse2)
165-
mn = jnp.minimum(lse1,lse2)
166-
lse = jnp.log1p(jnp.exp(mn-mx)) + mx
167-
168-
o = (o1 * rearrange(jnp.exp(lse1 - lse), 'n h l -> n l h 1') +
169-
o2 * rearrange(jnp.exp(lse2 - lse), 'n h l -> n l h 1'))
170-
171-
k2 = jax.lax.ppermute(k, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
172-
v2 = jax.lax.ppermute(v, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
173-
k_ix = jax.lax.ppermute(k_ix, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
174-
175-
return ((k2, v2, o, lse, k_ix), None)
176-
acc = (k,v,o,lse,k_ix)
177-
# We sadly have to manually unroll this because scan breaks the axis context preventing us from using ppermute (unroll=axis_size doesn't help either).
178-
# Optimization barrier prevents instruction reordering so that ppermute and flash_mha execute concurrently.
179-
for _ in range(axis_size):
180-
acc, _ = f(acc, None)
181-
acc = _optimization_barrier(acc)
182-
(_,_,o,lse,_) = acc
183-
# (_,_,o,lse), _ = jax.lax.scan(f,init,None,axis_size)
184-
return o.astype(q.dtype), lse
185-
186-
# ==== Ring Backward ===
187-
188-
# This doesn't seem like the most efficient way to do this, kind of wasting compute by calculating every dq,dk,dv twice.
189-
# Should we send the accumulator for dk,dv cross-device instead? Relying on the fact that after a full cycle, they return to the starting device.
190-
def ring_bwd(softmax_scale, is_causal, axis_name, axis_size, do,q,k,v,o,lse):
191-
[n,l,h,d] = q.shape
192-
193-
ix = jax.lax.axis_index(axis_name)
194-
195-
dq = jnp.zeros([n,l,h,d], jnp.float32)
196-
dk = jnp.zeros([n,l,h,d], jnp.float32)
197-
dv = jnp.zeros([n,l,h,d], jnp.float32)
198-
199-
# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
200-
def f(acc, a):
201-
(do2,q2,k2,v2,o2,lse2,ix2, dq,dk,dv) = acc
202-
203-
cmp = (ix2 < ix).astype(jnp.int32) + (ix2 <= ix).astype(jnp.int32)
204-
# 0: ix < ix2
205-
# 1: ix = ix2
206-
# 2: ix > ix2
207-
if is_causal:
208-
dqa = jax.lax.switch(cmp, [
209-
lambda q,k,v: jnp.zeros([n,l,h,d], q.dtype),
210-
lambda q,k,v: _flash_mha_bwd_hlo(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1))[0],
211-
lambda q,k,v: _flash_mha_bwd_hlo(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))[0],
212-
], q, k, v)
213-
dka,dva = jax.lax.switch(cmp, [
214-
lambda q,k,v: _flash_mha_bwd_hlo(do2,q2,k,v,o2,lse2, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))[1:],
215-
lambda q,k,v: _flash_mha_bwd_hlo(do2,q2,k,v,o2,lse2, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1))[1:],
216-
lambda q,k,v: (jnp.zeros([n,l,h,d], q.dtype),jnp.zeros([n,l,h,d], q.dtype)),
217-
], q, k, v)
218-
else:
219-
dqa,_,_ = _flash_mha_bwd_hlo(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
220-
_,dka,dva = _flash_mha_bwd_hlo(do2,q2,k,v,o2,lse2, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
221-
222-
dq += dqa
223-
dk += dka
224-
dv += dva
225-
226-
(do2,q2,k2,v2,o2,lse2,ix2) = jax.lax.ppermute((do2,q2,k2,v2,o2,lse2,ix2), axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
227-
228-
return ((do2,q2,k2,v2,o2,lse2,ix2, dq,dk,dv), None)
229-
acc = (do,q,k,v,o,lse,ix,dq,dk,dv)
230-
# Unrolled as above.
231-
for _ in range(axis_size):
232-
acc, _ = f(acc, None)
233-
acc = _optimization_barrier(acc)
234-
(do2,q2,k2,v2,o2,lse2,ix2, dq,dk,dv) = acc
235-
return dq.astype(q.dtype),dk.astype(q.dtype),dv.astype(q.dtype)
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from functools import partial, wraps
2+
3+
import numpy as np
4+
import jax
5+
import jax.numpy as jnp
6+
from jax import core, dtypes
7+
from jax.core import ShapedArray
8+
from jax.interpreters import batching
9+
from jax.interpreters import mlir
10+
from jax.interpreters import xla
11+
from jax.interpreters.mlir import ir
12+
from jax.lib import xla_client
13+
from jaxlib.hlo_helpers import custom_call
14+
from jax.experimental.custom_partitioning import custom_partitioning
15+
16+
from jax.sharding import PartitionSpec as P
17+
from jax.sharding import Mesh
18+
from jax.sharding import NamedSharding
19+
from jax.sharding import PositionalSharding
20+
from jax._src.ad_checkpoint import _optimization_barrier
21+
22+
from einops import rearrange
23+
import math
24+
25+
# ==== Ring Forward ====
26+
27+
def ring_fwd(q,k,v, axis_name, axis_size, mha_fwd, softmax_scale=None, is_causal=False):
28+
[n,l,h,d] = q.shape
29+
if softmax_scale is None:
30+
softmax_scale = 1/math.sqrt(d)
31+
32+
q_ix = jax.lax.axis_index(axis_name)
33+
k_ix = jax.lax.axis_index(axis_name)
34+
35+
o = jnp.zeros([n,l,h,d], jnp.float32)
36+
lse = jnp.full([n,h,l], float('-inf'), jnp.float32)
37+
38+
# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
39+
def f(c, a):
40+
(k, v, o, lse, k_ix) = c
41+
42+
o1, lse1 = o, lse
43+
if is_causal:
44+
cmp = (k_ix < q_ix).astype(jnp.int32) + (k_ix <= q_ix).astype(jnp.int32)
45+
o2, lse2 = jax.lax.switch(cmp,
46+
[
47+
lambda: (jnp.zeros([n,l,h,d], q.dtype), jnp.full([n,h,l], float('-inf'), jnp.float32)),
48+
lambda: mha_fwd(q,k,v, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1)),
49+
lambda: mha_fwd(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1)),
50+
])
51+
else:
52+
o2, lse2 = mha_fwd(q,k,v, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
53+
o2 = o2.astype(jnp.float32)
54+
55+
mx = jnp.maximum(lse1,lse2)
56+
mn = jnp.minimum(lse1,lse2)
57+
lse = jnp.log1p(jnp.exp(mn-mx)) + mx
58+
59+
o = (o1 * rearrange(jnp.exp(lse1 - lse), 'n h l -> n l h 1') +
60+
o2 * rearrange(jnp.exp(lse2 - lse), 'n h l -> n l h 1'))
61+
62+
k2 = jax.lax.ppermute(k, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
63+
v2 = jax.lax.ppermute(v, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
64+
k_ix = jax.lax.ppermute(k_ix, axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
65+
66+
return ((k2, v2, o, lse, k_ix), None)
67+
acc = (k,v,o,lse,k_ix)
68+
# Manually unroll this until https://github.com/google/jax/pull/20884 is merged.
69+
# Optimization barrier prevents instruction reordering across loop iters, so that
70+
# ppermute and flash_mha execute concurrently (though this is unreliable).
71+
for _ in range(axis_size):
72+
acc, _ = f(acc, None)
73+
acc = _optimization_barrier(acc)
74+
(_,_,o,lse,_) = acc
75+
# (_,_,o,lse,_), _ = jax.lax.scan(f,acc,None,axis_size)
76+
return o.astype(q.dtype), lse
77+
78+
# ==== Ring Backward ===
79+
80+
def ring_bwd(do,q,k,v,o,lse, axis_name, axis_size, mha_bwd, softmax_scale=None, is_causal=False):
81+
[n,l,h,d] = q.shape
82+
if softmax_scale is None:
83+
softmax_scale = 1/math.sqrt(d)
84+
85+
ix = jax.lax.axis_index(axis_name)
86+
87+
dq = jnp.zeros([n,l,h,d], jnp.float32)
88+
dk = jnp.zeros([n,l,h,d], jnp.float32)
89+
dv = jnp.zeros([n,l,h,d], jnp.float32)
90+
91+
# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
92+
def f(acc, _):
93+
(k2,v2,dk2,dv2,ix2, dq) = acc
94+
95+
cmp = (ix2 < ix).astype(jnp.int32) + (ix2 <= ix).astype(jnp.int32)
96+
# 0: ix < ix2
97+
# 1: ix = ix2
98+
# 2: ix > ix2
99+
if is_causal:
100+
dqa, dka, dva = jax.lax.switch(cmp, (
101+
lambda: (jnp.zeros(q.shape, q.dtype), jnp.zeros(k.shape, k.dtype), jnp.zeros(v.shape, v.dtype)),
102+
lambda: mha_bwd(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=True, window_size=(-1,-1)),
103+
lambda: mha_bwd(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
104+
))
105+
else:
106+
dqa, dka, dva = mha_bwd(do,q,k2,v2,o,lse, softmax_scale=softmax_scale, is_causal=False, window_size=(-1,-1))
107+
108+
dq += dqa
109+
dk2 += dka
110+
dv2 += dva
111+
112+
(k2,v2,dk2,dv2,ix2) = jax.lax.ppermute((k2,v2,dk2,dv2,ix2), axis_name, [(i, (i+1)%axis_size) for i in range(axis_size)])
113+
114+
return ((k2,v2,dk2,dv2,ix2, dq), None)
115+
acc = (k,v,dk,dv,ix, dq)
116+
# See above (#20884).
117+
for _ in range(axis_size):
118+
acc, _ = f(acc, None)
119+
acc = _optimization_barrier(acc)
120+
(k,v,dk,dv,ix2, dq) = acc
121+
return dq.astype(q.dtype),dk.astype(q.dtype),dv.astype(q.dtype)

0 commit comments

Comments
 (0)