Skip to content

Commit a037f7a

Browse files
committed
Refactor flash HLO primitives to use jax.ffi.ffi_call and remove custom call implementation
1 parent a224720 commit a037f7a

File tree

1 file changed

+31
-74
lines changed

1 file changed

+31
-74
lines changed

src/flash_attn_jax/flash_hlo.py

Lines changed: 31 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,9 @@
55
import jax.numpy as jnp
66
from jax import core, dtypes
77
from jax.core import ShapedArray
8-
from jax.interpreters import batching
98
from jax.interpreters import mlir
109
from jax.interpreters import xla
1110
from jax.interpreters.mlir import ir
12-
from jax.lib import xla_client
13-
from jax.experimental.custom_partitioning import custom_partitioning
1411

1512
from jax.extend.core import Primitive
1613

@@ -20,6 +17,8 @@
2017

2118
import flash_attn_jax_lib.flash_api as flash_api
2219

20+
# jax.ffi.ffi_call()
21+
2322
# ==== Register primitives ====
2423

2524
_flash_mha_fwd_hlo_p = Primitive("flash_mha_fwd_hlo")
@@ -30,10 +29,6 @@
3029
_flash_mha_bwd_hlo_p.multiple_results = True
3130
_flash_mha_bwd_hlo_p.def_impl(partial(xla.apply_primitive, _flash_mha_bwd_hlo_p))
3231

33-
_custom_call_p = Primitive("custom_call")
34-
_custom_call_p.multiple_results = True
35-
_custom_call_p.def_impl(partial(xla.apply_primitive, _custom_call_p))
36-
3732
# ==== Primitive wrapper ====
3833

3934
def _flash_mha_fwd_hlo(q, k, v, softmax_scale, is_causal, window_size):
@@ -44,18 +39,10 @@ def _flash_mha_bwd_hlo(dout, q, k, v, out, lse, softmax_scale, is_causal, window
4439
dq, dk, dv = _flash_mha_bwd_hlo_p.bind(dout, q, k, v, out, lse, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size)
4540
return dq, dk, dv
4641

47-
def custom_call(*args, call_target_name, result_types, backend_config, operand_layouts, result_layouts):
48-
return _custom_call_p.bind(*args, call_target_name=call_target_name,
49-
result_types=tuple(result_types),
50-
backend_config=backend_config,
51-
operand_layouts=tuple(operand_layouts),
52-
result_layouts=tuple(result_layouts))
53-
5442
# ==== HLO lowerings ====
5543

5644
# Register functions defined in gpu_ops as custom call target for GPUs
5745
for _name, _value in flash_api.get_registrations().items():
58-
# xla_client.register_custom_call_target(_name, _value, platform="gpu")
5946
jax.ffi.register_ffi_target(_name, _value, platform="gpu", api_version=0)
6047

6148
def default_layouts(*shapes):
@@ -109,26 +96,24 @@ def fwd(q, k, v):
10996
k = jnp.pad(k, ((0,0),(0,0),(0,0),(0,dpad)), 'constant')
11097
v = jnp.pad(v, ((0,0),(0,0),(0,0),(0,dpad)), 'constant')
11198

112-
q_shape = [n, l, h, d+dpad]
113-
k_shape = [n, lk, hk, d+dpad]
114-
v_shape = [n, lk, hk, d+dpad]
99+
# q_shape = [n, l, h, d+dpad]
100+
# k_shape = [n, lk, hk, d+dpad]
101+
# v_shape = [n, lk, hk, d+dpad]
115102
o_shape = [n, l, h, d+dpad]
116103
lse_shape = [n, h, l]
117-
118104

119-
lse_type = ir.RankedTensorType.get([n, h, l], mlir.dtype_to_ir_type(jnp.float32.dtype))
120-
out_types = [ir.RankedTensorType.get(o_shape, element_type), lse_type]
121-
operand_layouts = default_layouts(q_shape, k_shape, v_shape)
122-
result_layouts = default_layouts(o_shape, lse_shape)
123-
124-
o, lse = custom_call(
125-
q, k, v,
126-
call_target_name = b"flash_mha_fwd",
127-
result_types=out_types,
128-
backend_config=opaque,
129-
operand_layouts=operand_layouts,
130-
result_layouts=result_layouts,
131-
)
105+
jax_dtype = jnp.bfloat16 if type(element_type) == ir.BF16Type else jnp.float16
106+
out_types = [jax.ShapeDtypeStruct(o_shape, jax_dtype), jax.ShapeDtypeStruct(lse_shape, jnp.float32)]
107+
108+
o, lse = jax.ffi.ffi_call(
109+
"flash_mha_fwd",
110+
result_shape_dtypes=out_types,
111+
has_side_effect=False,
112+
legacy_backend_config=opaque,
113+
input_layouts=[None, None, None], # default row major
114+
output_layouts=[None, None],
115+
custom_call_api_version=1
116+
)(q, k, v)
132117

133118
if dpad > 0:
134119
o = o[:,:,:,:d]
@@ -197,19 +182,20 @@ def fwd(dout, q, k, v, out, lse):
197182
dout = jnp.pad(dout, ((0,0),(0,0),(0,0),(0,dpad)), 'constant')
198183

199184
# For MQA/GQA, hq != hk, but we pass a hq sized output tensor to the kernel and sum over it afterwards to reduce the size.
200-
out_types = [ir.RankedTensorType.get([n, lq, hq, d+dpad], dtype),
201-
ir.RankedTensorType.get([n, lk, hq, d+dpad], dtype),
202-
ir.RankedTensorType.get([n, lk, hq, d+dpad], dtype)]
203-
out_layouts = default_layouts([n, lq, hq, d+dpad], [n, lk, hq, d+dpad], [n, lk, hq, d+dpad])
204-
205-
dq, dk, dv = custom_call(
206-
dout, q, k, v, out, lse,
207-
call_target_name=b"flash_mha_bwd",
208-
operand_layouts=default_layouts(dout.shape, q.shape, k.shape, v.shape, out.shape, lse.shape),
209-
backend_config=opaque,
210-
result_types=out_types,
211-
result_layouts=out_layouts,
212-
)
185+
jax_dtype = jnp.bfloat16 if type(dtype) == ir.BF16Type else jnp.float16
186+
out_types = [jax.ShapeDtypeStruct((n, lq, hq, d+dpad), jax_dtype),
187+
jax.ShapeDtypeStruct((n, lk, hq, d+dpad), jax_dtype),
188+
jax.ShapeDtypeStruct((n, lk, hq, d+dpad), jax_dtype)]
189+
190+
dq, dk, dv = jax.ffi.ffi_call(
191+
"flash_mha_bwd",
192+
result_shape_dtypes=out_types,
193+
has_side_effect=False,
194+
legacy_backend_config=opaque,
195+
input_layouts=[None]*6, # default row major
196+
output_layouts=[None]*3,
197+
custom_call_api_version=1
198+
)(dout, q, k, v, out, lse)
213199

214200
if hq != hk:
215201
assert hq > hk and hq % hk == 0
@@ -264,32 +250,3 @@ def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, is_caus
264250
ShapedArray(v.shape, v_dtype),
265251
)
266252
_flash_mha_bwd_hlo_p.def_abstract_eval(_flash_mha_bwd_abstract)
267-
268-
# ==== Custom Call ====
269-
270-
def _custom_call_abstract_eval(*args, call_target_name, result_types, backend_config, operand_layouts, result_layouts):
271-
def convert(ty):
272-
ty = ir.RankedTensorType(ty)
273-
shape = tuple(ty.shape)
274-
dtype = ir_type_to_dtype(ty.element_type)
275-
return ShapedArray(shape, dtype)
276-
out_types = [convert(o) for o in result_types]
277-
return tuple(out_types)
278-
279-
_custom_call_p.def_abstract_eval(_custom_call_abstract_eval)
280-
281-
def _custom_call_hlo_lowering(ctx, *args, call_target_name, result_types, backend_config, operand_layouts, result_layouts):
282-
out = mlir.custom_call(
283-
call_target_name,
284-
operands=args,
285-
result_types=list(result_types),
286-
backend_config=backend_config,
287-
operand_layouts=list(operand_layouts),
288-
result_layouts=list(result_layouts),
289-
).results
290-
return out
291-
292-
mlir.register_lowering(
293-
_custom_call_p,
294-
_custom_call_hlo_lowering
295-
)

0 commit comments

Comments
 (0)