10
10
from jax .interpreters import xla
11
11
from jax .interpreters .mlir import ir
12
12
from jax .lib import xla_client
13
- from jaxlib .hlo_helpers import custom_call
14
13
from jax .experimental .custom_partitioning import custom_partitioning
15
14
16
15
from jax .sharding import PartitionSpec as P
19
18
from jax .sharding import PositionalSharding
20
19
21
20
from einops import rearrange
21
+ import einops
22
22
import math
23
23
24
24
import flash_attn_jax_lib .flash_api as flash_api
33
33
_flash_mha_bwd_hlo_p .multiple_results = True
34
34
_flash_mha_bwd_hlo_p .def_impl (partial (xla .apply_primitive , _flash_mha_bwd_hlo_p ))
35
35
36
+ _custom_call_p = core .Primitive ("custom_call" )
37
+ _custom_call_p .multiple_results = True
38
+ _custom_call_p .def_impl (partial (xla .apply_primitive , _custom_call_p ))
39
+
36
40
# ==== Primitive wrapper ====
37
41
38
42
def _flash_mha_fwd_hlo (q , k , v , softmax_scale , is_causal , window_size ):
@@ -43,6 +47,9 @@ def _flash_mha_bwd_hlo(dout, q, k, v, out, lse, softmax_scale, is_causal, window
43
47
dq , dk , dv = _flash_mha_bwd_hlo_p .bind (dout , q , k , v , out , lse , softmax_scale = softmax_scale , is_causal = is_causal , window_size = window_size )
44
48
return dq , dk , dv
45
49
50
+ def custom_call (* args , call_target_name , result_types , backend_config , operand_layouts , result_layouts ):
51
+ return _custom_call_p .bind (* args , call_target_name = call_target_name , result_types = result_types , backend_config = backend_config , operand_layouts = operand_layouts , result_layouts = result_layouts )
52
+
46
53
# ==== HLO lowerings ====
47
54
48
55
# Register functions defined in gpu_ops as custom call target for GPUs
@@ -112,7 +119,7 @@ def _flash_mha_fwd_hlo_lowering(ctx, q, k, v, softmax_scale=None, is_causal=Fals
112
119
113
120
out_types = [ir .RankedTensorType .get (o_shape , element_type ), lse_type ]
114
121
115
- (o , lse ) = custom_call (
122
+ (o , lse ) = mlir . custom_call (
116
123
b"flash_mha_fwd" ,
117
124
result_types = out_types ,
118
125
operands = [q_padded , k_padded , v_padded ],
@@ -125,7 +132,7 @@ def _flash_mha_fwd_hlo_lowering(ctx, q, k, v, softmax_scale=None, is_causal=Fals
125
132
return (o ,lse )
126
133
else :
127
134
out_types = [ir .RankedTensorType .get ([n , l , h , d ], element_type ), lse_type ]
128
- out = custom_call (
135
+ out = mlir . custom_call (
129
136
b"flash_mha_fwd" ,
130
137
result_types = out_types ,
131
138
operands = [q , k , v ],
@@ -155,6 +162,7 @@ def _flash_mha_bwd_hlo_lowering(ctx, dout, q, k, v, out, lse, softmax_scale=None
155
162
assert q_type == v_type
156
163
assert q_type == out_type
157
164
assert type (lse_type ) in [ir .F32Type ]
165
+ dtype = q_type
158
166
159
167
dout_shape = ir .RankedTensorType (dout .type ).shape
160
168
q_shape = ir .RankedTensorType (q .type ).shape
@@ -184,49 +192,45 @@ def _flash_mha_bwd_hlo_lowering(ctx, dout, q, k, v, out, lse, softmax_scale=None
184
192
flash_api .BF16 if type (q_type ) == ir .BF16Type else flash_api .FP16 ,
185
193
0 )
186
194
187
- if d % 8 != 0 :
188
- # We need padding. It's better to let xla's allocator handle it here than directly call cudaMalloc.
189
- dpad = 8 - d % 8
190
-
191
- z = np .array (0.0 , dtype = ir_type_to_dtype (q_type ))
192
- z = mlir .ir_constant (z )
193
- q_padded = mlir .hlo .PadOp (q ,z ,[0 ,0 ,0 ,0 ],[0 ,0 ,0 ,dpad ],[0 ,0 ,0 ,0 ]).result
194
- k_padded = mlir .hlo .PadOp (k ,z ,[0 ,0 ,0 ,0 ],[0 ,0 ,0 ,dpad ],[0 ,0 ,0 ,0 ]).result
195
- v_padded = mlir .hlo .PadOp (v ,z ,[0 ,0 ,0 ,0 ],[0 ,0 ,0 ,dpad ],[0 ,0 ,0 ,0 ]).result
196
- out_padded = mlir .hlo .PadOp (out ,z ,[0 ,0 ,0 ,0 ],[0 ,0 ,0 ,dpad ],[0 ,0 ,0 ,0 ]).result
197
- dout_padded = mlir .hlo .PadOp (dout ,z ,[0 ,0 ,0 ,0 ],[0 ,0 ,0 ,dpad ],[0 ,0 ,0 ,0 ]).result
198
-
199
- # Outputs are the same shape as the q,k,v (including padding)
200
- out_types = [q_padded .type , k_padded .type , v_padded .type ]
195
+ def fwd (dout , q , k , v , out , lse ):
196
+ dpad = (8 - d % 8 ) % 8
197
+ if dpad > 0 :
198
+ # We need padding. It's better to let xla's allocator handle it here than directly call cudaMalloc.
199
+ q = jnp .pad (q , ((0 ,0 ),(0 ,0 ),(0 ,0 ),(0 ,dpad )), 'constant' )
200
+ k = jnp .pad (k , ((0 ,0 ),(0 ,0 ),(0 ,0 ),(0 ,dpad )), 'constant' )
201
+ v = jnp .pad (v , ((0 ,0 ),(0 ,0 ),(0 ,0 ),(0 ,dpad )), 'constant' )
202
+ out = jnp .pad (out , ((0 ,0 ),(0 ,0 ),(0 ,0 ),(0 ,dpad )), 'constant' )
203
+ dout = jnp .pad (dout , ((0 ,0 ),(0 ,0 ),(0 ,0 ),(0 ,dpad )), 'constant' )
204
+
205
+ # For MQA/GQA, hq != hk, but we pass a hq sized output tensor to the kernel and sum over it afterwards to reduce the size.
206
+ out_types = [ir .RankedTensorType .get ([n , lq , hq , d + dpad ], dtype ),
207
+ ir .RankedTensorType .get ([n , lk , hq , d + dpad ], dtype ),
208
+ ir .RankedTensorType .get ([n , lk , hq , d + dpad ], dtype )]
209
+ out_layouts = default_layouts ([n , lq , hq , d + dpad ], [n , lk , hq , d + dpad ], [n , lk , hq , d + dpad ])
201
210
202
211
dq , dk , dv = custom_call (
203
- b"flash_mha_bwd" ,
204
- result_types = out_types ,
205
- operands = [ dout_padded , q_padded , k_padded , v_padded , out_padded , lse ] ,
212
+ dout , q , k , v , out , lse ,
213
+ call_target_name = b"flash_mha_bwd" ,
214
+ operand_layouts = default_layouts ( dout . shape , q . shape , k . shape , v . shape , out . shape , lse . shape ) ,
206
215
backend_config = opaque ,
207
- operand_layouts = value_layouts (dout_padded , q_padded , k_padded , v_padded , out_padded , lse ),
208
- result_layouts = value_layouts (q_padded , k_padded , v_padded ), # dq, dk, dv
209
- ).results
210
-
211
- dq = mlir .hlo .SliceOp (dq , [0 ,0 ,0 ,0 ], tuple (q_shape ), [1 ,1 ,1 ,1 ]).result
212
- dk = mlir .hlo .SliceOp (dk , [0 ,0 ,0 ,0 ], tuple (k_shape ), [1 ,1 ,1 ,1 ]).result
213
- dv = mlir .hlo .SliceOp (dv , [0 ,0 ,0 ,0 ], tuple (v_shape ), [1 ,1 ,1 ,1 ]).result
216
+ result_types = out_types ,
217
+ result_layouts = out_layouts ,
218
+ )
219
+
220
+ if hq != hk :
221
+ assert hq > hk and hq % hk == 0
222
+ m = hq // hk
223
+ dk = einops .reduce (dk , 'n l (h m) d -> n l h d' , reduction = 'sum' , h = hk )
224
+ dv = einops .reduce (dv , 'n l (h m) d -> n l h d' , reduction = 'sum' , h = hk )
225
+
226
+ if dpad > 0 :
227
+ dq = dq [:,:,:,:d ]
228
+ dk = dk [:,:,:,:d ]
229
+ dv = dv [:,:,:,:d ]
214
230
215
231
return dq , dk , dv
216
- else :
217
- out_types = [ir .RankedTensorType .get (q_shape , q_type ),
218
- ir .RankedTensorType .get (k_shape , k_type ),
219
- ir .RankedTensorType .get (v_shape , v_type )]
220
-
221
- out = custom_call (
222
- b"flash_mha_bwd" ,
223
- result_types = out_types ,
224
- operands = [dout , q , k , v , out , lse ],
225
- backend_config = opaque ,
226
- operand_layouts = default_layouts (dout_shape , q_shape , k_shape , v_shape , out_shape , lse_shape ),
227
- result_layouts = default_layouts (* [o .shape for o in out_types ]),
228
- ).results
229
- return out
232
+
233
+ return mlir .lower_fun (fwd , multiple_results = True )(ctx , dout , q , k , v , out , lse )
230
234
231
235
mlir .register_lowering (
232
236
_flash_mha_bwd_hlo_p ,
@@ -266,3 +270,32 @@ def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, is_caus
266
270
ShapedArray (v .shape , v_dtype , named_shape = v .named_shape ),
267
271
)
268
272
_flash_mha_bwd_hlo_p .def_abstract_eval (_flash_mha_bwd_abstract )
273
+
274
+ # ==== Custom Call ====
275
+
276
+ def _custom_call_abstract_eval (* args , call_target_name , result_types , backend_config , operand_layouts , result_layouts ):
277
+ def convert (ty ):
278
+ ty = ir .RankedTensorType (ty )
279
+ shape = tuple (ty .shape )
280
+ dtype = ir_type_to_dtype (ty .element_type )
281
+ return ShapedArray (shape , dtype )
282
+ out_types = [convert (o ) for o in result_types ]
283
+ return tuple (out_types )
284
+
285
+ _custom_call_p .def_abstract_eval (_custom_call_abstract_eval )
286
+
287
+ def _custom_call_hlo_lowering (ctx , * args , call_target_name , result_types , backend_config , operand_layouts , result_layouts ):
288
+ out = mlir .custom_call (
289
+ call_target_name ,
290
+ operands = args ,
291
+ result_types = result_types ,
292
+ backend_config = backend_config ,
293
+ operand_layouts = operand_layouts ,
294
+ result_layouts = result_layouts ,
295
+ ).results
296
+ return out
297
+
298
+ mlir .register_lowering (
299
+ _custom_call_p ,
300
+ _custom_call_hlo_lowering
301
+ )
0 commit comments