Skip to content

Commit 794d337

Browse files
authored
Add pipelined attention kernel (#63)
1 parent a5e9d8c commit 794d337

File tree

1 file changed

+378
-0
lines changed

1 file changed

+378
-0
lines changed

contributed/pipelined_attention.py

Lines changed: 378 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,378 @@
1+
"""
2+
Kernel with software pipelining, adapted from official attention nki sample
3+
4+
Author: Hongyi Jin ([email protected])
5+
6+
WARNING: These kernels:
7+
- Are tested only against internal nightly builds
8+
- May not be compatible with public NeuronSDK releases
9+
- Have not been extensively tested across all input configurations
10+
- Carry no compatibility guarantees
11+
- The behavior of these kernels may be modified without prior notice
12+
13+
"""
14+
import numpy as np
15+
16+
import neuronxcc.nki.isa as nisa
17+
import neuronxcc.nki.language as nl
18+
from neuronxcc import nki
19+
20+
from neuronxcc.nki.language import par_dim, bfloat16
21+
22+
sb_mod = nki.compiler.sbuf.mod_alloc
23+
psum_mod = nki.compiler.psum.mod_alloc
24+
25+
def mm1_dot_psum_alloc(idx, pdim_size, fdim_size):
26+
grp_i, _, tile_i = idx
27+
grp_i = grp_i % 2
28+
return (tile_i % 4), 0, 0
29+
30+
def mm2_dot_psum_alloc(idx, pdim_size, fdim_size):
31+
grp_i, tile_i = idx
32+
return 4 + (tile_i % 4), 0, 0
33+
34+
def exp_tp_psum_alloc(idx, pdim_size, fdim_size):
35+
grp_i, tile_i, tp_grp_i = idx
36+
grp_i = grp_i % 2
37+
return tp_grp_i , 0, 0
38+
39+
40+
# This kernel can only run on 16k seqlen,
41+
@nki.compiler.skip_middle_end_transformations
42+
@nki.baremetal(artifacts_dir="debug", additional_compile_opt="--internal-skip-backend-allocation-opt-nki --disable-internal-io-dge")
43+
def flash_fwd(q, k, v,
44+
softmax_scale=None,
45+
use_causal_mask=True,
46+
mixed_precision=True,
47+
):
48+
"""
49+
Flash Attention Forward kernel
50+
51+
IO tensor layouts:
52+
- q: shape (bs, n_heads, d, seq_q)
53+
- k: shape (bs, nk_heads, d, seq_k)
54+
- v: shape (bs, nv_heads, d, seq_v) if config.should_transpose_v else (bs, nv_heads, seq_v, d)
55+
- o: shape (bs, n_heads, seq_q, d)
56+
- This kernel requires seq_k == seq_v
57+
58+
IO tensor dtypes:
59+
- This kernel assumes all IO tensors have the same dtype
60+
- If mixed_precision is True, then all Tensor Engine operation will be performed in
61+
bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
62+
will be in the same type as the inputs.
63+
64+
Compile-time Constants:
65+
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
66+
- mixed_precision: flag to set non-matmul ops in fp32 precision, default is set to `true`, if false, we use same precision as input types
67+
"""
68+
b, d, seqlen_q = q.shape
69+
_, _, seqlen_k = k.shape
70+
71+
assert use_causal_mask == False, "causal mask code path disabled"
72+
73+
assert tuple(v.shape) == (b, seqlen_k, d), f"Expect shape of V to be {(b, seqlen_k, d)} (batch, heads, seqlen_k, d_head) but got {v.shape}"
74+
assert tuple(k.shape) == (b, d, seqlen_k), f"Expect shape of K to be {(b, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {k.shape}"
75+
assert d <= 128, f" we do not support head_dim > 128, got head dim {d}"
76+
kernel_dtype = nl.bfloat16 if mixed_precision else q.dtype
77+
acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype
78+
79+
o = nl.ndarray((b, seqlen_q, d), dtype=q.dtype, buffer=nl.shared_hbm)
80+
81+
batch_id = nl.program_id(0)
82+
softmax_scale = softmax_scale or (1.0 / (d ** 0.5))
83+
84+
sb_p = 128
85+
assert seqlen_k % sb_p == 0
86+
num_grps = seqlen_k // sb_p
87+
section_len = 8192
88+
num_sections = seqlen_q // section_len
89+
90+
num_2048_tiles_per_section = section_len // 2048
91+
num_512_tiles_per_section = section_len // 512
92+
num_128_tiles_per_section = section_len // 128
93+
94+
sca = 0
95+
96+
identity = nl.shared_constant(np.identity(128, dtype=np.int8), dtype=nl.bfloat16)
97+
identity_load = nl.ndarray((par_dim(128), 128), dtype=nl.bfloat16, buffer=sb_mod(base_addr=sca))
98+
id_p, id_f = nl.mgrid[0:128, 0:128]
99+
identity_load[id_p, id_f] = nl.load(identity)
100+
sca += 128 * 2
101+
102+
zero_bias_tensor = nl.ndarray((128, 1), dtype=nl.float32, buffer=sb_mod(base_addr=sca))
103+
zero_bias_tensor[...] = 0.0
104+
sca += 4
105+
106+
running_max = nl.ndarray((sb_p, num_grps), dtype=nl.float32, buffer=sb_mod(base_addr=sca))
107+
sca += num_grps * 4
108+
running_sum = nl.ndarray((sb_p, num_grps), dtype=nl.float32, buffer=sb_mod(base_addr=sca))
109+
sca += num_grps * 4
110+
div_25_sbuf = nl.ndarray((128, num_grps), dtype=nl.float32, buffer=sb_mod(base_addr=sca))
111+
sca += num_grps * 4
112+
113+
for section_i in nl.sequential_range(num_sections):
114+
num_2048_tiles_cur_section = num_2048_tiles_per_section
115+
num_512_tiles_cur_section = num_512_tiles_per_section
116+
num_128_tiles_cur_section = num_128_tiles_per_section
117+
118+
p, n = d, 128*4
119+
k_loaded = nl.ndarray((num_512_tiles_cur_section, nl.par_dim(p), n), dtype=nl.bfloat16, buffer=sb_mod(base_addr=sca, num_free_tiles=(num_512_tiles_cur_section, )))
120+
sca += num_512_tiles_cur_section * n * 2
121+
for i in nl.affine_range(num_512_tiles_cur_section):
122+
ip_k, if_k = nl.mgrid[0:p, 0:n]
123+
k_loaded[i, ip_k, if_k] = nl.load(k[batch_id, ip_k, section_len*section_i+512*i+if_k], dtype=nl.bfloat16)
124+
125+
p, n = sb_p, d
126+
v_loaded = nl.ndarray((num_128_tiles_cur_section, nl.par_dim(p), n), dtype=nl.bfloat16, buffer=sb_mod(base_addr=sca, num_free_tiles=(num_128_tiles_cur_section, )))
127+
sca += num_128_tiles_cur_section * n * 2
128+
for i in nl.affine_range(num_128_tiles_cur_section):
129+
ip_v, if_v = nl.mgrid[0:p, 0:n]
130+
v_loaded[i, ip_v, if_v] = nl.load(v[batch_id, section_len*section_i + i * 128 + ip_v, if_v], dtype=nl.bfloat16)
131+
132+
q_loaded = nl.ndarray((num_grps, nl.par_dim(d), sb_p), dtype=nl.bfloat16, buffer=sb_mod(base_addr=sca, num_free_tiles=(2, )))
133+
sca += 2 * sb_p * 2
134+
reduce14_num_parts = 128
135+
scaling_factor = nl.ndarray((num_grps, nl.par_dim(reduce14_num_parts), 1), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(2, )))
136+
sca += 2 * 1 * 4
137+
num_blks = num_512_tiles_cur_section
138+
temp_reduce14_sbuf = nl.ndarray((num_grps, nl.par_dim(reduce14_num_parts), num_blks), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, )))
139+
sca += 2 * num_blks * 4
140+
141+
p, n = d,sb_p
142+
psum_p, psum_n = 128, sb_p * 4 # (128, 512)
143+
sbuf_p, sbuf_n = psum_p, psum_n*4 # (128, 2048)
144+
145+
mhlo_mul_2 = nl.ndarray((num_grps, num_2048_tiles_cur_section, nl.par_dim(sbuf_p), sbuf_n), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(2, num_2048_tiles_cur_section)))
146+
sca += num_2048_tiles_cur_section * sbuf_n * 4 * 2
147+
mm1_psum_dot = nl.ndarray((num_grps, num_2048_tiles_cur_section, 4, nl.par_dim(psum_p), psum_n), dtype=nl.float32, buffer=nki.compiler.psum.alloc(mm1_dot_psum_alloc))
148+
149+
final_reduce_max = nl.ndarray((num_grps, nl.par_dim(128), 1), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, )))
150+
sca += 2 * 4
151+
152+
prev_runnning_max = nl.ndarray((num_grps, nl.par_dim(reduce14_num_parts), 1), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, )))
153+
sca += 2 * 4
154+
155+
exp_inst_elems = 2048
156+
exp_insts = 2048 // exp_inst_elems # num of exp insts per si iter
157+
ip_final_reduce_sum, _ = nl.mgrid[0:128, 0:1]
158+
final_reduce_sum_b = nl.ndarray((num_grps, nl.par_dim(128), section_len//exp_inst_elems), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(2, )))
159+
sca += 2 * (section_len // exp_inst_elems) * 4
160+
161+
final_reduce_sum_b_collect = nl.ndarray((num_grps, nl.par_dim(128), 1), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, )))
162+
sca += 2 * 1 * 4
163+
164+
prev_running_sum = nl.ndarray(shape=(num_grps, nl.par_dim(128), 1), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, )))
165+
sca += 2 * 1 * 4
166+
167+
prev_output = nl.ndarray((num_grps, nl.par_dim(128), 128), dtype=o.dtype, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, )))
168+
sca += 2 * 128 * 4
169+
mm2_sbuf_res = nl.ndarray((num_grps, nl.par_dim(128), 128), dtype=q.dtype, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, )))
170+
sca += 2 * 128 * 4
171+
mm2_div_sbuf = nl.ndarray((num_grps, nl.par_dim(128), 128), dtype=q.dtype, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, )))
172+
sca += 2 * 128 * 4
173+
174+
num_tps = exp_inst_elems // 128
175+
num_tp_grps = num_tps // 4
176+
num_tps_in_grp = 4
177+
n_per_part = num_tps_in_grp * 128
178+
179+
180+
# p, n, access_n = 128, 2048, exp_inst_elems
181+
exp6_sbuf = nl.ndarray((num_grps, num_2048_tiles_cur_section, nl.par_dim(p), 2048), dtype=nl.bfloat16, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, num_2048_tiles_cur_section)))
182+
sca += 2 * num_2048_tiles_cur_section * 2048
183+
184+
tp_sbuf = nl.ndarray((num_grps, num_2048_tiles_cur_section, num_tp_grps, nl.par_dim(128), n_per_part), dtype=nl.bfloat16,
185+
buffer=sb_mod(base_addr=sca, num_free_tiles=(1, num_2048_tiles_cur_section, num_tp_grps)))
186+
sca += num_2048_tiles_cur_section * num_tp_grps * 2 * n_per_part
187+
188+
mm2_p, mm2_n = sb_p, d
189+
mm2_sbuf = nl.ndarray((num_grps, nl.par_dim(mm2_p), mm2_n), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, )))
190+
sca += 2 * mm2_n * 4
191+
tp_psum = nl.ndarray((num_grps, num_2048_tiles_cur_section, num_tp_grps, nl.par_dim(128), n_per_part), dtype=nl.float32, buffer=nki.compiler.psum.alloc(exp_tp_psum_alloc))
192+
193+
mm2_psum = nl.ndarray((num_grps, num_2048_tiles_cur_section, nl.par_dim(sb_p), mm2_n), dtype=nl.float32, buffer=nki.compiler.psum.alloc(mm2_dot_psum_alloc))
194+
iq_p, iq_f = nl.mgrid[0:p, 0:n]
195+
def load_q(grp_i):
196+
q_loaded[grp_i, iq_p, iq_f] = nl.load(q[batch_id, iq_p, grp_i*n+iq_f])
197+
198+
def qk_and_max(grp_i):
199+
for si in nl.affine_range(num_2048_tiles_cur_section):
200+
for pi in nl.affine_range(4):
201+
loc_512_tile_i = si*4+pi
202+
ip_res, if_res = nl.mgrid[0:128, 0:512]
203+
ip_reduce_res, _ = nl.mgrid[0:128, 0:1]
204+
mm1_psum_dot[grp_i, si, pi, ip_res, if_res] = nisa.nc_matmul(q_loaded[grp_i, :, :], k_loaded[loc_512_tile_i, :, :])
205+
mhlo_mul_2[grp_i, si, ip_res, pi*512+if_res] = nisa.tensor_scalar_reduce(
206+
data=mm1_psum_dot[grp_i, si, pi, ip_res, if_res], op0=np.multiply, operand0=softmax_scale,
207+
reduce_op=nl.max, reduce_res=temp_reduce14_sbuf[grp_i, ip_reduce_res, si*4+pi], name="mm1-tsp"
208+
)
209+
210+
def update_max(grp_i):
211+
ip_reduce, _= nl.mgrid[0:128, 0:1]
212+
final_reduce_max[grp_i, ip_reduce, 0] = nisa.tensor_reduce(np.max, temp_reduce14_sbuf[grp_i], 1, negate=True)
213+
if section_i == 0:
214+
running_max[ip_reduce, grp_i] = nisa.tensor_copy(final_reduce_max[grp_i])
215+
if section_i > 0:
216+
prev_runnning_max[grp_i, ip_reduce, 0] = nisa.activation(np.copy, running_max[ip_reduce, grp_i], scale=-1.0, bias=zero_bias_tensor)
217+
running_max[ip_reduce, grp_i] = nisa.tensor_tensor(running_max[ip_reduce, grp_i], final_reduce_max[grp_i], op=nl.minimum)
218+
scaling_factor[grp_i, ip_reduce, 0] = nisa.activation(np.exp, prev_runnning_max[grp_i], bias=running_max[ip_reduce, grp_i], scale=1.0)
219+
220+
assert section_len//exp_inst_elems == num_2048_tiles_cur_section
221+
def exp(grp_i):
222+
ip_reduce, _= nl.mgrid[0:128, 0:1]
223+
for si in nl.affine_range(num_2048_tiles_cur_section):
224+
p, n, access_n = 128, 2048, exp_inst_elems
225+
ip_p, ip_n = nl.mgrid[0:p, 0:access_n]
226+
for pi in nl.affine_range(exp_insts): # This loop doesn't actually exist with current config, exp_insts==1 in current config
227+
exp6_sbuf[grp_i, si, ip_p, pi*access_n+ip_n] = nisa.activation_reduce(np.exp, mhlo_mul_2[grp_i, si, ip_p, access_n*pi+ip_n],
228+
reduce_op=np.add, reduce_res=final_reduce_sum_b[grp_i, ip_final_reduce_sum, si*exp_insts+pi],
229+
bias=running_max[ip_reduce, grp_i], name='exp6',
230+
)
231+
232+
def tp(grp_i):
233+
for si in nl.affine_range(num_2048_tiles_cur_section):
234+
for tp_grp in nl.affine_range(num_tp_grps):
235+
ip_tp, if_tp = nl.mgrid[0:128, 0:128]
236+
ip_cp, if_cp = nl.mgrid[0:128, 0:n_per_part]
237+
for ti in nl.affine_range(num_tps_in_grp):
238+
tp_psum[grp_i, si, tp_grp, ip_tp, ti*128+if_tp] = nisa.nc_matmul(exp6_sbuf[grp_i, si, ip_tp, tp_grp*n_per_part+ti*128+if_tp], identity_load)
239+
tp_sbuf[grp_i, si, tp_grp, ip_cp, if_cp] = nisa.tensor_copy(tp_psum[grp_i, si, tp_grp], dtype=nl.bfloat16, name='tp-act-cp',
240+
)
241+
def pv(grp_i):
242+
mm2_sbuf[grp_i] = 0.0
243+
for mm2i in nl.affine_range(num_2048_tiles_cur_section):
244+
num_tp_grps_in_2048_tile = 4
245+
# mm2_psum = nl.zeros((nl.par_dim(sb_p), mm2_n), dtype=nl.float32, buffer=nl.psum)
246+
for tp_grp_i in nl.affine_range(num_tp_grps_in_2048_tile):
247+
mm2_num_grps = 4
248+
ip_mm2, if_mm2 = nl.mgrid[0:128, 0:128]
249+
for mm2_si in nl.affine_range(mm2_num_grps):
250+
mm2_psum[grp_i, mm2i, ip_mm2, if_mm2] += nisa.nc_matmul(tp_sbuf[grp_i, mm2i, tp_grp_i, ip_mm2, mm2_si*128+if_mm2],v_loaded[mm2i*16+tp_grp_i*4+mm2_si, ip_mm2, if_mm2])
251+
mm2_sbuf[grp_i] = nl.loop_reduce(mm2_psum[grp_i, mm2i], np.add, loop_indices=[mm2i],name='mm2-itt',
252+
)
253+
254+
def fused_qkmax_and_pv(grp_i):
255+
mm2_sbuf[grp_i] = 0.0
256+
for si in nl.affine_range(num_2048_tiles_cur_section):
257+
for pi in nl.affine_range(4):
258+
loc_512_tile_i = si*4+pi
259+
ip_res, if_res = nl.mgrid[0:128, 0:512]
260+
ip_reduce_res, _ = nl.mgrid[0:128, 0:1]
261+
mm1_psum_dot[grp_i+2, si, pi, ip_res, if_res] = nisa.nc_matmul(q_loaded[grp_i+2, :, :], k_loaded[loc_512_tile_i, :, :])
262+
mhlo_mul_2[grp_i+2, si, ip_res, pi*512+if_res] = nisa.tensor_scalar_reduce(
263+
data=mm1_psum_dot[grp_i+2, si, pi, ip_res, if_res], op0=np.multiply, operand0=softmax_scale,
264+
reduce_op=nl.max, reduce_res=temp_reduce14_sbuf[grp_i+2, ip_reduce_res, si*4+pi], name="mm1-tsp"
265+
)
266+
mm2i=si
267+
num_tp_grps_in_2048_tile = 4
268+
# mm2_psum = nl.zeros((nl.par_dim(sb_p), mm2_n), dtype=nl.float32, buffer=nl.psum)
269+
for tp_grp_i in nl.affine_range(num_tp_grps_in_2048_tile):
270+
mm2_num_grps = 4
271+
ip_mm2, if_mm2 = nl.mgrid[0:128, 0:128]
272+
for mm2_si in nl.affine_range(mm2_num_grps):
273+
mm2_psum[grp_i, mm2i, ip_mm2, if_mm2] += nisa.nc_matmul(tp_sbuf[grp_i, mm2i, tp_grp_i, ip_mm2, mm2_si*128+if_mm2],v_loaded[mm2i*16+tp_grp_i*4+mm2_si, ip_mm2, if_mm2])
274+
mm2_sbuf[grp_i] = nl.loop_reduce(mm2_psum[grp_i, mm2i], np.add, loop_indices=[mm2i],name='mm2-itt')
275+
276+
def write_back(grp_i):
277+
ip_o, if_o = nl.mgrid[0:128,0:128]
278+
279+
ip_reduce, _= nl.mgrid[0:128, 0:1]
280+
final_reduce_sum_b_collect[grp_i] = nisa.tensor_reduce(np.sum, final_reduce_sum_b[grp_i], axis=(1,))
281+
if section_i == 0:
282+
running_sum[ip_reduce, grp_i] = nisa.tensor_copy(final_reduce_sum_b_collect[grp_i])
283+
if section_i > 0:
284+
prev_running_sum[grp_i] = nisa.tensor_copy(running_sum[ip_reduce, grp_i])
285+
running_sum[ip_reduce, grp_i] = nisa.tensor_scalar(prev_running_sum[grp_i, ip_reduce, 0], np.multiply, scaling_factor[grp_i], op1=nl.add, operand1=final_reduce_sum_b_collect[grp_i])
286+
if section_i == num_sections - 1:
287+
div_25_sbuf[ip_reduce, grp_i] = nisa.reciprocal(running_sum[ip_reduce, grp_i])
288+
289+
if section_i == 0:
290+
nl.store(o[batch_id, grp_i*sb_p+ip_o, if_o], value=mm2_sbuf[grp_i])
291+
292+
if section_i == num_sections - 1:
293+
prev_output[grp_i] = nl.load(o[batch_id, grp_i*sb_p+ip_o, if_o])
294+
mm2_sbuf_res[grp_i] = nisa.scalar_tensor_tensor(data=prev_output[grp_i], op0=np.multiply, operand0=scaling_factor[grp_i], op1=np.add, operand1=mm2_sbuf[grp_i])
295+
296+
mm2_div_sbuf[grp_i] = nisa.activation(np.copy, mm2_sbuf_res[grp_i], scale=div_25_sbuf[ip_reduce, grp_i],bias=zero_bias_tensor)
297+
nl.store(o[batch_id, grp_i*sb_p+ip_o, if_o], value=mm2_div_sbuf[grp_i])
298+
load_q(0)
299+
qk_and_max(0)
300+
update_max(0)
301+
exp(0)
302+
tp(0)
303+
load_q(1)
304+
qk_and_max(1)
305+
update_max(1)
306+
for grp_i in nl.affine_range(num_grps-2, precise_schedule=True): # for each block of seq_q
307+
load_q(grp_i+2)
308+
exp(grp_i+1)
309+
fused_qkmax_and_pv(grp_i)
310+
tp(grp_i+1)
311+
write_back(grp_i)
312+
update_max(grp_i+2)
313+
pv(num_grps-2)
314+
write_back(num_grps-2)
315+
exp(num_grps-1)
316+
tp(num_grps-1)
317+
pv(num_grps-1)
318+
write_back(num_grps-1)
319+
320+
321+
return o
322+
323+
def softmax(x: np.ndarray, dim: int, zero_max_mode=False,
324+
mixed_precision=False):
325+
max_value = np.amax(x, axis=dim, keepdims=True)
326+
max_value = np.maximum(0, max_value) if zero_max_mode else max_value
327+
328+
exp = np.exp(x - max_value)
329+
330+
reduce = np.add.reduce(exp.astype(np.float32), axis=dim, keepdims=True).astype(x.dtype)
331+
332+
result = exp / reduce
333+
334+
return exp / reduce
335+
336+
def cpu_attention_forward(q, k, v, softmax_scale, use_causal_mask=True, mixed_precision=True):
337+
def mixed_precision_matmul(a, b):
338+
input_dtype = a.dtype
339+
a, b = a.astype(np.float32), b.astype(np.float32)
340+
c = np.matmul(a, b)
341+
return c.astype(input_dtype)
342+
343+
# Compute golden output
344+
q_scaled = q * softmax_scale
345+
346+
raw_score = mixed_precision_matmul(q_scaled.transpose(0, 2, 1), k)
347+
348+
if use_causal_mask:
349+
# raw_score has K seq in the most inner dim
350+
# we want to mask all elements where Q idx is smaller than K idx with -inf
351+
# this maps to the upper triangle of the final two axes
352+
for i in range(raw_score.shape[0]):
353+
for j in range(raw_score.shape[1]):
354+
# -inf triggers invalid input error in softmax implementation, use a small negative instead
355+
# k=1 to exclude the diagonal, because each token can still attend to itself
356+
raw_score[i, j][np.triu_indices_from(raw_score[i, j], k=1)] = -9984.0
357+
358+
norm_score = softmax(raw_score, dim=-1, mixed_precision=mixed_precision)
359+
360+
# Transpose the result so it has the same layout as ours
361+
out_golden = mixed_precision_matmul(norm_score, v)
362+
363+
return out_golden, norm_score
364+
365+
bs=1
366+
seqlen_q=16*1024
367+
seqlen_k=16*1024
368+
d=128
369+
softmax_scale=0.125
370+
dtype=bfloat16
371+
372+
q = np.random.rand(bs, d, seqlen_q).astype(dtype)
373+
k = np.random.rand(bs, d, seqlen_k).astype(dtype)
374+
v = np.random.rand(bs, seqlen_k, d).astype(dtype)
375+
o = flash_fwd[0](q, k, v, mixed_precision=True, softmax_scale=softmax_scale, use_causal_mask=False)
376+
o_std, scores = cpu_attention_forward(q, k, v, softmax_scale, use_causal_mask=False, mixed_precision=True)
377+
378+
np.testing.assert_allclose(o, o_std, atol=1e-2)

0 commit comments

Comments
 (0)