Skip to content

Commit af4317a

Browse files
committed
Implement MQA and GQA support and add tests for it.
1 parent 66ef3e8 commit af4317a

File tree

9 files changed

+228
-170
lines changed

9 files changed

+228
-170
lines changed

csrc/flash_attn/mha_bwd.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <cutlass/numeric_types.h>
33
#include <cuda_runtime_api.h>
44
#include <pybind11/pybind11.h>
5+
#include <cute/layout.hpp>
56

67
#include "flash.h"
78
#include "exception.h"
@@ -65,18 +66,24 @@ void set_params_dgrad(Flash_bwd_params &params,
6566
params.dq_ptr = dq_ptr;
6667
params.dk_ptr = dk_ptr;
6768
params.dv_ptr = dv_ptr;
68-
params.dq_row_stride = params.q_row_stride;
69-
params.dk_row_stride = params.k_row_stride;
70-
params.dv_row_stride = params.v_row_stride;
71-
params.dq_head_stride = params.q_head_stride;
72-
params.dk_head_stride = params.k_head_stride;
73-
params.dv_head_stride = params.v_head_stride;
69+
70+
// dk&dv is expanded to the same h as dq for MQA, we sum it later
71+
auto dq = cute::compact_row_major(cute::make_shape(b, seqlen_q, h, d));
72+
auto dk = cute::compact_row_major(cute::make_shape(b, seqlen_k, h, d));
73+
auto dv = cute::compact_row_major(cute::make_shape(b, seqlen_k, h, d));
74+
75+
params.dq_row_stride = cute::get<1>(dq);
76+
params.dk_row_stride = cute::get<1>(dk);
77+
params.dv_row_stride = cute::get<1>(dv);
78+
params.dq_head_stride = cute::get<2>(dq);
79+
params.dk_head_stride = cute::get<2>(dk);
80+
params.dv_head_stride = cute::get<2>(dv);
7481

7582
if (cu_seqlens_q_d == nullptr) {
7683
params.do_batch_stride = params.o_batch_stride;
77-
params.dq_batch_stride = params.q_batch_stride;
78-
params.dk_batch_stride = params.k_batch_stride;
79-
params.dv_batch_stride = params.v_batch_stride;
84+
params.dq_batch_stride = cute::get<0>(dq);
85+
params.dk_batch_stride = cute::get<0>(dk);
86+
params.dv_batch_stride = cute::get<0>(dv);
8087
}
8188

8289
params.dq_accum_ptr = dq_accum_d;
@@ -273,9 +280,8 @@ mha_bwd(cudaStream_t stream, void **buffers, const char* opaque, size_t opaque_l
273280
}
274281

275282

276-
// Not sure what this is about. It needs extra scratch space for dk and dv when hk > h?
277-
// Maybe because it's partitioning by n and h.
278-
// disabled for now and figure out how to handle it later
283+
// For MQA, dk and dv are expanded to the same n_heads as dq (handled in xla).
284+
// After returning the result, it gets reduced to the original size by summing, so we don't need to do anything here.
279285
void* dk_expanded = dk;
280286
void* dv_expanded = dv;
281287
// at::Tensor dk_expanded, dv_expanded;
@@ -376,7 +382,7 @@ mha_bwd(cudaStream_t stream, void **buffers, const char* opaque, size_t opaque_l
376382

377383
// For MQA/GQA we need to sum dK and dV across the groups
378384
if (num_heads_k != num_heads) {
379-
CHECK(false, "don't handle MQA yet");
385+
// CHECK(false, "don't handle MQA yet");
380386
// at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
381387
// at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
382388
}

src/flash_attn_jax/flash_hlo.py

Lines changed: 74 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from jax.interpreters import xla
1111
from jax.interpreters.mlir import ir
1212
from jax.lib import xla_client
13-
from jaxlib.hlo_helpers import custom_call
1413
from jax.experimental.custom_partitioning import custom_partitioning
1514

1615
from jax.sharding import PartitionSpec as P
@@ -19,6 +18,7 @@
1918
from jax.sharding import PositionalSharding
2019

2120
from einops import rearrange
21+
import einops
2222
import math
2323

2424
import flash_attn_jax_lib.flash_api as flash_api
@@ -33,6 +33,10 @@
3333
_flash_mha_bwd_hlo_p.multiple_results = True
3434
_flash_mha_bwd_hlo_p.def_impl(partial(xla.apply_primitive, _flash_mha_bwd_hlo_p))
3535

36+
_custom_call_p = core.Primitive("custom_call")
37+
_custom_call_p.multiple_results = True
38+
_custom_call_p.def_impl(partial(xla.apply_primitive, _custom_call_p))
39+
3640
# ==== Primitive wrapper ====
3741

3842
def _flash_mha_fwd_hlo(q, k, v, softmax_scale, is_causal, window_size):
@@ -43,6 +47,9 @@ def _flash_mha_bwd_hlo(dout, q, k, v, out, lse, softmax_scale, is_causal, window
4347
dq, dk, dv = _flash_mha_bwd_hlo_p.bind(dout, q, k, v, out, lse, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size)
4448
return dq, dk, dv
4549

50+
def custom_call(*args, call_target_name, result_types, backend_config, operand_layouts, result_layouts):
51+
return _custom_call_p.bind(*args, call_target_name=call_target_name, result_types=result_types, backend_config=backend_config, operand_layouts=operand_layouts, result_layouts=result_layouts)
52+
4653
# ==== HLO lowerings ====
4754

4855
# Register functions defined in gpu_ops as custom call target for GPUs
@@ -112,7 +119,7 @@ def _flash_mha_fwd_hlo_lowering(ctx, q, k, v, softmax_scale=None, is_causal=Fals
112119

113120
out_types = [ir.RankedTensorType.get(o_shape, element_type), lse_type]
114121

115-
(o, lse) = custom_call(
122+
(o, lse) = mlir.custom_call(
116123
b"flash_mha_fwd",
117124
result_types=out_types,
118125
operands=[q_padded, k_padded, v_padded],
@@ -125,7 +132,7 @@ def _flash_mha_fwd_hlo_lowering(ctx, q, k, v, softmax_scale=None, is_causal=Fals
125132
return (o,lse)
126133
else:
127134
out_types = [ir.RankedTensorType.get([n, l, h, d], element_type), lse_type]
128-
out = custom_call(
135+
out = mlir.custom_call(
129136
b"flash_mha_fwd",
130137
result_types=out_types,
131138
operands=[q, k, v],
@@ -155,6 +162,7 @@ def _flash_mha_bwd_hlo_lowering(ctx, dout, q, k, v, out, lse, softmax_scale=None
155162
assert q_type == v_type
156163
assert q_type == out_type
157164
assert type(lse_type) in [ir.F32Type]
165+
dtype = q_type
158166

159167
dout_shape = ir.RankedTensorType(dout.type).shape
160168
q_shape = ir.RankedTensorType(q.type).shape
@@ -184,49 +192,45 @@ def _flash_mha_bwd_hlo_lowering(ctx, dout, q, k, v, out, lse, softmax_scale=None
184192
flash_api.BF16 if type(q_type) == ir.BF16Type else flash_api.FP16,
185193
0)
186194

187-
if d % 8 != 0:
188-
# We need padding. It's better to let xla's allocator handle it here than directly call cudaMalloc.
189-
dpad = 8 - d%8
190-
191-
z = np.array(0.0, dtype=ir_type_to_dtype(q_type))
192-
z = mlir.ir_constant(z)
193-
q_padded = mlir.hlo.PadOp(q,z,[0,0,0,0],[0,0,0,dpad],[0,0,0,0]).result
194-
k_padded = mlir.hlo.PadOp(k,z,[0,0,0,0],[0,0,0,dpad],[0,0,0,0]).result
195-
v_padded = mlir.hlo.PadOp(v,z,[0,0,0,0],[0,0,0,dpad],[0,0,0,0]).result
196-
out_padded = mlir.hlo.PadOp(out,z,[0,0,0,0],[0,0,0,dpad],[0,0,0,0]).result
197-
dout_padded = mlir.hlo.PadOp(dout,z,[0,0,0,0],[0,0,0,dpad],[0,0,0,0]).result
198-
199-
# Outputs are the same shape as the q,k,v (including padding)
200-
out_types = [q_padded.type, k_padded.type, v_padded.type]
195+
def fwd(dout, q, k, v, out, lse):
196+
dpad = (8 - d%8) % 8
197+
if dpad > 0:
198+
# We need padding. It's better to let xla's allocator handle it here than directly call cudaMalloc.
199+
q = jnp.pad(q, ((0,0),(0,0),(0,0),(0,dpad)), 'constant')
200+
k = jnp.pad(k, ((0,0),(0,0),(0,0),(0,dpad)), 'constant')
201+
v = jnp.pad(v, ((0,0),(0,0),(0,0),(0,dpad)), 'constant')
202+
out = jnp.pad(out, ((0,0),(0,0),(0,0),(0,dpad)), 'constant')
203+
dout = jnp.pad(dout, ((0,0),(0,0),(0,0),(0,dpad)), 'constant')
204+
205+
# For MQA/GQA, hq != hk, but we pass a hq sized output tensor to the kernel and sum over it afterwards to reduce the size.
206+
out_types = [ir.RankedTensorType.get([n, lq, hq, d+dpad], dtype),
207+
ir.RankedTensorType.get([n, lk, hq, d+dpad], dtype),
208+
ir.RankedTensorType.get([n, lk, hq, d+dpad], dtype)]
209+
out_layouts = default_layouts([n, lq, hq, d+dpad], [n, lk, hq, d+dpad], [n, lk, hq, d+dpad])
201210

202211
dq, dk, dv = custom_call(
203-
b"flash_mha_bwd",
204-
result_types=out_types,
205-
operands=[dout_padded, q_padded, k_padded, v_padded, out_padded, lse],
212+
dout, q, k, v, out, lse,
213+
call_target_name=b"flash_mha_bwd",
214+
operand_layouts=default_layouts(dout.shape, q.shape, k.shape, v.shape, out.shape, lse.shape),
206215
backend_config=opaque,
207-
operand_layouts=value_layouts(dout_padded, q_padded, k_padded, v_padded, out_padded, lse),
208-
result_layouts=value_layouts(q_padded, k_padded, v_padded), # dq, dk, dv
209-
).results
210-
211-
dq = mlir.hlo.SliceOp(dq, [0,0,0,0], tuple(q_shape), [1,1,1,1]).result
212-
dk = mlir.hlo.SliceOp(dk, [0,0,0,0], tuple(k_shape), [1,1,1,1]).result
213-
dv = mlir.hlo.SliceOp(dv, [0,0,0,0], tuple(v_shape), [1,1,1,1]).result
216+
result_types=out_types,
217+
result_layouts=out_layouts,
218+
)
219+
220+
if hq != hk:
221+
assert hq > hk and hq % hk == 0
222+
m = hq // hk
223+
dk = einops.reduce(dk, 'n l (h m) d -> n l h d', reduction='sum', h=hk)
224+
dv = einops.reduce(dv, 'n l (h m) d -> n l h d', reduction='sum', h=hk)
225+
226+
if dpad > 0:
227+
dq = dq[:,:,:,:d]
228+
dk = dk[:,:,:,:d]
229+
dv = dv[:,:,:,:d]
214230

215231
return dq, dk, dv
216-
else:
217-
out_types = [ir.RankedTensorType.get(q_shape, q_type),
218-
ir.RankedTensorType.get(k_shape, k_type),
219-
ir.RankedTensorType.get(v_shape, v_type)]
220-
221-
out = custom_call(
222-
b"flash_mha_bwd",
223-
result_types=out_types,
224-
operands=[dout, q, k, v, out, lse],
225-
backend_config=opaque,
226-
operand_layouts=default_layouts(dout_shape, q_shape, k_shape, v_shape, out_shape, lse_shape),
227-
result_layouts=default_layouts(*[o.shape for o in out_types]),
228-
).results
229-
return out
232+
233+
return mlir.lower_fun(fwd, multiple_results=True)(ctx, dout, q, k, v, out, lse)
230234

231235
mlir.register_lowering(
232236
_flash_mha_bwd_hlo_p,
@@ -266,3 +270,32 @@ def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, is_caus
266270
ShapedArray(v.shape, v_dtype, named_shape=v.named_shape),
267271
)
268272
_flash_mha_bwd_hlo_p.def_abstract_eval(_flash_mha_bwd_abstract)
273+
274+
# ==== Custom Call ====
275+
276+
def _custom_call_abstract_eval(*args, call_target_name, result_types, backend_config, operand_layouts, result_layouts):
277+
def convert(ty):
278+
ty = ir.RankedTensorType(ty)
279+
shape = tuple(ty.shape)
280+
dtype = ir_type_to_dtype(ty.element_type)
281+
return ShapedArray(shape, dtype)
282+
out_types = [convert(o) for o in result_types]
283+
return tuple(out_types)
284+
285+
_custom_call_p.def_abstract_eval(_custom_call_abstract_eval)
286+
287+
def _custom_call_hlo_lowering(ctx, *args, call_target_name, result_types, backend_config, operand_layouts, result_layouts):
288+
out = mlir.custom_call(
289+
call_target_name,
290+
operands=args,
291+
result_types=result_types,
292+
backend_config=backend_config,
293+
operand_layouts=operand_layouts,
294+
result_layouts=result_layouts,
295+
).results
296+
return out
297+
298+
mlir.register_lowering(
299+
_custom_call_p,
300+
_custom_call_hlo_lowering
301+
)

src/flash_attn_jax/ring_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,15 @@ def f(c, a):
7979

8080
def ring_bwd(do,q,k,v,o,lse, axis_name, axis_size, mha_bwd, softmax_scale=None, is_causal=False):
8181
[n,l,h,d] = q.shape
82+
[n,lk,hk,d] = k.shape
8283
if softmax_scale is None:
8384
softmax_scale = 1/math.sqrt(d)
8485

8586
ix = jax.lax.axis_index(axis_name)
8687

8788
dq = jnp.zeros([n,l,h,d], jnp.float32)
88-
dk = jnp.zeros([n,l,h,d], jnp.float32)
89-
dv = jnp.zeros([n,l,h,d], jnp.float32)
89+
dk = jnp.zeros([n,lk,hk,d], jnp.float32)
90+
dv = jnp.zeros([n,lk,hk,d], jnp.float32)
9091

9192
# scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
9293
def f(acc, _):

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0

tests/ref_mha.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import glob
2+
import sys, os
3+
4+
import jax
5+
import jax.numpy as jnp
6+
import numpy as np
7+
from functools import partial
8+
import einops
9+
10+
def make_mask(R, C, is_causal, window_size):
11+
mask = jnp.ones([R,C], dtype=jnp.int32)
12+
if is_causal:
13+
mask = jnp.tril(mask)
14+
if window_size[0] != -1:
15+
mask = jnp.triu(mask, -window_size[0])
16+
if window_size[1] != -1:
17+
mask = jnp.tril(mask, window_size[1])
18+
return mask
19+
20+
def ref_mha(q,k,v, is_causal=False, window_size=(-1,-1), softmax_scale=None):
21+
return ref_fwd(q,k,v, is_causal=is_causal, window_size=window_size, softmax_scale=softmax_scale)[0]
22+
23+
def ref_fwd(q,k,v, is_causal=False, window_size=(-1,-1), softmax_scale=None):
24+
[n, l, h, d] = q.shape
25+
[n, lk, hk, d] = k.shape
26+
if softmax_scale is None:
27+
softmax_scale = 1/np.sqrt(d)
28+
mask = make_mask(l,lk,is_causal,window_size)
29+
if h != hk:
30+
assert h > hk and h % hk == 0
31+
q = einops.rearrange(q, 'n L (h x) d -> n L h x d', h=hk)
32+
S = jnp.einsum('nlhxd,nLhd->nhxlL',q,k) * softmax_scale
33+
S = jnp.where(mask, S, float('-inf'))
34+
lse = jax.nn.logsumexp(S, axis=-1) #nhxl
35+
P = jnp.exp(S - lse[...,None]) # n h l L
36+
o = jnp.einsum('nhxlL,nLhd->nlhxd',P,v)
37+
o = einops.rearrange(o, 'n l h x d -> n l (h x) d')
38+
lse = einops.rearrange(lse, 'n h x l -> n (h x) l')
39+
return o.astype(q.dtype), lse.astype(jnp.float32)
40+
else:
41+
att = jnp.einsum('nlhd,nLhd->nhlL',q,k)*softmax_scale
42+
[_, _, l, L] = att.shape
43+
mask = make_mask(l,L,is_causal,window_size)
44+
att = jnp.where(mask, att, float('-inf'))
45+
lse = jax.nn.logsumexp(att, axis=-1) #nhl
46+
att = jnp.exp(att - lse[...,None])
47+
o = jnp.einsum('nhlL,nLhd->nlhd',att,v)
48+
return o.astype(q.dtype), lse.astype(jnp.float32)
49+
50+
def ref_bwd(do,q,k,v,o,lse, is_causal=False, window_size=(-1,-1), softmax_scale=None):
51+
[n, l, h, d] = q.shape
52+
[n, lk, hk, d] = k.shape
53+
if softmax_scale is None:
54+
softmax_scale = 1/np.sqrt(d)
55+
mask = make_mask(l,lk,is_causal,window_size)
56+
if h != hk:
57+
assert h > hk and h % hk == 0
58+
q = einops.rearrange(q, 'n l (h x) d -> n l h x d', h=hk)
59+
lse = einops.rearrange(lse, 'n (h x) l -> n h x l', h=hk)
60+
S = jnp.einsum('nlhxd,nLhd->nhxlL',q,k) * softmax_scale
61+
D = einops.reduce(do * o, 'n l (h x) d -> n h x l', reduction='sum', h=hk)
62+
do = einops.rearrange(do, 'n l (h x) d -> n l h x d', h=hk)
63+
S = jnp.where(mask, S, float('-inf'))
64+
P = jnp.exp(S - lse[...,None]) # n h x l L
65+
dP = jnp.einsum('nlhxd,nLhd->nhxlL',do,v)
66+
dv = jnp.einsum('nlhxd,nhxlL->nLhd',do,P)
67+
dS = P * (dP - D[...,None])
68+
dq = softmax_scale*jnp.einsum('nLhd,nhxlL->nlhxd',k,dS)
69+
dk = softmax_scale*jnp.einsum('nlhxd,nhxlL->nLhd',q,dS)
70+
dq = einops.rearrange(dq, 'n l h x d -> n l (h x) d')
71+
return dq.astype(q.dtype),dk.astype(q.dtype),dv.astype(q.dtype)
72+
else:
73+
S = jnp.einsum('nlhd,nLhd->nhlL',q,k)*softmax_scale
74+
D = einops.reduce(do * o, 'n l h d -> n h l', reduction='sum')
75+
S = jnp.where(mask, S, float('-inf'))
76+
P = jnp.exp(S - lse[...,None]) # n h l L
77+
dP = jnp.einsum('nlhd,nLhd->nhlL',do,v)
78+
dv = jnp.einsum('nlhd,nhlL->nLhd',do,P)
79+
dS = P * (dP - D[...,None])
80+
dq = softmax_scale*jnp.einsum('nLhd,nhlL->nlhd',k,dS)
81+
dk = softmax_scale*jnp.einsum('nlhd,nhlL->nLhd',q,dS)
82+
return dq.astype(q.dtype),dk.astype(q.dtype),dv.astype(q.dtype)

0 commit comments

Comments
 (0)