Skip to content

Commit 18e4a80

Browse files
committed
Fix issue with edge padding <#7>.
1 parent 565ee42 commit 18e4a80

File tree

1 file changed

+28
-38
lines changed

1 file changed

+28
-38
lines changed

src/flash_attn_jax/flash_hlo.py

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,14 @@ def _flash_mha_fwd_hlo_lowering(ctx, q, k, v, softmax_scale=None, is_causal=Fals
7777
v_type = ir.RankedTensorType(v.type)
7878
v_shape = v_type.shape
7979

80-
assert q_type.element_type == k_type.element_type
81-
assert q_type.element_type == v_type.element_type
80+
assert q_type.element_type == k_type.element_type, "Q and K must have the same dtype"
81+
assert q_type.element_type == v_type.element_type, "Q and V must have the same dtype"
8282
element_type = q_type.element_type
83-
assert type(element_type) in [ir.F16Type, ir.BF16Type]
83+
assert type(element_type) in [ir.F16Type, ir.BF16Type], "Only support fp16 and bf16 data type"
8484
[n, l, h, d] = q_shape
8585
[nk, lk, hk, dk] = k_shape
86-
87-
88-
assert k_shape == v_shape
89-
assert [n, d] == [nk, dk]
86+
assert k_shape == v_shape, "K and V must have the same shape"
87+
assert [n, d] == [nk, dk], "Q and K must have the same batch size and head size"
9088

9189
opaque = flash_api.make_flash_mha_fwd_args(
9290
0.0, # p_dropout
@@ -100,47 +98,39 @@ def _flash_mha_fwd_hlo_lowering(ctx, q, k, v, softmax_scale=None, is_causal=Fals
10098
flash_api.BF16 if type(element_type) == ir.BF16Type else flash_api.FP16,
10199
0)
102100

103-
lse_type = ir.RankedTensorType.get([n, h, l], ir.F32Type.get(ctx.module_context.context))
104-
105-
if d % 8 != 0:
106-
# We need padding. It's better to let xla's allocator handle it here than directly call cudaMalloc.
107-
dpad = 8 - d%8
108-
109-
z = np.array(0.0, dtype=ir_type_to_dtype(element_type))
110-
z = mlir.ir_constant(z)
111-
q_padded = mlir.hlo.PadOp(q,z,[0,0,0,0],[0,0,0,dpad],[0,0,0,0]).result
112-
k_padded = mlir.hlo.PadOp(k,z,[0,0,0,0],[0,0,0,dpad],[0,0,0,0]).result
113-
v_padded = mlir.hlo.PadOp(v,z,[0,0,0,0],[0,0,0,dpad],[0,0,0,0]).result
114-
101+
def fwd(q, k, v):
102+
dpad = (8 - d%8) % 8
103+
if dpad > 0:
104+
# We need padding. It's better to let xla's allocator handle it here than directly call cudaMalloc.
105+
q = jnp.pad(q, ((0,0),(0,0),(0,0),(0,dpad)), 'constant')
106+
k = jnp.pad(k, ((0,0),(0,0),(0,0),(0,dpad)), 'constant')
107+
v = jnp.pad(v, ((0,0),(0,0),(0,0),(0,dpad)), 'constant')
108+
115109
q_shape = [n, l, h, d+dpad]
116110
k_shape = [n, lk, hk, d+dpad]
117111
v_shape = [n, lk, hk, d+dpad]
118112
o_shape = [n, l, h, d+dpad]
113+
lse_shape = [n, h, l]
119114

115+
116+
lse_type = ir.RankedTensorType.get([n, h, l], mlir.dtype_to_ir_type(jnp.float32.dtype))
120117
out_types = [ir.RankedTensorType.get(o_shape, element_type), lse_type]
118+
operand_layouts = default_layouts(q_shape, k_shape, v_shape)
119+
result_layouts = default_layouts(o_shape, lse_shape)
121120

122-
(o, lse) = mlir.custom_call(
123-
b"flash_mha_fwd",
121+
o, lse = custom_call(
122+
q, k, v,
123+
call_target_name = b"flash_mha_fwd",
124124
result_types=out_types,
125-
operands=[q_padded, k_padded, v_padded],
126125
backend_config=opaque,
127-
operand_layouts=default_layouts(q_shape, k_shape, v_shape),
128-
result_layouts=default_layouts(*[o.shape for o in out_types]),
129-
).results
126+
operand_layouts=operand_layouts,
127+
result_layouts=result_layouts,
128+
)
130129

131-
o = mlir.hlo.SliceOp(o, [0,0,0,0], (n, l, h, d), [1,1,1,1]).result
132-
return (o,lse)
133-
else:
134-
out_types = [ir.RankedTensorType.get([n, l, h, d], element_type), lse_type]
135-
out = mlir.custom_call(
136-
b"flash_mha_fwd",
137-
result_types=out_types,
138-
operands=[q, k, v],
139-
backend_config=opaque,
140-
operand_layouts=default_layouts(q_shape, k_shape, v_shape),
141-
result_layouts=default_layouts(*[o.shape for o in out_types]),
142-
).results
143-
return out
130+
if dpad > 0:
131+
o = o[:,:,:,:d]
132+
return o, lse
133+
return mlir.lower_fun(fwd, multiple_results=True)(ctx, q, k, v)
144134

145135
mlir.register_lowering(
146136
_flash_mha_fwd_hlo_p,

0 commit comments

Comments
 (0)