Skip to content

Commit eea8426

Browse files
modify triton matmul to allow the formula C=A*B+C
Signed-off-by: cliu-us <[email protected]>
1 parent 262db69 commit eea8426

File tree

1 file changed

+172
-12
lines changed

1 file changed

+172
-12
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 172 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def matmul_kernel(
101101
stride_cm,
102102
stride_cn,
103103
chunk_trun_bits,
104+
truncate_then_accumulate,
104105
# Meta-parameters
105106
BLOCK_SIZE_M: tl.constexpr,
106107
BLOCK_SIZE_N: tl.constexpr,
@@ -159,15 +160,20 @@ def matmul_kernel(
159160
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
160161
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
161162
# We accumulate along the K dimension.
162-
accumulator = tl.dot(a, b, accumulator, input_precision="ieee")
163+
if truncate_then_accumulate:
164+
accumulator_inner = tl.dot(a, b, input_precision="ieee")
165+
else:
166+
accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee")
163167
# tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
164168

165169
## ------ add chunky LSB rounding/masking --------
166170
if chunk_trun_bits > 0:
167-
accumulator = libdevice.uint_as_float(
168-
(libdevice.float_as_uint(accumulator) + round_bit) & trun_mask
169-
)
171+
accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask)
170172
## ---------------------------------------------------------
173+
if truncate_then_accumulate:
174+
accumulator += accumulator_inner
175+
else:
176+
accumulator = accumulator_inner
171177

172178
# Advance the ptrs to the next K block.
173179
a_ptrs += BLOCK_SIZE_K * stride_ak
@@ -206,6 +212,7 @@ def imatmul_kernel(
206212
stride_cm,
207213
stride_cn,
208214
chunk_trun_bits,
215+
truncate_then_accumulate,
209216
# Meta-parameters
210217
BLOCK_SIZE_M: tl.constexpr,
211218
BLOCK_SIZE_N: tl.constexpr,
@@ -244,13 +251,21 @@ def imatmul_kernel(
244251
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
245252
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
246253
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
247-
accumulator = tl.dot(a, b, accumulator, input_precision="ieee")
254+
if truncate_then_accumulate:
255+
accumulator_inner = tl.dot(a, b, input_precision="ieee")
256+
else:
257+
accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee")
248258

249259
## ------ add chunky LSB rounding/masking --------
250260
if chunk_trun_bits != 0:
251-
accumulator = (accumulator + round_bit) >> chunk_trun_bits
252-
accumulator = accumulator << chunk_trun_bits
261+
accumulator_inner = (accumulator_inner + round_bit) >> chunk_trun_bits
262+
accumulator_inner = accumulator_inner << chunk_trun_bits
253263
## ---------------------------------------------------------
264+
if truncate_then_accumulate:
265+
accumulator += accumulator_inner
266+
else:
267+
accumulator = accumulator_inner
268+
254269

255270
a_ptrs += BLOCK_SIZE_K * stride_ak
256271
b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -266,29 +281,163 @@ def imatmul_kernel(
266281
tl.store(c_ptrs, c, mask=c_mask)
267282

268283

284+
285+
@triton.jit
286+
def matmul_kernel_DABC(
287+
# Pointers to matrices
288+
a_ptr,
289+
b_ptr,
290+
c_ptr,
291+
# Matrix dimensions
292+
M,
293+
N,
294+
K,
295+
# The stride variables represent how much to increase the ptr by when moving by 1
296+
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
297+
# by to get the element one row down (A has M rows).
298+
stride_am,
299+
stride_ak,
300+
stride_bk,
301+
stride_bn,
302+
stride_cm,
303+
stride_cn,
304+
chunk_trun_bits,
305+
truncate_then_accumulate,
306+
# Meta-parameters
307+
BLOCK_SIZE_M: tl.constexpr,
308+
BLOCK_SIZE_N: tl.constexpr,
309+
BLOCK_SIZE_K: tl.constexpr,
310+
GROUP_SIZE_M: tl.constexpr,
311+
ACTIVATION: tl.constexpr,
312+
):
313+
"""Kernel for computing the matmul D = A x B + C that include LSB truncation.
314+
A has shape (M, K), B has shape (K, N) and C/D has shape (M, N).
315+
NOTE:
316+
C should be consistent with accumulator dtype, e.g. fp8xfp8 -> fp32.
317+
*D ptr is supposed to be the same as C ptr, no need to provide D as arg
318+
**we can be used C to verify unintended truncation by CUDA as well.
319+
Args:
320+
chunk_trun_bits (int): number of LSB to truncate/round. [0 to 23]
321+
"""
322+
# -----------------------------------------------------------
323+
# Map program ids `pid` to the block of C it should compute.
324+
# This is done in a grouped ordering to promote L2 data reuse.
325+
# See above `L2 Cache Optimizations` section for details.
326+
pid = tl.program_id(axis=0)
327+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
328+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
329+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
330+
group_id = pid // num_pid_in_group
331+
first_pid_m = group_id * GROUP_SIZE_M
332+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
333+
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
334+
pid_n = (pid % num_pid_in_group) // group_size_m
335+
336+
# ----------------------------------------------------------
337+
# Create pointers for the first blocks of A and B.
338+
# We will advance this pointer as we move in the K direction
339+
# and accumulate
340+
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
341+
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
342+
# See above `Pointer Arithmetic` section for details
343+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
344+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
345+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
346+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
347+
offs_k = tl.arange(0, BLOCK_SIZE_K)
348+
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
349+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
350+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
351+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
352+
353+
# -----------------------------------------------------------
354+
# Iterate to compute a block of the C matrix.
355+
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
356+
# of fp32 values for higher accuracy.
357+
# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
358+
accumulator = tl.load(c_ptrs, mask=c_mask, other=0.0) # should have been cast to fp32 already
359+
## ------ prepare LSB rounding/truncation masks -------
360+
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
361+
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
362+
# 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
363+
trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32)
364+
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
365+
## ---------------------------------------------------------
366+
367+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
368+
# Load the next block of A, B, and C, generate a mask by checking the K dimension.
369+
# If it is out of bounds, set it to 0.
370+
# D = truncation(A*B) + C
371+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
372+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
373+
# We accumulate along the K dimension. but apply truncation on local A*B first
374+
if truncate_then_accumulate:
375+
accumulator_inner = tl.dot(a, b, input_precision="ieee")
376+
else:
377+
accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee")
378+
# tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
379+
# NOTE: tl.dot(a, b, c) should use one single CUDA mma instruction to handle "c = a*b+c". If
380+
# this mma instruction uses "reduced-precision" under the hood, not only a*b will
381+
# be accumulated in that precision, c most likely will be cast to that "lower"
382+
# precision first, hence, will lose some precision!
383+
384+
## ------ add chunky LSB rounding/masking --------
385+
if chunk_trun_bits > 0:
386+
accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask)
387+
## ---------------------------------------------------------
388+
if truncate_then_accumulate:
389+
accumulator += accumulator_inner
390+
else:
391+
accumulator = accumulator_inner
392+
393+
# Advance the ptrs to the next K block.
394+
a_ptrs += BLOCK_SIZE_K * stride_ak
395+
b_ptrs += BLOCK_SIZE_K * stride_bk
396+
# You can fuse arbitrary activation functions here
397+
# while the accumulator is still in FP32!
398+
if ACTIVATION == "leaky_relu":
399+
accumulator = leaky_relu(accumulator)
400+
401+
d = accumulator # do not cast to (tl.float16) just yet
402+
403+
# -----------------------------------------------------------
404+
# Write back the block of the output to matrix "C" with masks.
405+
tl.store(c_ptrs, d, mask=c_mask)
406+
407+
269408
@triton.jit
270409
def leaky_relu(x):
271410
"""Activation function that could be fused into matmul kernel"""
272411
return tl.where(x >= 0, x, 0.01 * x)
273412

274413

414+
@triton.jit
415+
def round_and_trun(x, round_bit, trun_mask):
416+
"""Round and truncate (usually for accumulator)."""
417+
return libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask)
418+
419+
275420
def tl_matmul_chunk_truncate(
276421
a,
277422
b,
278423
activation="",
279424
chunk_trun_bits=0,
280425
chunk_size=16,
426+
truncate_then_accumulate=True,
281427
cast_output_to_input_dtype=None,
282428
):
283429
"""Triton matmul for HW behavior simulation. Supports float and int8.
284-
a. variable chunk size (i.e., BLOCK_SIZE_K)
285-
b. LSB truncation, must <23 if using float.
430+
i. variable chunk size (i.e., BLOCK_SIZE_K)
431+
ii. LSB truncation, must <23 if using float.
432+
iii. assume D = A*B + C, where C is optional. If C exists, it will be updated inplace.
286433
287434
Args:
288435
a, b: input tensors. FloatX, X in [32, 16, 8] or INT8.
289436
activation (str, optional): activation func to be fused, see relu example.
290437
chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded.
291438
chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
439+
truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise
440+
c = truncate(a*b+c)
292441
cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
293442
FP32 or INT32. by default we cast the final
294443
output to the same dtype as input for non-8bits.
@@ -300,6 +449,7 @@ def tl_matmul_chunk_truncate(
300449
use empirical way to determine BLOCK sizes, may not be optimal. But need to avoid autotune for
301450
real model inference. otherwise auto-tune will be triggered in every forward call.
302451
"""
452+
303453
# Check constraints.
304454
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
305455
assert a.is_contiguous(), "Matrix A must be contiguous"
@@ -343,9 +493,18 @@ def isPowerofTwo(x):
343493
mm_kernel = imatmul_kernel
344494
else:
345495
acc_dtype = torch.float32
346-
mm_kernel = matmul_kernel
496+
mm_kernel = matmul_kernel if c == None else matmul_kernel_DABC
347497
assert chunk_trun_bits < 23, "FP32 accumulator only has 23 mantissa bits"
348-
c = torch.zeros((M, N), device=a.device, dtype=acc_dtype)
498+
499+
if c == None:
500+
c_org_dtype = a.dtype
501+
c = torch.zeros((M, N), device=a.device, dtype=acc_dtype)
502+
else:
503+
# if C is in fp16, accumulate in fp32 no matter what, decide whether to cast back later
504+
c_org_dtype = c.dtype
505+
c = c.to(acc_dtype)
506+
assert c.shape[0]==M and c.shape[1]==N, "C shape is inconsistent with A B."
507+
assert acc_dtype==torch.float32, "INT truncation experiment is not yet supported."
349508

350509
# 1D launch kernel where each block gets its own program.
351510
def grid(META):
@@ -386,7 +545,8 @@ def grid(META):
386545
c.stride(0),
387546
c.stride(1),
388547
chunk_trun_bits=chunk_trun_bits,
548+
truncate_then_accumulate=truncate_then_accumulate,
389549
ACTIVATION=activation,
390550
**kernel_config, # if using auto-tune, comment this line out.
391551
)
392-
return c.to(a.dtype) if cast_output_to_input_dtype else c
552+
return c.to(c_org_dtype) if cast_output_to_input_dtype else c

0 commit comments

Comments
 (0)