22
22
import math
23
23
24
24
from .flash_hlo import _flash_mha_fwd_hlo , _flash_mha_bwd_hlo
25
+ from .ring_attention import ring_fwd , ring_bwd
25
26
26
27
# ==== Sharding ====
27
28
30
31
31
32
from jax ._src .ad_checkpoint import _optimization_barrier
32
33
34
+ def is_replicated (sharding ):
35
+ return (isinstance (sharding , PositionalSharding ) and sharding .shape == (1 ,)) or (isinstance (sharding , NamedSharding ) and len (sharding .spec ) == 0 )
36
+
33
37
def partition_fwd (softmax_scale , is_causal , window_size , mesh , arg_shapes , result_shape ):
34
38
result_shardings = jax .tree_map (lambda x : x .sharding , result_shape )
35
39
arg_shardings = jax .tree_map (lambda x : x .sharding , arg_shapes )
36
40
37
41
q_sharding = arg_shardings [0 ]
38
- if isinstance (q_sharding , PositionalSharding ):
42
+ k_sharding = arg_shardings [1 ]
43
+ v_sharding = arg_shardings [2 ]
44
+ assert q_sharding == k_sharding and q_sharding == v_sharding , "Only support q, k, v sharing the same sharding."
45
+ if is_replicated (q_sharding ):
46
+ result_sharding = (q_sharding , q_sharding )
47
+ elif isinstance (q_sharding , PositionalSharding ):
39
48
(n ,l ,h ,d ) = q_sharding .shape
40
49
assert d == 1 , "Sharding across `d` won't be efficient, so it's not supported."
41
50
assert l == 1 , "For ring attention, use `with Mesh(...) as mesh` and NamedSharding."
@@ -53,7 +62,7 @@ def partition_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, resul
53
62
axis_name = l
54
63
axis_size = mesh .shape [axis_name ]
55
64
# ring attention
56
- return mesh , partial (ring_fwd , softmax_scale , is_causal , axis_name , axis_size ), result_shardings , arg_shardings
65
+ return mesh , partial (ring_fwd , softmax_scale = softmax_scale , is_causal = is_causal , axis_name = axis_name , axis_size = axis_size , mha_fwd = _flash_mha_fwd_hlo ), result_shardings , arg_shardings
57
66
else :
58
67
result_shardings = q_sharding , NamedSharding (mesh , P (n ,h ,l ))
59
68
arg_shardings = q_sharding , q_sharding , q_sharding
@@ -64,7 +73,12 @@ def fwd(q,k,v):
64
73
def infer_sharding_fwd (softmax_scale , is_causal , window_size , mesh , arg_shapes , result_shape ):
65
74
arg_shardings = jax .tree_map (lambda x : x .sharding , arg_shapes )
66
75
q_sharding = arg_shardings [0 ]
67
- if isinstance (q_sharding , PositionalSharding ):
76
+ k_sharding = arg_shardings [1 ]
77
+ v_sharding = arg_shardings [2 ]
78
+ assert q_sharding == k_sharding and q_sharding == v_sharding , "Only support q, k, v sharing the same sharding."
79
+ if is_replicated (q_sharding ):
80
+ result_sharding = (q_sharding , q_sharding )
81
+ elif isinstance (q_sharding , PositionalSharding ):
68
82
[n ,l ,h ,d ] = q_sharding .shape
69
83
result_sharding = (q_sharding , # [n,l,h,d]
70
84
q_sharding .replicate (3 ).reshape (n ,l ,h ).transpose ((0 ,2 ,1 )) # [n,h,l]
@@ -73,6 +87,8 @@ def infer_sharding_fwd(softmax_scale, is_causal, window_size, mesh, arg_shapes,
73
87
[n ,l ,h ,d ] = q_sharding .spec
74
88
result_sharding = (q_sharding ,
75
89
NamedSharding (q_sharding .mesh , P (n ,h ,l )))
90
+ else :
91
+ raise ValueError ("Unsupported sharding type." , type (q_sharding ))
76
92
return result_sharding
77
93
78
94
_flash_mha_fwd_hlo_sharded .def_partition (
@@ -99,7 +115,10 @@ def partition_bwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, resul
99
115
v_sharding = arg_shardings [3 ]
100
116
o_sharding = arg_shardings [4 ]
101
117
lse_sharding = arg_shardings [5 ]
102
- if isinstance (q_sharding , PositionalSharding ):
118
+ assert q_sharding == k_sharding and q_sharding == v_sharding , "Only support q, k, v sharing the same sharding."
119
+ if is_replicated (q_sharding ):
120
+ result_shardings = (q_sharding ,)* 3
121
+ elif isinstance (q_sharding , PositionalSharding ):
103
122
assert q_sharding == k_sharding , "Expect q and k sharding to match"
104
123
assert q_sharding == v_sharding , "Expect q and v sharding to match"
105
124
[n , l , h , d ] = q_sharding .shape
@@ -121,7 +140,7 @@ def partition_bwd(softmax_scale, is_causal, window_size, mesh, arg_shapes, resul
121
140
axis_name = l
122
141
axis_size = mesh .shape [axis_name ]
123
142
# ring attention
124
- return mesh , partial (ring_bwd , softmax_scale , is_causal , axis_name , axis_size ), result_shardings , arg_shardings
143
+ return mesh , partial (ring_bwd , softmax_scale = softmax_scale , is_causal = is_causal , axis_name = axis_name , axis_size = axis_size , mha_bwd = _flash_mha_bwd_hlo ), result_shardings , arg_shardings
125
144
else :
126
145
result_shardings = q_sharding , q_sharding , q_sharding
127
146
lse_sharding = NamedSharding (mesh , P (n ,h ,l ))
@@ -133,103 +152,3 @@ def fwd(*args):
133
152
_flash_mha_bwd_hlo_sharded .def_partition (
134
153
infer_sharding_from_operands = infer_sharding_bwd ,
135
154
partition = partition_bwd )
136
-
137
- # ==== Ring Forward ====
138
-
139
- def ring_fwd (softmax_scale , is_causal , axis_name , axis_size , q ,k ,v ):
140
- [n ,l ,h ,d ] = q .shape
141
-
142
- q_ix = jax .lax .axis_index (axis_name )
143
- k_ix = jax .lax .axis_index (axis_name )
144
-
145
- o = jnp .zeros ([n ,l ,h ,d ], jnp .float32 )
146
- lse = jnp .full ([n ,h ,l ], float ('-inf' ), jnp .float32 )
147
-
148
- # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
149
- def f (c , a ):
150
- (k , v , o , lse , k_ix ) = c
151
-
152
- o1 , lse1 = o , lse
153
- if is_causal :
154
- o2 , lse2 = jax .lax .switch ((k_ix < q_ix ).astype (jnp .int32 ) + (k_ix <= q_ix ).astype (jnp .int32 ),
155
- [
156
- lambda q ,k ,v : (jnp .zeros ([n ,l ,h ,d ], q .dtype ), jnp .full ([n ,h ,l ], float ('-inf' ), jnp .float32 )),
157
- lambda q ,k ,v : _flash_mha_fwd_hlo (q ,k ,v , softmax_scale = softmax_scale , is_causal = True , window_size = (- 1 ,- 1 )),
158
- lambda q ,k ,v : _flash_mha_fwd_hlo (q ,k ,v , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 )),
159
- ], q , k , v )
160
- else :
161
- o2 , lse2 = _flash_mha_fwd_hlo (q ,k ,v , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 ))
162
- o2 = o2 .astype (jnp .float32 )
163
-
164
- mx = jnp .maximum (lse1 ,lse2 )
165
- mn = jnp .minimum (lse1 ,lse2 )
166
- lse = jnp .log1p (jnp .exp (mn - mx )) + mx
167
-
168
- o = (o1 * rearrange (jnp .exp (lse1 - lse ), 'n h l -> n l h 1' ) +
169
- o2 * rearrange (jnp .exp (lse2 - lse ), 'n h l -> n l h 1' ))
170
-
171
- k2 = jax .lax .ppermute (k , axis_name , [(i , (i + 1 )% axis_size ) for i in range (axis_size )])
172
- v2 = jax .lax .ppermute (v , axis_name , [(i , (i + 1 )% axis_size ) for i in range (axis_size )])
173
- k_ix = jax .lax .ppermute (k_ix , axis_name , [(i , (i + 1 )% axis_size ) for i in range (axis_size )])
174
-
175
- return ((k2 , v2 , o , lse , k_ix ), None )
176
- acc = (k ,v ,o ,lse ,k_ix )
177
- # We sadly have to manually unroll this because scan breaks the axis context preventing us from using ppermute (unroll=axis_size doesn't help either).
178
- # Optimization barrier prevents instruction reordering so that ppermute and flash_mha execute concurrently.
179
- for _ in range (axis_size ):
180
- acc , _ = f (acc , None )
181
- acc = _optimization_barrier (acc )
182
- (_ ,_ ,o ,lse ,_ ) = acc
183
- # (_,_,o,lse), _ = jax.lax.scan(f,init,None,axis_size)
184
- return o .astype (q .dtype ), lse
185
-
186
- # ==== Ring Backward ===
187
-
188
- # This doesn't seem like the most efficient way to do this, kind of wasting compute by calculating every dq,dk,dv twice.
189
- # Should we send the accumulator for dk,dv cross-device instead? Relying on the fact that after a full cycle, they return to the starting device.
190
- def ring_bwd (softmax_scale , is_causal , axis_name , axis_size , do ,q ,k ,v ,o ,lse ):
191
- [n ,l ,h ,d ] = q .shape
192
-
193
- ix = jax .lax .axis_index (axis_name )
194
-
195
- dq = jnp .zeros ([n ,l ,h ,d ], jnp .float32 )
196
- dk = jnp .zeros ([n ,l ,h ,d ], jnp .float32 )
197
- dv = jnp .zeros ([n ,l ,h ,d ], jnp .float32 )
198
-
199
- # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
200
- def f (acc , a ):
201
- (do2 ,q2 ,k2 ,v2 ,o2 ,lse2 ,ix2 , dq ,dk ,dv ) = acc
202
-
203
- cmp = (ix2 < ix ).astype (jnp .int32 ) + (ix2 <= ix ).astype (jnp .int32 )
204
- # 0: ix < ix2
205
- # 1: ix = ix2
206
- # 2: ix > ix2
207
- if is_causal :
208
- dqa = jax .lax .switch (cmp , [
209
- lambda q ,k ,v : jnp .zeros ([n ,l ,h ,d ], q .dtype ),
210
- lambda q ,k ,v : _flash_mha_bwd_hlo (do ,q ,k2 ,v2 ,o ,lse , softmax_scale = softmax_scale , is_causal = True , window_size = (- 1 ,- 1 ))[0 ],
211
- lambda q ,k ,v : _flash_mha_bwd_hlo (do ,q ,k2 ,v2 ,o ,lse , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 ))[0 ],
212
- ], q , k , v )
213
- dka ,dva = jax .lax .switch (cmp , [
214
- lambda q ,k ,v : _flash_mha_bwd_hlo (do2 ,q2 ,k ,v ,o2 ,lse2 , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 ))[1 :],
215
- lambda q ,k ,v : _flash_mha_bwd_hlo (do2 ,q2 ,k ,v ,o2 ,lse2 , softmax_scale = softmax_scale , is_causal = True , window_size = (- 1 ,- 1 ))[1 :],
216
- lambda q ,k ,v : (jnp .zeros ([n ,l ,h ,d ], q .dtype ),jnp .zeros ([n ,l ,h ,d ], q .dtype )),
217
- ], q , k , v )
218
- else :
219
- dqa ,_ ,_ = _flash_mha_bwd_hlo (do ,q ,k2 ,v2 ,o ,lse , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 ))
220
- _ ,dka ,dva = _flash_mha_bwd_hlo (do2 ,q2 ,k ,v ,o2 ,lse2 , softmax_scale = softmax_scale , is_causal = False , window_size = (- 1 ,- 1 ))
221
-
222
- dq += dqa
223
- dk += dka
224
- dv += dva
225
-
226
- (do2 ,q2 ,k2 ,v2 ,o2 ,lse2 ,ix2 ) = jax .lax .ppermute ((do2 ,q2 ,k2 ,v2 ,o2 ,lse2 ,ix2 ), axis_name , [(i , (i + 1 )% axis_size ) for i in range (axis_size )])
227
-
228
- return ((do2 ,q2 ,k2 ,v2 ,o2 ,lse2 ,ix2 , dq ,dk ,dv ), None )
229
- acc = (do ,q ,k ,v ,o ,lse ,ix ,dq ,dk ,dv )
230
- # Unrolled as above.
231
- for _ in range (axis_size ):
232
- acc , _ = f (acc , None )
233
- acc = _optimization_barrier (acc )
234
- (do2 ,q2 ,k2 ,v2 ,o2 ,lse2 ,ix2 , dq ,dk ,dv ) = acc
235
- return dq .astype (q .dtype ),dk .astype (q .dtype ),dv .astype (q .dtype )
0 commit comments