|
| 1 | +""" |
| 2 | +Kernel with software pipelining, adapted from official attention nki sample |
| 3 | +
|
| 4 | +Author: Hongyi Jin ([email protected]) |
| 5 | +
|
| 6 | +WARNING: These kernels: |
| 7 | + - Are tested only against internal nightly builds |
| 8 | + - May not be compatible with public NeuronSDK releases |
| 9 | + - Have not been extensively tested across all input configurations |
| 10 | + - Carry no compatibility guarantees |
| 11 | + - The behavior of these kernels may be modified without prior notice |
| 12 | +
|
| 13 | +""" |
| 14 | +import numpy as np |
| 15 | + |
| 16 | +import neuronxcc.nki.isa as nisa |
| 17 | +import neuronxcc.nki.language as nl |
| 18 | +from neuronxcc import nki |
| 19 | + |
| 20 | +from neuronxcc.nki.language import par_dim, bfloat16 |
| 21 | + |
| 22 | +sb_mod = nki.compiler.sbuf.mod_alloc |
| 23 | +psum_mod = nki.compiler.psum.mod_alloc |
| 24 | + |
| 25 | +def mm1_dot_psum_alloc(idx, pdim_size, fdim_size): |
| 26 | + grp_i, _, tile_i = idx |
| 27 | + grp_i = grp_i % 2 |
| 28 | + return (tile_i % 4), 0, 0 |
| 29 | + |
| 30 | +def mm2_dot_psum_alloc(idx, pdim_size, fdim_size): |
| 31 | + grp_i, tile_i = idx |
| 32 | + return 4 + (tile_i % 4), 0, 0 |
| 33 | + |
| 34 | +def exp_tp_psum_alloc(idx, pdim_size, fdim_size): |
| 35 | + grp_i, tile_i, tp_grp_i = idx |
| 36 | + grp_i = grp_i % 2 |
| 37 | + return tp_grp_i , 0, 0 |
| 38 | + |
| 39 | + |
| 40 | +# This kernel can only run on 16k seqlen, |
| 41 | +@nki.compiler.skip_middle_end_transformations |
| 42 | +@nki.baremetal(artifacts_dir="debug", additional_compile_opt="--internal-skip-backend-allocation-opt-nki --disable-internal-io-dge") |
| 43 | +def flash_fwd(q, k, v, |
| 44 | + softmax_scale=None, |
| 45 | + use_causal_mask=True, |
| 46 | + mixed_precision=True, |
| 47 | + ): |
| 48 | + """ |
| 49 | + Flash Attention Forward kernel |
| 50 | +
|
| 51 | + IO tensor layouts: |
| 52 | + - q: shape (bs, n_heads, d, seq_q) |
| 53 | + - k: shape (bs, nk_heads, d, seq_k) |
| 54 | + - v: shape (bs, nv_heads, d, seq_v) if config.should_transpose_v else (bs, nv_heads, seq_v, d) |
| 55 | + - o: shape (bs, n_heads, seq_q, d) |
| 56 | + - This kernel requires seq_k == seq_v |
| 57 | +
|
| 58 | + IO tensor dtypes: |
| 59 | + - This kernel assumes all IO tensors have the same dtype |
| 60 | + - If mixed_precision is True, then all Tensor Engine operation will be performed in |
| 61 | + bfloat16 and accumulation will be performed in float32. Otherwise the intermediates |
| 62 | + will be in the same type as the inputs. |
| 63 | +
|
| 64 | + Compile-time Constants: |
| 65 | + - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` |
| 66 | + - mixed_precision: flag to set non-matmul ops in fp32 precision, default is set to `true`, if false, we use same precision as input types |
| 67 | + """ |
| 68 | + b, d, seqlen_q = q.shape |
| 69 | + _, _, seqlen_k = k.shape |
| 70 | + |
| 71 | + assert use_causal_mask == False, "causal mask code path disabled" |
| 72 | + |
| 73 | + assert tuple(v.shape) == (b, seqlen_k, d), f"Expect shape of V to be {(b, seqlen_k, d)} (batch, heads, seqlen_k, d_head) but got {v.shape}" |
| 74 | + assert tuple(k.shape) == (b, d, seqlen_k), f"Expect shape of K to be {(b, d, seqlen_k)} (batch, heads, d_head, seqlen_k) but got {k.shape}" |
| 75 | + assert d <= 128, f" we do not support head_dim > 128, got head dim {d}" |
| 76 | + kernel_dtype = nl.bfloat16 if mixed_precision else q.dtype |
| 77 | + acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype |
| 78 | + |
| 79 | + o = nl.ndarray((b, seqlen_q, d), dtype=q.dtype, buffer=nl.shared_hbm) |
| 80 | + |
| 81 | + batch_id = nl.program_id(0) |
| 82 | + softmax_scale = softmax_scale or (1.0 / (d ** 0.5)) |
| 83 | + |
| 84 | + sb_p = 128 |
| 85 | + assert seqlen_k % sb_p == 0 |
| 86 | + num_grps = seqlen_k // sb_p |
| 87 | + section_len = 8192 |
| 88 | + num_sections = seqlen_q // section_len |
| 89 | + |
| 90 | + num_2048_tiles_per_section = section_len // 2048 |
| 91 | + num_512_tiles_per_section = section_len // 512 |
| 92 | + num_128_tiles_per_section = section_len // 128 |
| 93 | + |
| 94 | + sca = 0 |
| 95 | + |
| 96 | + identity = nl.shared_constant(np.identity(128, dtype=np.int8), dtype=nl.bfloat16) |
| 97 | + identity_load = nl.ndarray((par_dim(128), 128), dtype=nl.bfloat16, buffer=sb_mod(base_addr=sca)) |
| 98 | + id_p, id_f = nl.mgrid[0:128, 0:128] |
| 99 | + identity_load[id_p, id_f] = nl.load(identity) |
| 100 | + sca += 128 * 2 |
| 101 | + |
| 102 | + zero_bias_tensor = nl.ndarray((128, 1), dtype=nl.float32, buffer=sb_mod(base_addr=sca)) |
| 103 | + zero_bias_tensor[...] = 0.0 |
| 104 | + sca += 4 |
| 105 | + |
| 106 | + running_max = nl.ndarray((sb_p, num_grps), dtype=nl.float32, buffer=sb_mod(base_addr=sca)) |
| 107 | + sca += num_grps * 4 |
| 108 | + running_sum = nl.ndarray((sb_p, num_grps), dtype=nl.float32, buffer=sb_mod(base_addr=sca)) |
| 109 | + sca += num_grps * 4 |
| 110 | + div_25_sbuf = nl.ndarray((128, num_grps), dtype=nl.float32, buffer=sb_mod(base_addr=sca)) |
| 111 | + sca += num_grps * 4 |
| 112 | + |
| 113 | + for section_i in nl.sequential_range(num_sections): |
| 114 | + num_2048_tiles_cur_section = num_2048_tiles_per_section |
| 115 | + num_512_tiles_cur_section = num_512_tiles_per_section |
| 116 | + num_128_tiles_cur_section = num_128_tiles_per_section |
| 117 | + |
| 118 | + p, n = d, 128*4 |
| 119 | + k_loaded = nl.ndarray((num_512_tiles_cur_section, nl.par_dim(p), n), dtype=nl.bfloat16, buffer=sb_mod(base_addr=sca, num_free_tiles=(num_512_tiles_cur_section, ))) |
| 120 | + sca += num_512_tiles_cur_section * n * 2 |
| 121 | + for i in nl.affine_range(num_512_tiles_cur_section): |
| 122 | + ip_k, if_k = nl.mgrid[0:p, 0:n] |
| 123 | + k_loaded[i, ip_k, if_k] = nl.load(k[batch_id, ip_k, section_len*section_i+512*i+if_k], dtype=nl.bfloat16) |
| 124 | + |
| 125 | + p, n = sb_p, d |
| 126 | + v_loaded = nl.ndarray((num_128_tiles_cur_section, nl.par_dim(p), n), dtype=nl.bfloat16, buffer=sb_mod(base_addr=sca, num_free_tiles=(num_128_tiles_cur_section, ))) |
| 127 | + sca += num_128_tiles_cur_section * n * 2 |
| 128 | + for i in nl.affine_range(num_128_tiles_cur_section): |
| 129 | + ip_v, if_v = nl.mgrid[0:p, 0:n] |
| 130 | + v_loaded[i, ip_v, if_v] = nl.load(v[batch_id, section_len*section_i + i * 128 + ip_v, if_v], dtype=nl.bfloat16) |
| 131 | + |
| 132 | + q_loaded = nl.ndarray((num_grps, nl.par_dim(d), sb_p), dtype=nl.bfloat16, buffer=sb_mod(base_addr=sca, num_free_tiles=(2, ))) |
| 133 | + sca += 2 * sb_p * 2 |
| 134 | + reduce14_num_parts = 128 |
| 135 | + scaling_factor = nl.ndarray((num_grps, nl.par_dim(reduce14_num_parts), 1), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(2, ))) |
| 136 | + sca += 2 * 1 * 4 |
| 137 | + num_blks = num_512_tiles_cur_section |
| 138 | + temp_reduce14_sbuf = nl.ndarray((num_grps, nl.par_dim(reduce14_num_parts), num_blks), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, ))) |
| 139 | + sca += 2 * num_blks * 4 |
| 140 | + |
| 141 | + p, n = d,sb_p |
| 142 | + psum_p, psum_n = 128, sb_p * 4 # (128, 512) |
| 143 | + sbuf_p, sbuf_n = psum_p, psum_n*4 # (128, 2048) |
| 144 | + |
| 145 | + mhlo_mul_2 = nl.ndarray((num_grps, num_2048_tiles_cur_section, nl.par_dim(sbuf_p), sbuf_n), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(2, num_2048_tiles_cur_section))) |
| 146 | + sca += num_2048_tiles_cur_section * sbuf_n * 4 * 2 |
| 147 | + mm1_psum_dot = nl.ndarray((num_grps, num_2048_tiles_cur_section, 4, nl.par_dim(psum_p), psum_n), dtype=nl.float32, buffer=nki.compiler.psum.alloc(mm1_dot_psum_alloc)) |
| 148 | + |
| 149 | + final_reduce_max = nl.ndarray((num_grps, nl.par_dim(128), 1), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, ))) |
| 150 | + sca += 2 * 4 |
| 151 | + |
| 152 | + prev_runnning_max = nl.ndarray((num_grps, nl.par_dim(reduce14_num_parts), 1), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, ))) |
| 153 | + sca += 2 * 4 |
| 154 | + |
| 155 | + exp_inst_elems = 2048 |
| 156 | + exp_insts = 2048 // exp_inst_elems # num of exp insts per si iter |
| 157 | + ip_final_reduce_sum, _ = nl.mgrid[0:128, 0:1] |
| 158 | + final_reduce_sum_b = nl.ndarray((num_grps, nl.par_dim(128), section_len//exp_inst_elems), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(2, ))) |
| 159 | + sca += 2 * (section_len // exp_inst_elems) * 4 |
| 160 | + |
| 161 | + final_reduce_sum_b_collect = nl.ndarray((num_grps, nl.par_dim(128), 1), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, ))) |
| 162 | + sca += 2 * 1 * 4 |
| 163 | + |
| 164 | + prev_running_sum = nl.ndarray(shape=(num_grps, nl.par_dim(128), 1), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, ))) |
| 165 | + sca += 2 * 1 * 4 |
| 166 | + |
| 167 | + prev_output = nl.ndarray((num_grps, nl.par_dim(128), 128), dtype=o.dtype, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, ))) |
| 168 | + sca += 2 * 128 * 4 |
| 169 | + mm2_sbuf_res = nl.ndarray((num_grps, nl.par_dim(128), 128), dtype=q.dtype, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, ))) |
| 170 | + sca += 2 * 128 * 4 |
| 171 | + mm2_div_sbuf = nl.ndarray((num_grps, nl.par_dim(128), 128), dtype=q.dtype, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, ))) |
| 172 | + sca += 2 * 128 * 4 |
| 173 | + |
| 174 | + num_tps = exp_inst_elems // 128 |
| 175 | + num_tp_grps = num_tps // 4 |
| 176 | + num_tps_in_grp = 4 |
| 177 | + n_per_part = num_tps_in_grp * 128 |
| 178 | + |
| 179 | + |
| 180 | + # p, n, access_n = 128, 2048, exp_inst_elems |
| 181 | + exp6_sbuf = nl.ndarray((num_grps, num_2048_tiles_cur_section, nl.par_dim(p), 2048), dtype=nl.bfloat16, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, num_2048_tiles_cur_section))) |
| 182 | + sca += 2 * num_2048_tiles_cur_section * 2048 |
| 183 | + |
| 184 | + tp_sbuf = nl.ndarray((num_grps, num_2048_tiles_cur_section, num_tp_grps, nl.par_dim(128), n_per_part), dtype=nl.bfloat16, |
| 185 | + buffer=sb_mod(base_addr=sca, num_free_tiles=(1, num_2048_tiles_cur_section, num_tp_grps))) |
| 186 | + sca += num_2048_tiles_cur_section * num_tp_grps * 2 * n_per_part |
| 187 | + |
| 188 | + mm2_p, mm2_n = sb_p, d |
| 189 | + mm2_sbuf = nl.ndarray((num_grps, nl.par_dim(mm2_p), mm2_n), dtype=nl.float32, buffer=sb_mod(base_addr=sca, num_free_tiles=(1, ))) |
| 190 | + sca += 2 * mm2_n * 4 |
| 191 | + tp_psum = nl.ndarray((num_grps, num_2048_tiles_cur_section, num_tp_grps, nl.par_dim(128), n_per_part), dtype=nl.float32, buffer=nki.compiler.psum.alloc(exp_tp_psum_alloc)) |
| 192 | + |
| 193 | + mm2_psum = nl.ndarray((num_grps, num_2048_tiles_cur_section, nl.par_dim(sb_p), mm2_n), dtype=nl.float32, buffer=nki.compiler.psum.alloc(mm2_dot_psum_alloc)) |
| 194 | + iq_p, iq_f = nl.mgrid[0:p, 0:n] |
| 195 | + def load_q(grp_i): |
| 196 | + q_loaded[grp_i, iq_p, iq_f] = nl.load(q[batch_id, iq_p, grp_i*n+iq_f]) |
| 197 | + |
| 198 | + def qk_and_max(grp_i): |
| 199 | + for si in nl.affine_range(num_2048_tiles_cur_section): |
| 200 | + for pi in nl.affine_range(4): |
| 201 | + loc_512_tile_i = si*4+pi |
| 202 | + ip_res, if_res = nl.mgrid[0:128, 0:512] |
| 203 | + ip_reduce_res, _ = nl.mgrid[0:128, 0:1] |
| 204 | + mm1_psum_dot[grp_i, si, pi, ip_res, if_res] = nisa.nc_matmul(q_loaded[grp_i, :, :], k_loaded[loc_512_tile_i, :, :]) |
| 205 | + mhlo_mul_2[grp_i, si, ip_res, pi*512+if_res] = nisa.tensor_scalar_reduce( |
| 206 | + data=mm1_psum_dot[grp_i, si, pi, ip_res, if_res], op0=np.multiply, operand0=softmax_scale, |
| 207 | + reduce_op=nl.max, reduce_res=temp_reduce14_sbuf[grp_i, ip_reduce_res, si*4+pi], name="mm1-tsp" |
| 208 | + ) |
| 209 | + |
| 210 | + def update_max(grp_i): |
| 211 | + ip_reduce, _= nl.mgrid[0:128, 0:1] |
| 212 | + final_reduce_max[grp_i, ip_reduce, 0] = nisa.tensor_reduce(np.max, temp_reduce14_sbuf[grp_i], 1, negate=True) |
| 213 | + if section_i == 0: |
| 214 | + running_max[ip_reduce, grp_i] = nisa.tensor_copy(final_reduce_max[grp_i]) |
| 215 | + if section_i > 0: |
| 216 | + prev_runnning_max[grp_i, ip_reduce, 0] = nisa.activation(np.copy, running_max[ip_reduce, grp_i], scale=-1.0, bias=zero_bias_tensor) |
| 217 | + running_max[ip_reduce, grp_i] = nisa.tensor_tensor(running_max[ip_reduce, grp_i], final_reduce_max[grp_i], op=nl.minimum) |
| 218 | + scaling_factor[grp_i, ip_reduce, 0] = nisa.activation(np.exp, prev_runnning_max[grp_i], bias=running_max[ip_reduce, grp_i], scale=1.0) |
| 219 | + |
| 220 | + assert section_len//exp_inst_elems == num_2048_tiles_cur_section |
| 221 | + def exp(grp_i): |
| 222 | + ip_reduce, _= nl.mgrid[0:128, 0:1] |
| 223 | + for si in nl.affine_range(num_2048_tiles_cur_section): |
| 224 | + p, n, access_n = 128, 2048, exp_inst_elems |
| 225 | + ip_p, ip_n = nl.mgrid[0:p, 0:access_n] |
| 226 | + for pi in nl.affine_range(exp_insts): # This loop doesn't actually exist with current config, exp_insts==1 in current config |
| 227 | + exp6_sbuf[grp_i, si, ip_p, pi*access_n+ip_n] = nisa.activation_reduce(np.exp, mhlo_mul_2[grp_i, si, ip_p, access_n*pi+ip_n], |
| 228 | + reduce_op=np.add, reduce_res=final_reduce_sum_b[grp_i, ip_final_reduce_sum, si*exp_insts+pi], |
| 229 | + bias=running_max[ip_reduce, grp_i], name='exp6', |
| 230 | + ) |
| 231 | + |
| 232 | + def tp(grp_i): |
| 233 | + for si in nl.affine_range(num_2048_tiles_cur_section): |
| 234 | + for tp_grp in nl.affine_range(num_tp_grps): |
| 235 | + ip_tp, if_tp = nl.mgrid[0:128, 0:128] |
| 236 | + ip_cp, if_cp = nl.mgrid[0:128, 0:n_per_part] |
| 237 | + for ti in nl.affine_range(num_tps_in_grp): |
| 238 | + tp_psum[grp_i, si, tp_grp, ip_tp, ti*128+if_tp] = nisa.nc_matmul(exp6_sbuf[grp_i, si, ip_tp, tp_grp*n_per_part+ti*128+if_tp], identity_load) |
| 239 | + tp_sbuf[grp_i, si, tp_grp, ip_cp, if_cp] = nisa.tensor_copy(tp_psum[grp_i, si, tp_grp], dtype=nl.bfloat16, name='tp-act-cp', |
| 240 | + ) |
| 241 | + def pv(grp_i): |
| 242 | + mm2_sbuf[grp_i] = 0.0 |
| 243 | + for mm2i in nl.affine_range(num_2048_tiles_cur_section): |
| 244 | + num_tp_grps_in_2048_tile = 4 |
| 245 | + # mm2_psum = nl.zeros((nl.par_dim(sb_p), mm2_n), dtype=nl.float32, buffer=nl.psum) |
| 246 | + for tp_grp_i in nl.affine_range(num_tp_grps_in_2048_tile): |
| 247 | + mm2_num_grps = 4 |
| 248 | + ip_mm2, if_mm2 = nl.mgrid[0:128, 0:128] |
| 249 | + for mm2_si in nl.affine_range(mm2_num_grps): |
| 250 | + mm2_psum[grp_i, mm2i, ip_mm2, if_mm2] += nisa.nc_matmul(tp_sbuf[grp_i, mm2i, tp_grp_i, ip_mm2, mm2_si*128+if_mm2],v_loaded[mm2i*16+tp_grp_i*4+mm2_si, ip_mm2, if_mm2]) |
| 251 | + mm2_sbuf[grp_i] = nl.loop_reduce(mm2_psum[grp_i, mm2i], np.add, loop_indices=[mm2i],name='mm2-itt', |
| 252 | + ) |
| 253 | + |
| 254 | + def fused_qkmax_and_pv(grp_i): |
| 255 | + mm2_sbuf[grp_i] = 0.0 |
| 256 | + for si in nl.affine_range(num_2048_tiles_cur_section): |
| 257 | + for pi in nl.affine_range(4): |
| 258 | + loc_512_tile_i = si*4+pi |
| 259 | + ip_res, if_res = nl.mgrid[0:128, 0:512] |
| 260 | + ip_reduce_res, _ = nl.mgrid[0:128, 0:1] |
| 261 | + mm1_psum_dot[grp_i+2, si, pi, ip_res, if_res] = nisa.nc_matmul(q_loaded[grp_i+2, :, :], k_loaded[loc_512_tile_i, :, :]) |
| 262 | + mhlo_mul_2[grp_i+2, si, ip_res, pi*512+if_res] = nisa.tensor_scalar_reduce( |
| 263 | + data=mm1_psum_dot[grp_i+2, si, pi, ip_res, if_res], op0=np.multiply, operand0=softmax_scale, |
| 264 | + reduce_op=nl.max, reduce_res=temp_reduce14_sbuf[grp_i+2, ip_reduce_res, si*4+pi], name="mm1-tsp" |
| 265 | + ) |
| 266 | + mm2i=si |
| 267 | + num_tp_grps_in_2048_tile = 4 |
| 268 | + # mm2_psum = nl.zeros((nl.par_dim(sb_p), mm2_n), dtype=nl.float32, buffer=nl.psum) |
| 269 | + for tp_grp_i in nl.affine_range(num_tp_grps_in_2048_tile): |
| 270 | + mm2_num_grps = 4 |
| 271 | + ip_mm2, if_mm2 = nl.mgrid[0:128, 0:128] |
| 272 | + for mm2_si in nl.affine_range(mm2_num_grps): |
| 273 | + mm2_psum[grp_i, mm2i, ip_mm2, if_mm2] += nisa.nc_matmul(tp_sbuf[grp_i, mm2i, tp_grp_i, ip_mm2, mm2_si*128+if_mm2],v_loaded[mm2i*16+tp_grp_i*4+mm2_si, ip_mm2, if_mm2]) |
| 274 | + mm2_sbuf[grp_i] = nl.loop_reduce(mm2_psum[grp_i, mm2i], np.add, loop_indices=[mm2i],name='mm2-itt') |
| 275 | + |
| 276 | + def write_back(grp_i): |
| 277 | + ip_o, if_o = nl.mgrid[0:128,0:128] |
| 278 | + |
| 279 | + ip_reduce, _= nl.mgrid[0:128, 0:1] |
| 280 | + final_reduce_sum_b_collect[grp_i] = nisa.tensor_reduce(np.sum, final_reduce_sum_b[grp_i], axis=(1,)) |
| 281 | + if section_i == 0: |
| 282 | + running_sum[ip_reduce, grp_i] = nisa.tensor_copy(final_reduce_sum_b_collect[grp_i]) |
| 283 | + if section_i > 0: |
| 284 | + prev_running_sum[grp_i] = nisa.tensor_copy(running_sum[ip_reduce, grp_i]) |
| 285 | + running_sum[ip_reduce, grp_i] = nisa.tensor_scalar(prev_running_sum[grp_i, ip_reduce, 0], np.multiply, scaling_factor[grp_i], op1=nl.add, operand1=final_reduce_sum_b_collect[grp_i]) |
| 286 | + if section_i == num_sections - 1: |
| 287 | + div_25_sbuf[ip_reduce, grp_i] = nisa.reciprocal(running_sum[ip_reduce, grp_i]) |
| 288 | + |
| 289 | + if section_i == 0: |
| 290 | + nl.store(o[batch_id, grp_i*sb_p+ip_o, if_o], value=mm2_sbuf[grp_i]) |
| 291 | + |
| 292 | + if section_i == num_sections - 1: |
| 293 | + prev_output[grp_i] = nl.load(o[batch_id, grp_i*sb_p+ip_o, if_o]) |
| 294 | + mm2_sbuf_res[grp_i] = nisa.scalar_tensor_tensor(data=prev_output[grp_i], op0=np.multiply, operand0=scaling_factor[grp_i], op1=np.add, operand1=mm2_sbuf[grp_i]) |
| 295 | + |
| 296 | + mm2_div_sbuf[grp_i] = nisa.activation(np.copy, mm2_sbuf_res[grp_i], scale=div_25_sbuf[ip_reduce, grp_i],bias=zero_bias_tensor) |
| 297 | + nl.store(o[batch_id, grp_i*sb_p+ip_o, if_o], value=mm2_div_sbuf[grp_i]) |
| 298 | + load_q(0) |
| 299 | + qk_and_max(0) |
| 300 | + update_max(0) |
| 301 | + exp(0) |
| 302 | + tp(0) |
| 303 | + load_q(1) |
| 304 | + qk_and_max(1) |
| 305 | + update_max(1) |
| 306 | + for grp_i in nl.affine_range(num_grps-2, precise_schedule=True): # for each block of seq_q |
| 307 | + load_q(grp_i+2) |
| 308 | + exp(grp_i+1) |
| 309 | + fused_qkmax_and_pv(grp_i) |
| 310 | + tp(grp_i+1) |
| 311 | + write_back(grp_i) |
| 312 | + update_max(grp_i+2) |
| 313 | + pv(num_grps-2) |
| 314 | + write_back(num_grps-2) |
| 315 | + exp(num_grps-1) |
| 316 | + tp(num_grps-1) |
| 317 | + pv(num_grps-1) |
| 318 | + write_back(num_grps-1) |
| 319 | + |
| 320 | + |
| 321 | + return o |
| 322 | + |
| 323 | +def softmax(x: np.ndarray, dim: int, zero_max_mode=False, |
| 324 | + mixed_precision=False): |
| 325 | + max_value = np.amax(x, axis=dim, keepdims=True) |
| 326 | + max_value = np.maximum(0, max_value) if zero_max_mode else max_value |
| 327 | + |
| 328 | + exp = np.exp(x - max_value) |
| 329 | + |
| 330 | + reduce = np.add.reduce(exp.astype(np.float32), axis=dim, keepdims=True).astype(x.dtype) |
| 331 | + |
| 332 | + result = exp / reduce |
| 333 | + |
| 334 | + return exp / reduce |
| 335 | + |
| 336 | +def cpu_attention_forward(q, k, v, softmax_scale, use_causal_mask=True, mixed_precision=True): |
| 337 | + def mixed_precision_matmul(a, b): |
| 338 | + input_dtype = a.dtype |
| 339 | + a, b = a.astype(np.float32), b.astype(np.float32) |
| 340 | + c = np.matmul(a, b) |
| 341 | + return c.astype(input_dtype) |
| 342 | + |
| 343 | + # Compute golden output |
| 344 | + q_scaled = q * softmax_scale |
| 345 | + |
| 346 | + raw_score = mixed_precision_matmul(q_scaled.transpose(0, 2, 1), k) |
| 347 | + |
| 348 | + if use_causal_mask: |
| 349 | + # raw_score has K seq in the most inner dim |
| 350 | + # we want to mask all elements where Q idx is smaller than K idx with -inf |
| 351 | + # this maps to the upper triangle of the final two axes |
| 352 | + for i in range(raw_score.shape[0]): |
| 353 | + for j in range(raw_score.shape[1]): |
| 354 | + # -inf triggers invalid input error in softmax implementation, use a small negative instead |
| 355 | + # k=1 to exclude the diagonal, because each token can still attend to itself |
| 356 | + raw_score[i, j][np.triu_indices_from(raw_score[i, j], k=1)] = -9984.0 |
| 357 | + |
| 358 | + norm_score = softmax(raw_score, dim=-1, mixed_precision=mixed_precision) |
| 359 | + |
| 360 | + # Transpose the result so it has the same layout as ours |
| 361 | + out_golden = mixed_precision_matmul(norm_score, v) |
| 362 | + |
| 363 | + return out_golden, norm_score |
| 364 | + |
| 365 | +bs=1 |
| 366 | +seqlen_q=16*1024 |
| 367 | +seqlen_k=16*1024 |
| 368 | +d=128 |
| 369 | +softmax_scale=0.125 |
| 370 | +dtype=bfloat16 |
| 371 | + |
| 372 | +q = np.random.rand(bs, d, seqlen_q).astype(dtype) |
| 373 | +k = np.random.rand(bs, d, seqlen_k).astype(dtype) |
| 374 | +v = np.random.rand(bs, seqlen_k, d).astype(dtype) |
| 375 | +o = flash_fwd[0](q, k, v, mixed_precision=True, softmax_scale=softmax_scale, use_causal_mask=False) |
| 376 | +o_std, scores = cpu_attention_forward(q, k, v, softmax_scale, use_causal_mask=False, mixed_precision=True) |
| 377 | + |
| 378 | +np.testing.assert_allclose(o, o_std, atol=1e-2) |
0 commit comments