5
5
import jax .numpy as jnp
6
6
from jax import core , dtypes
7
7
from jax .core import ShapedArray
8
- from jax .interpreters import batching
9
8
from jax .interpreters import mlir
10
9
from jax .interpreters import xla
11
10
from jax .interpreters .mlir import ir
12
- from jax .lib import xla_client
13
- from jax .experimental .custom_partitioning import custom_partitioning
14
11
15
12
from jax .extend .core import Primitive
16
13
20
17
21
18
import flash_attn_jax_lib .flash_api as flash_api
22
19
20
+ # jax.ffi.ffi_call()
21
+
23
22
# ==== Register primitives ====
24
23
25
24
_flash_mha_fwd_hlo_p = Primitive ("flash_mha_fwd_hlo" )
30
29
_flash_mha_bwd_hlo_p .multiple_results = True
31
30
_flash_mha_bwd_hlo_p .def_impl (partial (xla .apply_primitive , _flash_mha_bwd_hlo_p ))
32
31
33
- _custom_call_p = Primitive ("custom_call" )
34
- _custom_call_p .multiple_results = True
35
- _custom_call_p .def_impl (partial (xla .apply_primitive , _custom_call_p ))
36
-
37
32
# ==== Primitive wrapper ====
38
33
39
34
def _flash_mha_fwd_hlo (q , k , v , softmax_scale , is_causal , window_size ):
@@ -44,18 +39,10 @@ def _flash_mha_bwd_hlo(dout, q, k, v, out, lse, softmax_scale, is_causal, window
44
39
dq , dk , dv = _flash_mha_bwd_hlo_p .bind (dout , q , k , v , out , lse , softmax_scale = softmax_scale , is_causal = is_causal , window_size = window_size )
45
40
return dq , dk , dv
46
41
47
- def custom_call (* args , call_target_name , result_types , backend_config , operand_layouts , result_layouts ):
48
- return _custom_call_p .bind (* args , call_target_name = call_target_name ,
49
- result_types = tuple (result_types ),
50
- backend_config = backend_config ,
51
- operand_layouts = tuple (operand_layouts ),
52
- result_layouts = tuple (result_layouts ))
53
-
54
42
# ==== HLO lowerings ====
55
43
56
44
# Register functions defined in gpu_ops as custom call target for GPUs
57
45
for _name , _value in flash_api .get_registrations ().items ():
58
- # xla_client.register_custom_call_target(_name, _value, platform="gpu")
59
46
jax .ffi .register_ffi_target (_name , _value , platform = "gpu" , api_version = 0 )
60
47
61
48
def default_layouts (* shapes ):
@@ -109,26 +96,24 @@ def fwd(q, k, v):
109
96
k = jnp .pad (k , ((0 ,0 ),(0 ,0 ),(0 ,0 ),(0 ,dpad )), 'constant' )
110
97
v = jnp .pad (v , ((0 ,0 ),(0 ,0 ),(0 ,0 ),(0 ,dpad )), 'constant' )
111
98
112
- q_shape = [n , l , h , d + dpad ]
113
- k_shape = [n , lk , hk , d + dpad ]
114
- v_shape = [n , lk , hk , d + dpad ]
99
+ # q_shape = [n, l, h, d+dpad]
100
+ # k_shape = [n, lk, hk, d+dpad]
101
+ # v_shape = [n, lk, hk, d+dpad]
115
102
o_shape = [n , l , h , d + dpad ]
116
103
lse_shape = [n , h , l ]
117
-
118
104
119
- lse_type = ir .RankedTensorType .get ([n , h , l ], mlir .dtype_to_ir_type (jnp .float32 .dtype ))
120
- out_types = [ir .RankedTensorType .get (o_shape , element_type ), lse_type ]
121
- operand_layouts = default_layouts (q_shape , k_shape , v_shape )
122
- result_layouts = default_layouts (o_shape , lse_shape )
123
-
124
- o , lse = custom_call (
125
- q , k , v ,
126
- call_target_name = b"flash_mha_fwd" ,
127
- result_types = out_types ,
128
- backend_config = opaque ,
129
- operand_layouts = operand_layouts ,
130
- result_layouts = result_layouts ,
131
- )
105
+ jax_dtype = jnp .bfloat16 if type (element_type ) == ir .BF16Type else jnp .float16
106
+ out_types = [jax .ShapeDtypeStruct (o_shape , jax_dtype ), jax .ShapeDtypeStruct (lse_shape , jnp .float32 )]
107
+
108
+ o , lse = jax .ffi .ffi_call (
109
+ "flash_mha_fwd" ,
110
+ result_shape_dtypes = out_types ,
111
+ has_side_effect = False ,
112
+ legacy_backend_config = opaque ,
113
+ input_layouts = [None , None , None ], # default row major
114
+ output_layouts = [None , None ],
115
+ custom_call_api_version = 1
116
+ )(q , k , v )
132
117
133
118
if dpad > 0 :
134
119
o = o [:,:,:,:d ]
@@ -197,19 +182,20 @@ def fwd(dout, q, k, v, out, lse):
197
182
dout = jnp .pad (dout , ((0 ,0 ),(0 ,0 ),(0 ,0 ),(0 ,dpad )), 'constant' )
198
183
199
184
# For MQA/GQA, hq != hk, but we pass a hq sized output tensor to the kernel and sum over it afterwards to reduce the size.
200
- out_types = [ir .RankedTensorType .get ([n , lq , hq , d + dpad ], dtype ),
201
- ir .RankedTensorType .get ([n , lk , hq , d + dpad ], dtype ),
202
- ir .RankedTensorType .get ([n , lk , hq , d + dpad ], dtype )]
203
- out_layouts = default_layouts ([n , lq , hq , d + dpad ], [n , lk , hq , d + dpad ], [n , lk , hq , d + dpad ])
204
-
205
- dq , dk , dv = custom_call (
206
- dout , q , k , v , out , lse ,
207
- call_target_name = b"flash_mha_bwd" ,
208
- operand_layouts = default_layouts (dout .shape , q .shape , k .shape , v .shape , out .shape , lse .shape ),
209
- backend_config = opaque ,
210
- result_types = out_types ,
211
- result_layouts = out_layouts ,
212
- )
185
+ jax_dtype = jnp .bfloat16 if type (dtype ) == ir .BF16Type else jnp .float16
186
+ out_types = [jax .ShapeDtypeStruct ((n , lq , hq , d + dpad ), jax_dtype ),
187
+ jax .ShapeDtypeStruct ((n , lk , hq , d + dpad ), jax_dtype ),
188
+ jax .ShapeDtypeStruct ((n , lk , hq , d + dpad ), jax_dtype )]
189
+
190
+ dq , dk , dv = jax .ffi .ffi_call (
191
+ "flash_mha_bwd" ,
192
+ result_shape_dtypes = out_types ,
193
+ has_side_effect = False ,
194
+ legacy_backend_config = opaque ,
195
+ input_layouts = [None ]* 6 , # default row major
196
+ output_layouts = [None ]* 3 ,
197
+ custom_call_api_version = 1
198
+ )(dout , q , k , v , out , lse )
213
199
214
200
if hq != hk :
215
201
assert hq > hk and hq % hk == 0
@@ -264,32 +250,3 @@ def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, is_caus
264
250
ShapedArray (v .shape , v_dtype ),
265
251
)
266
252
_flash_mha_bwd_hlo_p .def_abstract_eval (_flash_mha_bwd_abstract )
267
-
268
- # ==== Custom Call ====
269
-
270
- def _custom_call_abstract_eval (* args , call_target_name , result_types , backend_config , operand_layouts , result_layouts ):
271
- def convert (ty ):
272
- ty = ir .RankedTensorType (ty )
273
- shape = tuple (ty .shape )
274
- dtype = ir_type_to_dtype (ty .element_type )
275
- return ShapedArray (shape , dtype )
276
- out_types = [convert (o ) for o in result_types ]
277
- return tuple (out_types )
278
-
279
- _custom_call_p .def_abstract_eval (_custom_call_abstract_eval )
280
-
281
- def _custom_call_hlo_lowering (ctx , * args , call_target_name , result_types , backend_config , operand_layouts , result_layouts ):
282
- out = mlir .custom_call (
283
- call_target_name ,
284
- operands = args ,
285
- result_types = list (result_types ),
286
- backend_config = backend_config ,
287
- operand_layouts = list (operand_layouts ),
288
- result_layouts = list (result_layouts ),
289
- ).results
290
- return out
291
-
292
- mlir .register_lowering (
293
- _custom_call_p ,
294
- _custom_call_hlo_lowering
295
- )
0 commit comments