Skip to content

Commit a5e9d8c

Browse files
authored
Double row matmul (#49)
* add: double row matmul kernel
1 parent 97e49c9 commit a5e9d8c

File tree

2 files changed

+334
-0
lines changed

2 files changed

+334
-0
lines changed
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
"""
2+
Copyright (c) 2025, Amazon.com. All Rights Reserved
3+
4+
kernels - Builtin high performance NKI kernels.
5+
6+
"""
7+
8+
from neuronxcc import nki
9+
import neuronxcc.nki.isa as nisa
10+
import neuronxcc.nki.language as nl
11+
12+
@nki.jit(platform_target='trn2')
13+
def quantized_double_row_matmul(
14+
lhs,
15+
rhs_quantized, rhs_scale,
16+
# Meta-parameters
17+
TILES_IN_BLOCK_M,
18+
TILES_IN_BLOCK_N,
19+
TILES_IN_BLOCK_K
20+
):
21+
"""NKI kernel to compute a matrix multiplication by blocking along all dimensions
22+
and performing fp8_e4m3 quantization on lhs matrix.
23+
24+
Args:
25+
lhs: an unquantized input tensor of shape [M,K], where K is a multiple of 128 *
26+
TILES_IN_BLOCK_K and M is a multiple of 128 * TILES_IN_BLOCK_M. It is the
27+
left-hand-side argument of the matrix multiplication.
28+
rhs_quantized: a pre-quantized input tensor of dtype float8_e4m3 and of shape
29+
[K // 2,2 * N] (reshaped from the original [K,N] rhs) where K is a multiple of 128 *
30+
TILES_IN_BLOCK_K and N is a multiple of 512 * TILES_IN_BLOCK_N. It is the
31+
right-hand-side argument of the matrix multiplication. See test_double_row_matmul.py
32+
for the expected reshape to be performed on the original rhs matrix.
33+
rhs_scale: the quantization column-wise scale of rhs of shape [128, N] that is
34+
pre-broadcasted from [1, N].
35+
TILES_IN_BLOCK_*: meta parameters to control blocking dimensions
36+
Returns:
37+
result: the resulting output tensor of shape [M,N]
38+
"""
39+
40+
assert rhs_quantized.dtype == nl.float8_e4m3, "rhs must be pre-quantized to dtype float8_e4m3"
41+
42+
M, K = lhs.shape
43+
K_RESHAPED, N_RESHAPED = rhs_quantized.shape
44+
K_ = 2 * K_RESHAPED
45+
46+
assert K == K_, "lhs and rhs must have the same contraction dimension"
47+
48+
assert N_RESHAPED % 2 == 0, f"N_RESHAPED={N_RESHAPED} must be divisible by 2"
49+
N = N_RESHAPED // 2
50+
51+
TILE_M = nl.tile_size.gemm_stationary_fmax # 128
52+
TILE_K = nl.tile_size.pmax # 128
53+
TILE_N = nl.tile_size.gemm_moving_fmax # 512
54+
55+
BLOCK_M = TILE_M * TILES_IN_BLOCK_M
56+
BLOCK_N = TILE_N * TILES_IN_BLOCK_N
57+
BLOCK_K = TILE_K * TILES_IN_BLOCK_K
58+
59+
assert M % BLOCK_M == 0
60+
assert N % BLOCK_N == 0
61+
assert K % BLOCK_K == 0
62+
63+
# The size has to be multiple of block size.
64+
NUM_BLOCK_M = M // BLOCK_M
65+
NUM_BLOCK_N = N // BLOCK_N
66+
NUM_BLOCK_K = K // BLOCK_K
67+
68+
# dtype fp8_e4m3 can represent [-240, 240].
69+
FP8_RANGE = 240
70+
71+
assert TILES_IN_BLOCK_K % 2 == 0, f"TILES_IN_BLOCK_K={TILES_IN_BLOCK_K} must be even to load 2 tiles at a time for double row matmul"
72+
73+
result = nl.ndarray((M, N), dtype=lhs.dtype, buffer=nl.shared_hbm)
74+
75+
# Blocking M dimension (lhs partition dimension).
76+
for m in nl.affine_range(NUM_BLOCK_M):
77+
result_tiles = nl.zeros((TILE_M, NUM_BLOCK_N * TILES_IN_BLOCK_M * TILES_IN_BLOCK_N * TILE_N),
78+
dtype=lhs.dtype,
79+
buffer=nl.sbuf)
80+
81+
# Blocking K dimension (the contraction dimension).
82+
# Use `sequential_range` because we do not want the compiler to change this loop by,
83+
# for example, vectorizing it.
84+
for k in nl.sequential_range(NUM_BLOCK_K):
85+
lhsT_quantized_tiles = nl.ndarray((TILES_IN_BLOCK_M, nl.par_dim(TILE_M), BLOCK_K),
86+
dtype=nl.float8_e4m3,
87+
buffer=nl.sbuf)
88+
lhsT_scale_tiles = nl.ndarray((TILES_IN_BLOCK_M, nl.par_dim(TILE_M), 1),
89+
dtype=lhs.dtype,
90+
buffer=nl.sbuf)
91+
92+
i_lhs = nl.mgrid[0:TILE_M, 0:BLOCK_K]
93+
for bm_l in nl.affine_range(TILES_IN_BLOCK_M):
94+
# Load and quantize tiles from rhs,
95+
# setting the load tile to [TILE_M, BLOCK_K] to optimize DMA performance.
96+
lhs_i_m = m * BLOCK_M + bm_l * TILE_M + i_lhs.p
97+
lhs_i_k = k * BLOCK_K + i_lhs.x
98+
99+
tile_block = nl.load(lhs[lhs_i_m, lhs_i_k])
100+
101+
# FIXME: use nisa.tensor_scalar_reduce to fuse nl.abs and nisa.tensor_reduce into
102+
# 1 operation.
103+
abs_tile_block = nl.abs(tile_block)
104+
lhsT_scale_tiles[bm_l] = nisa.tensor_reduce(nl.max,
105+
abs_tile_block,
106+
axis=[1])
107+
lhsT_scale_tiles[bm_l] = nl.divide(lhsT_scale_tiles[bm_l], FP8_RANGE)
108+
lhsT_quantized_tiles[bm_l] = nl.divide(tile_block, lhsT_scale_tiles[bm_l])
109+
110+
# For each [TILE_M, TILE_K] tiles, since TILE_K == TILE_M and the K dimension needs to be
111+
# along the partition dimension, transpose said tiles in-place.
112+
for bk_l in nl.affine_range(TILES_IN_BLOCK_K):
113+
# FIXME: use dma_transpose instead of nc_transpose.
114+
lhsT_quantized_tiles[bm_l, :,
115+
TILE_M * bk_l:(bk_l + 1) * TILE_M] = nisa.nc_transpose(lhsT_quantized_tiles[bm_l, :,
116+
TILE_M * bk_l:(bk_l + 1) * TILE_M])
117+
118+
# Each lhs block's matmul results needs to be dequantized independent of another lhs block's matmul results.
119+
# scoped_result_tiles stores the non-dequantized matmul results scoped to each `for m` and `for k` loops.
120+
scoped_result_tiles = nl.zeros((TILE_M, NUM_BLOCK_N * TILES_IN_BLOCK_M * TILES_IN_BLOCK_N * TILE_N),
121+
dtype=lhs.dtype,
122+
buffer=nl.sbuf)
123+
124+
for n in nl.affine_range(NUM_BLOCK_N):
125+
# Loading tiles from rhs,
126+
# setting the load tile to [TILE_K, 2 * BLOCK_N] to optimize DMA performance
127+
# (i.e. loading 2 rows of a rhs block at a time).
128+
i_rhs = nl.mgrid[0:TILE_K, 0:2 * BLOCK_N]
129+
130+
rhs_quantized_tiles = nl.ndarray((TILES_IN_BLOCK_K // 2, nl.par_dim(TILE_K), 2 * BLOCK_N), dtype=rhs_quantized.dtype)
131+
for bk_r in nl.affine_range(TILES_IN_BLOCK_K // 2):
132+
rhs_quantized_i_k = (k * TILES_IN_BLOCK_K // 2 + bk_r) * TILE_K + i_rhs.p
133+
rhs_quantized_i_n = 2 * n * BLOCK_N + i_rhs.x
134+
rhs_quantized_tiles[bk_r] = nl.load(rhs_quantized[rhs_quantized_i_k, rhs_quantized_i_n])
135+
136+
# Do matmul with all tiles in the loaded lhs and rhs blocks.
137+
i_res_mm = nl.mgrid[0:TILE_M, 0:TILE_N]
138+
for bm in nl.affine_range(TILES_IN_BLOCK_M):
139+
for bn in nl.affine_range(TILES_IN_BLOCK_N):
140+
res_tile = nl.zeros((TILE_M, TILE_N), dtype=nl.float32, buffer=nl.psum)
141+
for bk in nl.affine_range(TILES_IN_BLOCK_K // 2):
142+
i_k, i_tile_m, i_m = nl.mgrid[0:TILE_K, 0:2, 0:TILE_M]
143+
lhsT_double_tile = lhsT_quantized_tiles[
144+
bm,
145+
i_k,
146+
bk * (2 * TILE_M) + i_tile_m * TILE_M + i_m
147+
]
148+
assert lhsT_double_tile.shape == (TILE_K, 2, TILE_M)
149+
150+
i_k, i_tile_n, i_n = nl.mgrid[0:TILE_K, 0:2, 0:TILE_N]
151+
rhs_double_tile = rhs_quantized_tiles[
152+
bk,
153+
i_k,
154+
2 * bn * TILE_N + i_tile_n * TILE_N + i_n
155+
]
156+
assert rhs_double_tile.shape == (TILE_K, 2, TILE_N)
157+
158+
res_tile[...] += nisa.nc_matmul(lhsT_double_tile,
159+
rhs_double_tile,
160+
perf_mode='double_row_gen3')
161+
162+
i_scoped_result_tiles_k = i_res_mm.p
163+
i_scoped_result_tiles_n = bm * (NUM_BLOCK_N * BLOCK_N) + n * BLOCK_N + bn * TILE_N + i_res_mm.x
164+
scoped_result_tiles[i_scoped_result_tiles_k, i_scoped_result_tiles_n] += res_tile[...]
165+
166+
# FIXME: dequantize using both lhs and rhs scales using nisa.scalar_tensor_tensor when
167+
# accumulating from PSUM to SBUF.
168+
# Partially dequantize matmul results using lhs block scale.
169+
i_scoped_result_tiles = nl.mgrid[0:TILE_K, 0:NUM_BLOCK_N * BLOCK_N]
170+
for bm in nl.affine_range(TILES_IN_BLOCK_M):
171+
result_tiles_i_k = i_scoped_result_tiles.p
172+
result_tiles_i_n = bm * NUM_BLOCK_N * BLOCK_N + i_scoped_result_tiles.x
173+
dequantized_tile_block = nisa.tensor_tensor(
174+
scoped_result_tiles[result_tiles_i_k, result_tiles_i_n],
175+
lhsT_scale_tiles[bm],
176+
nl.multiply
177+
)
178+
179+
result_tiles[result_tiles_i_k, result_tiles_i_n] += dequantized_tile_block
180+
181+
# Dequantize matmul results using rhs scale and copying results from SBUF to HBM.
182+
rhs_scale_sbuf = nl.ndarray(rhs_scale.shape, buffer=nl.sbuf, dtype=rhs_scale.dtype)
183+
rhs_scale_sbuf = nl.load(rhs_scale)
184+
185+
i_result = nl.mgrid[0:TILE_M, 0:N]
186+
for bm in nl.affine_range(TILES_IN_BLOCK_M):
187+
result_tiles_i_k = i_result.p
188+
result_tiles_i_n = bm * (NUM_BLOCK_N * BLOCK_N) + i_result.x
189+
190+
result_i_m = m * BLOCK_M + bm * TILE_M + i_result.p
191+
result_i_n = i_result.x
192+
193+
# FIXME: remove after dequantizing using nisa.scalar_tensor_tensor for dequantization.
194+
dequantized = nisa.tensor_tensor(
195+
result_tiles[result_tiles_i_k, result_tiles_i_n],
196+
rhs_scale_sbuf,
197+
nl.multiply
198+
)
199+
200+
nl.store(result[result_i_m, result_i_n], value=dequantized)
201+
202+
return result
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
Copyright (c) 2025, Amazon.com. All Rights Reserved
3+
"""
4+
import pytest
5+
from nki_samples.reference.double_row_matmul import quantized_double_row_matmul
6+
from neuronxcc.nki import benchmark, baremetal, simulate_kernel
7+
import neuronxcc.nki.language as nl
8+
import numpy as np
9+
10+
xfail = pytest.mark.arch_specific_xfail
11+
12+
13+
bench_func = benchmark(warmup=5, iters=10)(quantized_double_row_matmul)
14+
15+
def reshape(matrix):
16+
"""
17+
Interleaves every [128,512] tiles from every 2 tile rows.
18+
19+
A [K,N] matrix is reshaped into [K//2, 2*N] where K must be divisible by 128 and
20+
N must be divisible by 512.
21+
22+
E.g. if Tij is the (i,j)-th tile and assuming a matrix with 4x4 [128,512] tiles,
23+
the reshaped matrix looks as follows
24+
25+
# T11 T12 T13 T14
26+
# T21 T22 T23 T24 reshape T11 T21 T12 T22 T13 T23 T14 T24
27+
# T31 T32 T33 T34 --------> T21 T41 T22 T42 T23 T43 T24 T44
28+
# T41 T42 T43 T44
29+
"""
30+
K, N = matrix.shape
31+
32+
TILE_K = 128
33+
TILE_N = 512
34+
35+
assert K % TILE_K == 0
36+
assert N % TILE_N == 0
37+
38+
result = np.zeros((K // 2, 2 * N))
39+
40+
for k in range(0, K // TILE_K, 2):
41+
for n in range(N // TILE_N):
42+
# Get 2 tiles in the same tile column and consecutive tile rows.
43+
tile1 = matrix[k * TILE_K:(k + 1) * TILE_K, n * TILE_N:(n + 1) * TILE_N]
44+
tile2 = matrix[(k + 1) * TILE_K:(k + 2) * TILE_K, n * TILE_N:(n + 1) * TILE_N]
45+
46+
result[(k // 2) * TILE_K:(k // 2 + 1) * TILE_K, n * TILE_N * 2:n * TILE_N * 2 + TILE_N] = tile1
47+
result[(k//2) * TILE_K:(k // 2 + 1) * TILE_K, n * TILE_N * 2 + TILE_N:(n + 1) * TILE_N * 2] = tile2
48+
49+
# Place the 2 tiles in the same tile row side by side.
50+
result[(k // 2) * TILE_K:(k // 2 + 1) * TILE_K, n * TILE_N * 2:n * TILE_N * 2+TILE_N] = tile1
51+
result[(k // 2) * TILE_K:(k // 2 + 1) * TILE_K, n * TILE_N * 2 + TILE_N:n * TILE_N * 2 + TILE_N + TILE_N] = tile2
52+
53+
return result
54+
55+
def column_wise_quantize(matrix):
56+
"""
57+
Quantizes a matrix.
58+
59+
Returns a column-wise scale broadcasted to (128, matrix.shape[1]) and the quantized matrix.
60+
"""
61+
FP8_RANGE = 240
62+
column_wise_max = np.max(np.abs(matrix), axis=0, keepdims=True)
63+
column_wise_scale = column_wise_max / FP8_RANGE
64+
65+
matrix_quantized = matrix / column_wise_scale
66+
column_wise_scale = np.broadcast_to(column_wise_scale, (128, matrix.shape[1]))
67+
68+
return column_wise_scale, matrix_quantized
69+
70+
class TestDoubleRowMatmul:
71+
72+
@xfail(fail=['trn1'])
73+
@pytest.mark.parametrize("M, K, N, dtype, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K, max_p99_latency", [
74+
[512, 16 * 1024, 1024, nl.bfloat16, 2, 2, 16, 320],
75+
])
76+
def test_double_row_matmul_perf(self, M, K, N, dtype, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K, max_p99_latency):
77+
# Initializing random inputs
78+
lhs = np.random.rand(M, K)
79+
rhs = np.random.rand(K, N)
80+
81+
# Quantizing rhs
82+
rhs_scale, rhs_quantized = column_wise_quantize(rhs)
83+
rhs_quantized_reshaped = reshape(rhs_quantized)
84+
85+
# Casting to the correct data type (rhs is pre-quantized, thus casted to FP8)
86+
lhs = nl.static_cast(lhs, dtype)
87+
rhs_scale = nl.static_cast(rhs_scale, dtype)
88+
rhs_quantized_reshaped = nl.static_cast(rhs_quantized_reshaped, nl.float8_e4m3)
89+
90+
# Latency checks
91+
bench_func(lhs, rhs_quantized_reshaped, rhs_scale, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K)
92+
latency_res = bench_func.benchmark_result.nc_latency
93+
p99_latency = latency_res.get_latency_percentile(99)
94+
95+
assert p99_latency <= max_p99_latency
96+
97+
@xfail(fail=['trn1'])
98+
@pytest.mark.simulation
99+
@pytest.mark.parametrize("M, K, N, dtype, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K", [
100+
[512, 16 * 1024, 1024, nl.bfloat16, 2, 2, 16],
101+
[512, 16 * 1024, 1024, nl.bfloat16, 4, 1, 32],
102+
[512, 16 * 1024, 1024, nl.bfloat16, 4, 2, 128],
103+
])
104+
def test_double_row_matmul_numerical(self, simulation_only, M, K, N, dtype, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K):
105+
# Initializing random inputs
106+
lhs = np.random.rand(M, K)
107+
rhs = np.random.rand(K, N)
108+
109+
# Correct CPU results
110+
result_golden = np.matmul(lhs, rhs)
111+
112+
# Quantizing rhs
113+
rhs_scale, rhs_quantized = column_wise_quantize(rhs)
114+
rhs_quantized_reshaped = reshape(rhs_quantized)
115+
116+
# Casting to the correct data type (rhs is pre-quantized, thus casted to FP8)
117+
lhs = nl.static_cast(lhs, dtype)
118+
rhs_scale = nl.static_cast(rhs_scale, dtype)
119+
rhs_quantized_reshaped = nl.static_cast(rhs_quantized_reshaped, nl.float8_e4m3)
120+
121+
# Numerical accuracy checks
122+
numeric_func = baremetal(quantized_double_row_matmul)
123+
124+
if simulation_only:
125+
result_nki = simulate_kernel(numeric_func, lhs, rhs_quantized_reshaped, rhs_scale, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K)
126+
else:
127+
result_nki = numeric_func(lhs, rhs_quantized_reshaped, rhs_scale, TILES_IN_BLOCK_M, TILES_IN_BLOCK_N, TILES_IN_BLOCK_K)
128+
129+
# Casting result_nki from dtype BF16 back to FP32 to compare the NumPy and NKI results
130+
result_nki = result_nki.astype(np.float32)
131+
132+
assert np.allclose(result_golden, result_nki, rtol=2e-2)

0 commit comments

Comments
 (0)