Skip to content

Commit 005d54b

Browse files
add DL16 option for LinearFPx (FP8 aiu sim)
Signed-off-by: cliu-us <[email protected]>
1 parent bcee5f3 commit 005d54b

File tree

3 files changed

+91
-16
lines changed

3 files changed

+91
-16
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def matmul_kernel(
114114
stride_cn,
115115
chunk_trun_bits,
116116
max_acc_bits, # pylint: disable=unused-argument
117+
clamp_acc_to_dl16,
117118
truncate_then_accumulate,
118119
# Meta-parameters
119120
BLOCK_SIZE_M: tl.constexpr,
@@ -182,6 +183,8 @@ def matmul_kernel(
182183
## ------ add chunky LSB rounding/masking --------
183184
if chunk_trun_bits > 0:
184185
accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask)
186+
if clamp_acc_to_dl16:
187+
accumulator = fp32_clamp_to_dl16(accumulator)
185188
## ---------------------------------------------------------
186189
if truncate_then_accumulate:
187190
accumulator += accumulator_inner
@@ -226,6 +229,7 @@ def imatmul_kernel(
226229
stride_cn,
227230
chunk_trun_bits,
228231
max_acc_bits,
232+
clamp_acc_to_dl16, # pylint: disable=unused-argument
229233
truncate_then_accumulate,
230234
# Meta-parameters
231235
BLOCK_SIZE_M: tl.constexpr,
@@ -324,6 +328,7 @@ def matmul_kernel_DABC(
324328
stride_cn,
325329
chunk_trun_bits,
326330
max_acc_bits, # pylint: disable=unused-argument
331+
clamp_acc_to_dl16,
327332
truncate_then_accumulate,
328333
# Meta-parameters
329334
BLOCK_SIZE_M: tl.constexpr,
@@ -405,6 +410,8 @@ def matmul_kernel_DABC(
405410
## ------ add chunky LSB rounding/masking --------
406411
if chunk_trun_bits > 0:
407412
accumulator_inner = round_and_trun(accumulator_inner, round_bit, trun_mask)
413+
if clamp_acc_to_dl16:
414+
accumulator = fp32_clamp_to_dl16(accumulator)
408415
## ---------------------------------------------------------
409416
if truncate_then_accumulate:
410417
accumulator += accumulator_inner
@@ -438,6 +445,28 @@ def round_and_trun(x, round_bit, trun_mask):
438445
return libdevice.uint_as_float((libdevice.float_as_uint(x) + round_bit) & trun_mask)
439446

440447

448+
@triton.jit
449+
def fp32_clamp_to_dl16(x):
450+
"""clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range."""
451+
# 1. rounding, add round bit to full uint representation
452+
x = libdevice.float_as_uint(x)
453+
round_bit = 1 << (23 - 9 - 1)
454+
x = libdevice.uint_as_float(x + round_bit)
455+
456+
# 2. clamp to min/max:
457+
# max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf
458+
# (32 + 127) << 23 | (0xFF8 << (23 - 12)) in FP32 is 8581545984.0
459+
# min = 2^-31 * 1.(0000 0000 1)_base2 => set to 0 for those smaller than this
460+
# (-31 + 127) << 23 | (1 << (23 - 9)) in FP32 is 4.665707820095122e-10
461+
dl16_max = 8581545984.0
462+
dl16_min = 4.665707820095122e-10
463+
x = tl.where(x >= dl16_max, float("inf"), x)
464+
x = tl.where(x <= -dl16_max, float("-inf"), x)
465+
x = tl.where(tl.abs(x) < dl16_min, 0, x)
466+
467+
return x
468+
469+
441470
def tl_matmul_chunk_truncate(
442471
a,
443472
b,
@@ -448,6 +477,7 @@ def tl_matmul_chunk_truncate(
448477
max_acc_bits=32,
449478
truncate_then_accumulate=True,
450479
cast_output_to_input_dtype=None,
480+
clamp_acc_to_dl16=False,
451481
):
452482
"""Triton matmul for HW behavior simulation. Supports float and int8.
453483
i. variable chunk size (i.e., BLOCK_SIZE_K)
@@ -461,7 +491,8 @@ def tl_matmul_chunk_truncate(
461491
chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
462492
max_acc_bits (int, optional): num of bits for the accumulator, e.g. if INT24 is used, will
463493
clamp each chunk of a*b to [-2**23-1, 2**23].
464-
(assuming no inf when overflow)
494+
(only used by INT)
495+
clamp_acc_to_dl16(bool): Only used by FP8, whether to clamp local accumulator (FP32) to DL16
465496
truncate_then_accumulate (bool, optional): if True, c = truncate(a*b) + c, otherwise
466497
c = truncate(a*b+c)
467498
cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
@@ -473,7 +504,7 @@ def tl_matmul_chunk_truncate(
473504
474505
NOTE:
475506
use empirical way to determine BLOCK sizes, may not be optimal. But need to avoid autotune for
476-
real model inference. otherwise auto-tune will be triggered in every forward call.
507+
real model inference. otherwise auto-tune may be triggered in every forward call.
477508
"""
478509

479510
# Check constraints.
@@ -584,6 +615,7 @@ def grid(META):
584615
c.stride(1),
585616
chunk_trun_bits=chunk_trun_bits,
586617
max_acc_bits=max_acc_bits,
618+
clamp_acc_to_dl16=clamp_acc_to_dl16,
587619
truncate_then_accumulate=truncate_then_accumulate,
588620
ACTIVATION=activation,
589621
**kernel_config, # if using auto-tune, comment this line out.

fms_mo/custom_ext_kernels/utils.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -870,14 +870,15 @@ def lower_qmodel_triton(
870870
model: torch.nn.Module,
871871
use_dyn_max_act=False,
872872
max_acc_bits=32,
873+
clamp_acc_to_dl16=False,
873874
num_lsb_to_truncate=0,
874875
chunk_size=32,
875876
):
876877
"""
877-
Examplar GPU lowering function using triton. Only swap Qlinears in transformers, nothing else.
878+
Examplar GPU lowering function using triton. Only swap Linear/Qlinear in transformers.
878879
Triton kernel can be used to:
879880
1. test INT8 or FP8 HW performance (kernel is not optimized)
880-
2. simulate MSB/LSB truncation effect
881+
2. simulate MSB/LSB truncation effect or special dtype (DL16) accumulation
881882
882883
Args:
883884
model: nn.Module. should be a fms_mo Qmodel, will do inplace layer swapping, no deepcopy
@@ -888,6 +889,8 @@ def lower_qmodel_triton(
888889
efficiency at the expense of higher chance of accumulation "overflow".
889890
For example, an INT24 accumulator can only hold values ranged from -2^23 to
890891
2^23 -1, as opposed to typical range -2^31 to -2^31 -1.
892+
clamp_acc_to_dl16: clamp local accumulator to DL16 (1-6-9) range. To simulate this special
893+
dtype effect on accumulation.
891894
num_lsb_to_truncate: number of bits to truncate from LSB side. For example, given fp32 is
892895
s1e8m23, if we choose to truncate 13 mantissa bits from right most side,
893896
i.e. LSB, the resulting number will be s1e8m10, which is TF32.
@@ -900,25 +903,47 @@ def lower_qmodel_triton(
900903
from torch.ao.quantization.utils import _parent_name
901904

902905
# Local
903-
from fms_mo.modules.linear import QLinear, QLinearINT8Deploy
906+
from fms_mo.modules.linear import LinearFPxAcc, QLinear, QLinearINT8Deploy
907+
908+
# Currently QLinearINT8 has more options in dynamic quantization than LinearFP. Here we resolve
909+
# the differences as a patch solution (will unify the codes in future release)
910+
linFP_dyn_code = (
911+
"per_token"
912+
if use_dyn_max_act in [-1, -2]
913+
else "per_tensor"
914+
if use_dyn_max_act
915+
else False
916+
)
904917

905918
for name, m in model.named_modules():
906-
if not isinstance(m, QLinear):
919+
if not isinstance(m, (QLinear, torch.nn.Linear)):
907920
continue
908921
parent_name, module_name = _parent_name(name)
909922
parent_mod = model.get_submodule(parent_name)
910-
qmod = getattr(parent_mod, module_name)
911-
setattr(
912-
parent_mod,
913-
module_name,
914-
QLinearINT8Deploy.from_fms_mo(
915-
qmod,
923+
924+
# Only support simulations of 1) QLinear -> INT8, 2) nnLinear->FP8 for now
925+
if isinstance(m, QLinear):
926+
new_lin = QLinearINT8Deploy.from_fms_mo(
927+
m,
916928
use_int_kernel="triton",
917929
use_dynamic_max_act_Qfunc=use_dyn_max_act,
918930
max_acc_bits=max_acc_bits,
919931
truncate_lsb=num_lsb_to_truncate,
920932
chunk_size=chunk_size,
921-
),
933+
)
934+
else:
935+
new_lin = LinearFPxAcc.from_nn(
936+
m,
937+
trun_bits=max_acc_bits,
938+
chunk_size=chunk_size,
939+
dynamic_fp8=linFP_dyn_code,
940+
clamp_acc_to_dl16=clamp_acc_to_dl16,
941+
)
942+
943+
setattr(
944+
parent_mod,
945+
module_name,
946+
new_lin,
922947
)
923948

924949
logger.info(f"\nModel lowering with triton kernel is done.\n{model}")

fms_mo/modules/linear.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1899,7 +1899,16 @@ class LinearFuncFPxFwdBwd(torch.autograd.Function):
18991899
"""
19001900

19011901
@staticmethod
1902-
def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16, fp8_dyn=False):
1902+
def forward(
1903+
ctx,
1904+
x,
1905+
weight,
1906+
bias=None,
1907+
trun_bits=0,
1908+
chunk_size=16,
1909+
fp8_dyn=False,
1910+
clamp_acc_to_dl16=False,
1911+
):
19031912
assert x.dtype in [torch.float, torch.bfloat16, torch.float16]
19041913
# input can be 2D or 3D, need to reshape before tl_matmul
19051914
org_dtype = x.dtype
@@ -1916,6 +1925,7 @@ def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16, fp8_dyn=False
19161925
ctx.trun_bits = trun_bits
19171926
ctx.chunk_size = chunk_size
19181927
ctx.fp8_dyn = fp8_dyn
1928+
ctx.clamp_acc_to_dl16 = clamp_acc_to_dl16
19191929

19201930
if fp8_dyn:
19211931
# use Q/dQ simulation for now, meaning still compute in fp16/bf16
@@ -1936,6 +1946,7 @@ def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16, fp8_dyn=False
19361946
weight.t().to(org_dtype),
19371947
chunk_trun_bits=trun_bits,
19381948
chunk_size=chunk_size,
1949+
clamp_acc_to_dl16=clamp_acc_to_dl16,
19391950
).reshape(target_shape_output)
19401951

19411952
if bias is not None:
@@ -1976,6 +1987,7 @@ def backward(ctx, grad_output):
19761987
x,
19771988
chunk_trun_bits=trun_bits,
19781989
chunk_size=chunk_size,
1990+
clamp_acc_to_dl16=ctx.clamp_acc_to_dl16,
19791991
).to(weight.dtype)
19801992
# Compute grad_input in 2D then reshape to target shape, could be 3D or 2D
19811993
grad_input = (
@@ -1984,6 +1996,7 @@ def backward(ctx, grad_output):
19841996
weight.to(dtype_input),
19851997
chunk_trun_bits=trun_bits,
19861998
chunk_size=chunk_size,
1999+
clamp_acc_to_dl16=ctx.clamp_acc_to_dl16,
19872000
)
19882001
.reshape(target_shape_grad_input)
19892002
.to(dtype_input)
@@ -1994,7 +2007,7 @@ def backward(ctx, grad_output):
19942007
else:
19952008
grad_bias = grad_output_2D.sum(0).to(ctx.bias_dtype)
19962009

1997-
return grad_input, grad_weight, grad_bias, None, None, None
2010+
return grad_input, grad_weight, grad_bias, None, None, None, None
19982011

19992012

20002013
class LinearFPxAcc(torch.nn.Linear):
@@ -2016,6 +2029,10 @@ def from_nn(cls, nnlin, trun_bits=0, **kwargs):
20162029
cls (class): The class to be created.
20172030
nnlin (torch.nn.Linear): The original torch.nn.Linear module.
20182031
trun_bits (int): truncate [0 to 22] LSBs from FP32 accumulation.
2032+
dynamic_fp8: whether to use dynamic quantization for fp8 activations, available options
2033+
are ["per_tensor", "per_token", False]
2034+
clamp_acc_to_dl16: clamp local accumulator into DL16 range, to simulate the effect of
2035+
this special dtype
20192036
**kwargs: Additional keyword arguments.
20202037
20212038
Returns:
@@ -2037,7 +2054,7 @@ def from_nn(cls, nnlin, trun_bits=0, **kwargs):
20372054
lin24acc.trun_bits = trun_bits
20382055
lin24acc.chunk_size = kwargs.get("chunk_size", False)
20392056
lin24acc.fp8_dyn = kwargs.get("dynamic_fp8", False)
2040-
# available options are ["per_tensor", "per_token"]
2057+
lin24acc.clamp_acc_to_dl16 = kwargs.get("clamp_acc_to_dl16", False)
20412058

20422059
if nnlin.bias is not None:
20432060
lin24acc.bias = nnlin.bias
@@ -2052,6 +2069,7 @@ def forward(self, inputs):
20522069
self.trun_bits,
20532070
self.chunk_size,
20542071
self.fp8_dyn,
2072+
self.clamp_acc_to_dl16,
20552073
)
20562074

20572075
def extra_repr(self) -> str:

0 commit comments

Comments
 (0)