|
| 1 | +""" |
| 2 | +Copyright (c) 2025, Amazon.com. All Rights Reserved |
| 3 | +
|
| 4 | +kernels - Builtin high performance NKI kernels. |
| 5 | +
|
| 6 | +""" |
| 7 | + |
| 8 | +from neuronxcc import nki |
| 9 | +import neuronxcc.nki.isa as nisa |
| 10 | +import neuronxcc.nki.language as nl |
| 11 | + |
| 12 | +@nki.jit(platform_target='trn2') |
| 13 | +def quantized_double_row_matmul( |
| 14 | + lhs, |
| 15 | + rhs_quantized, rhs_scale, |
| 16 | + # Meta-parameters |
| 17 | + TILES_IN_BLOCK_M, |
| 18 | + TILES_IN_BLOCK_N, |
| 19 | + TILES_IN_BLOCK_K |
| 20 | +): |
| 21 | + """NKI kernel to compute a matrix multiplication by blocking along all dimensions |
| 22 | + and performing fp8_e4m3 quantization on lhs matrix. |
| 23 | + |
| 24 | + Args: |
| 25 | + lhs: an unquantized input tensor of shape [M,K], where K is a multiple of 128 * |
| 26 | + TILES_IN_BLOCK_K and M is a multiple of 128 * TILES_IN_BLOCK_M. It is the |
| 27 | + left-hand-side argument of the matrix multiplication. |
| 28 | + rhs_quantized: a pre-quantized input tensor of dtype float8_e4m3 and of shape |
| 29 | + [K // 2,2 * N] (reshaped from the original [K,N] rhs) where K is a multiple of 128 * |
| 30 | + TILES_IN_BLOCK_K and N is a multiple of 512 * TILES_IN_BLOCK_N. It is the |
| 31 | + right-hand-side argument of the matrix multiplication. See test_double_row_matmul.py |
| 32 | + for the expected reshape to be performed on the original rhs matrix. |
| 33 | + rhs_scale: the quantization column-wise scale of rhs of shape [128, N] that is |
| 34 | + pre-broadcasted from [1, N]. |
| 35 | + TILES_IN_BLOCK_*: meta parameters to control blocking dimensions |
| 36 | + Returns: |
| 37 | + result: the resulting output tensor of shape [M,N] |
| 38 | + """ |
| 39 | + |
| 40 | + assert rhs_quantized.dtype == nl.float8_e4m3, "rhs must be pre-quantized to dtype float8_e4m3" |
| 41 | + |
| 42 | + M, K = lhs.shape |
| 43 | + K_RESHAPED, N_RESHAPED = rhs_quantized.shape |
| 44 | + K_ = 2 * K_RESHAPED |
| 45 | + |
| 46 | + assert K == K_, "lhs and rhs must have the same contraction dimension" |
| 47 | + |
| 48 | + assert N_RESHAPED % 2 == 0, f"N_RESHAPED={N_RESHAPED} must be divisible by 2" |
| 49 | + N = N_RESHAPED // 2 |
| 50 | + |
| 51 | + TILE_M = nl.tile_size.gemm_stationary_fmax # 128 |
| 52 | + TILE_K = nl.tile_size.pmax # 128 |
| 53 | + TILE_N = nl.tile_size.gemm_moving_fmax # 512 |
| 54 | + |
| 55 | + BLOCK_M = TILE_M * TILES_IN_BLOCK_M |
| 56 | + BLOCK_N = TILE_N * TILES_IN_BLOCK_N |
| 57 | + BLOCK_K = TILE_K * TILES_IN_BLOCK_K |
| 58 | + |
| 59 | + assert M % BLOCK_M == 0 |
| 60 | + assert N % BLOCK_N == 0 |
| 61 | + assert K % BLOCK_K == 0 |
| 62 | + |
| 63 | + # The size has to be multiple of block size. |
| 64 | + NUM_BLOCK_M = M // BLOCK_M |
| 65 | + NUM_BLOCK_N = N // BLOCK_N |
| 66 | + NUM_BLOCK_K = K // BLOCK_K |
| 67 | + |
| 68 | + # dtype fp8_e4m3 can represent [-240, 240]. |
| 69 | + FP8_RANGE = 240 |
| 70 | + |
| 71 | + assert TILES_IN_BLOCK_K % 2 == 0, f"TILES_IN_BLOCK_K={TILES_IN_BLOCK_K} must be even to load 2 tiles at a time for double row matmul" |
| 72 | + |
| 73 | + result = nl.ndarray((M, N), dtype=lhs.dtype, buffer=nl.shared_hbm) |
| 74 | + |
| 75 | + # Blocking M dimension (lhs partition dimension). |
| 76 | + for m in nl.affine_range(NUM_BLOCK_M): |
| 77 | + result_tiles = nl.zeros((TILE_M, NUM_BLOCK_N * TILES_IN_BLOCK_M * TILES_IN_BLOCK_N * TILE_N), |
| 78 | + dtype=lhs.dtype, |
| 79 | + buffer=nl.sbuf) |
| 80 | + |
| 81 | + # Blocking K dimension (the contraction dimension). |
| 82 | + # Use `sequential_range` because we do not want the compiler to change this loop by, |
| 83 | + # for example, vectorizing it. |
| 84 | + for k in nl.sequential_range(NUM_BLOCK_K): |
| 85 | + lhsT_quantized_tiles = nl.ndarray((TILES_IN_BLOCK_M, nl.par_dim(TILE_M), BLOCK_K), |
| 86 | + dtype=nl.float8_e4m3, |
| 87 | + buffer=nl.sbuf) |
| 88 | + lhsT_scale_tiles = nl.ndarray((TILES_IN_BLOCK_M, nl.par_dim(TILE_M), 1), |
| 89 | + dtype=lhs.dtype, |
| 90 | + buffer=nl.sbuf) |
| 91 | + |
| 92 | + i_lhs = nl.mgrid[0:TILE_M, 0:BLOCK_K] |
| 93 | + for bm_l in nl.affine_range(TILES_IN_BLOCK_M): |
| 94 | + # Load and quantize tiles from rhs, |
| 95 | + # setting the load tile to [TILE_M, BLOCK_K] to optimize DMA performance. |
| 96 | + lhs_i_m = m * BLOCK_M + bm_l * TILE_M + i_lhs.p |
| 97 | + lhs_i_k = k * BLOCK_K + i_lhs.x |
| 98 | + |
| 99 | + tile_block = nl.load(lhs[lhs_i_m, lhs_i_k]) |
| 100 | + |
| 101 | + # FIXME: use nisa.tensor_scalar_reduce to fuse nl.abs and nisa.tensor_reduce into |
| 102 | + # 1 operation. |
| 103 | + abs_tile_block = nl.abs(tile_block) |
| 104 | + lhsT_scale_tiles[bm_l] = nisa.tensor_reduce(nl.max, |
| 105 | + abs_tile_block, |
| 106 | + axis=[1]) |
| 107 | + lhsT_scale_tiles[bm_l] = nl.divide(lhsT_scale_tiles[bm_l], FP8_RANGE) |
| 108 | + lhsT_quantized_tiles[bm_l] = nl.divide(tile_block, lhsT_scale_tiles[bm_l]) |
| 109 | + |
| 110 | + # For each [TILE_M, TILE_K] tiles, since TILE_K == TILE_M and the K dimension needs to be |
| 111 | + # along the partition dimension, transpose said tiles in-place. |
| 112 | + for bk_l in nl.affine_range(TILES_IN_BLOCK_K): |
| 113 | + # FIXME: use dma_transpose instead of nc_transpose. |
| 114 | + lhsT_quantized_tiles[bm_l, :, |
| 115 | + TILE_M * bk_l:(bk_l + 1) * TILE_M] = nisa.nc_transpose(lhsT_quantized_tiles[bm_l, :, |
| 116 | + TILE_M * bk_l:(bk_l + 1) * TILE_M]) |
| 117 | + |
| 118 | + # Each lhs block's matmul results needs to be dequantized independent of another lhs block's matmul results. |
| 119 | + # scoped_result_tiles stores the non-dequantized matmul results scoped to each `for m` and `for k` loops. |
| 120 | + scoped_result_tiles = nl.zeros((TILE_M, NUM_BLOCK_N * TILES_IN_BLOCK_M * TILES_IN_BLOCK_N * TILE_N), |
| 121 | + dtype=lhs.dtype, |
| 122 | + buffer=nl.sbuf) |
| 123 | + |
| 124 | + for n in nl.affine_range(NUM_BLOCK_N): |
| 125 | + # Loading tiles from rhs, |
| 126 | + # setting the load tile to [TILE_K, 2 * BLOCK_N] to optimize DMA performance |
| 127 | + # (i.e. loading 2 rows of a rhs block at a time). |
| 128 | + i_rhs = nl.mgrid[0:TILE_K, 0:2 * BLOCK_N] |
| 129 | + |
| 130 | + rhs_quantized_tiles = nl.ndarray((TILES_IN_BLOCK_K // 2, nl.par_dim(TILE_K), 2 * BLOCK_N), dtype=rhs_quantized.dtype) |
| 131 | + for bk_r in nl.affine_range(TILES_IN_BLOCK_K // 2): |
| 132 | + rhs_quantized_i_k = (k * TILES_IN_BLOCK_K // 2 + bk_r) * TILE_K + i_rhs.p |
| 133 | + rhs_quantized_i_n = 2 * n * BLOCK_N + i_rhs.x |
| 134 | + rhs_quantized_tiles[bk_r] = nl.load(rhs_quantized[rhs_quantized_i_k, rhs_quantized_i_n]) |
| 135 | + |
| 136 | + # Do matmul with all tiles in the loaded lhs and rhs blocks. |
| 137 | + i_res_mm = nl.mgrid[0:TILE_M, 0:TILE_N] |
| 138 | + for bm in nl.affine_range(TILES_IN_BLOCK_M): |
| 139 | + for bn in nl.affine_range(TILES_IN_BLOCK_N): |
| 140 | + res_tile = nl.zeros((TILE_M, TILE_N), dtype=nl.float32, buffer=nl.psum) |
| 141 | + for bk in nl.affine_range(TILES_IN_BLOCK_K // 2): |
| 142 | + i_k, i_tile_m, i_m = nl.mgrid[0:TILE_K, 0:2, 0:TILE_M] |
| 143 | + lhsT_double_tile = lhsT_quantized_tiles[ |
| 144 | + bm, |
| 145 | + i_k, |
| 146 | + bk * (2 * TILE_M) + i_tile_m * TILE_M + i_m |
| 147 | + ] |
| 148 | + assert lhsT_double_tile.shape == (TILE_K, 2, TILE_M) |
| 149 | + |
| 150 | + i_k, i_tile_n, i_n = nl.mgrid[0:TILE_K, 0:2, 0:TILE_N] |
| 151 | + rhs_double_tile = rhs_quantized_tiles[ |
| 152 | + bk, |
| 153 | + i_k, |
| 154 | + 2 * bn * TILE_N + i_tile_n * TILE_N + i_n |
| 155 | + ] |
| 156 | + assert rhs_double_tile.shape == (TILE_K, 2, TILE_N) |
| 157 | + |
| 158 | + res_tile[...] += nisa.nc_matmul(lhsT_double_tile, |
| 159 | + rhs_double_tile, |
| 160 | + perf_mode='double_row_gen3') |
| 161 | + |
| 162 | + i_scoped_result_tiles_k = i_res_mm.p |
| 163 | + i_scoped_result_tiles_n = bm * (NUM_BLOCK_N * BLOCK_N) + n * BLOCK_N + bn * TILE_N + i_res_mm.x |
| 164 | + scoped_result_tiles[i_scoped_result_tiles_k, i_scoped_result_tiles_n] += res_tile[...] |
| 165 | + |
| 166 | + # FIXME: dequantize using both lhs and rhs scales using nisa.scalar_tensor_tensor when |
| 167 | + # accumulating from PSUM to SBUF. |
| 168 | + # Partially dequantize matmul results using lhs block scale. |
| 169 | + i_scoped_result_tiles = nl.mgrid[0:TILE_K, 0:NUM_BLOCK_N * BLOCK_N] |
| 170 | + for bm in nl.affine_range(TILES_IN_BLOCK_M): |
| 171 | + result_tiles_i_k = i_scoped_result_tiles.p |
| 172 | + result_tiles_i_n = bm * NUM_BLOCK_N * BLOCK_N + i_scoped_result_tiles.x |
| 173 | + dequantized_tile_block = nisa.tensor_tensor( |
| 174 | + scoped_result_tiles[result_tiles_i_k, result_tiles_i_n], |
| 175 | + lhsT_scale_tiles[bm], |
| 176 | + nl.multiply |
| 177 | + ) |
| 178 | + |
| 179 | + result_tiles[result_tiles_i_k, result_tiles_i_n] += dequantized_tile_block |
| 180 | + |
| 181 | + # Dequantize matmul results using rhs scale and copying results from SBUF to HBM. |
| 182 | + rhs_scale_sbuf = nl.ndarray(rhs_scale.shape, buffer=nl.sbuf, dtype=rhs_scale.dtype) |
| 183 | + rhs_scale_sbuf = nl.load(rhs_scale) |
| 184 | + |
| 185 | + i_result = nl.mgrid[0:TILE_M, 0:N] |
| 186 | + for bm in nl.affine_range(TILES_IN_BLOCK_M): |
| 187 | + result_tiles_i_k = i_result.p |
| 188 | + result_tiles_i_n = bm * (NUM_BLOCK_N * BLOCK_N) + i_result.x |
| 189 | + |
| 190 | + result_i_m = m * BLOCK_M + bm * TILE_M + i_result.p |
| 191 | + result_i_n = i_result.x |
| 192 | + |
| 193 | + # FIXME: remove after dequantizing using nisa.scalar_tensor_tensor for dequantization. |
| 194 | + dequantized = nisa.tensor_tensor( |
| 195 | + result_tiles[result_tiles_i_k, result_tiles_i_n], |
| 196 | + rhs_scale_sbuf, |
| 197 | + nl.multiply |
| 198 | + ) |
| 199 | + |
| 200 | + nl.store(result[result_i_m, result_i_n], value=dequantized) |
| 201 | + |
| 202 | + return result |
0 commit comments