Skip to content

Commit 59dfc8b

Browse files
modified based on Thara's feedback
Signed-off-by: cliu-us <[email protected]>
1 parent 3a89c7b commit 59dfc8b

File tree

5 files changed

+30
-31
lines changed

5 files changed

+30
-31
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,3 @@ fms_mo.log
4545
data_train/
4646
data_test/
4747
act_scales/
48-
examples/

examples/QAT_INT8/run_qa_no_trainer_qat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def parse_args():
388388
)
389389
parser.add_argument(
390390
"--do_lowering",
391+
choices=["cutlass", "triton"],
391392
type=str,
392393
default="triton",
393394
help="convert QAT model to utilize real INT8 GPU kernel, 'cutlass' or 'triton'",
@@ -1162,7 +1163,7 @@ def speedtest(model, exam_inp, Ntest=100):
11621163
parent_mod,
11631164
module_name,
11641165
QLinearINT8Deploy.from_fms_mo(
1165-
qmod, useINTkernel=args.do_lowering
1166+
qmod, use_int_kernel=args.do_lowering
11661167
),
11671168
)
11681169

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,9 @@ def get_cuda_autotune_config(chunk_size=None):
7474
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
7575
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
7676
# provided configs
77-
# => Need to avoid using auto-tune for real model inference!
78-
# @triton.autotune(
79-
# configs=get_cuda_autotune_config(),
80-
# key=['M', 'N', 'K'],
81-
# )
77+
# => Need to avoid using auto-tune for real model inference! But for micro-benchmarking purpose, we
78+
# could enable the decorator below
79+
# @triton.autotune(configs=get_cuda_autotune_config(), key=['M', 'N', 'K'])
8280
@triton.jit
8381
def matmul_kernel(
8482
# Pointers to matrices
@@ -187,10 +185,9 @@ def matmul_kernel(
187185
tl.store(c_ptrs, c, mask=c_mask)
188186

189187

190-
# @triton.autotune(
191-
# configs=get_cuda_autotune_config(),
192-
# key=['M', 'N', 'K'],
193-
# )
188+
# Reminder: avoid auto-tune for real model inference! But for micro-benchmarking purpose, could
189+
# enable the decorator below
190+
# @triton.autotune(configs=get_cuda_autotune_config(),key=['M', 'N', 'K'],)
194191
@triton.jit
195192
def imatmul_kernel(
196193
a_ptr,

fms_mo/modules/bmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def from_fms_mo(cls, fms_mo_qbmm, **kwargs):
364364
qbmm_int.num_bits_m1 = fms_mo_qbmm.num_bits_m1
365365
qbmm_int.num_bits_m2 = fms_mo_qbmm.num_bits_m2
366366
qcfg = getattr(fms_mo_qbmm, "qcfg", None)
367-
qbmm_int.useINTkernel = False # always False until int kernel is implemented
367+
qbmm_int.use_int_kernel = False # always False until int kernel is implemented
368368
qbmm_int.use_PT_native_Qfunc = qcfg["use_PT_native_Qfunc"] if qcfg else False
369369

370370
with torch.no_grad():
@@ -438,7 +438,7 @@ def extra_repr(self) -> str:
438438
"""
439439
return (
440440
f"nbits_m1,m2={self.num_bits_m1},{self.num_bits_m2}, "
441-
f"useINTkernel={self.useINTkernel}"
441+
f"use_int_kernel={self.use_int_kernel}"
442442
)
443443

444444
def forward(self, m1: torch.Tensor, m2: torch.Tensor) -> torch.Tensor:

fms_mo/modules/linear.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -712,8 +712,8 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
712712
cls: The class of the QLinearModule to be created.
713713
fms_mo_qlinear: The QLinear module to be converted.
714714
(experimental)
715-
useINTkernel: choose from ['cutlass', 'triton', False], "cutlass" kernel is faster,
716-
"triton" support chunky truncation, "False" fallback to torch.matmul
715+
use_int_kernel: choose from ['cutlass', 'triton', False], "cutlass" kernel is faster,
716+
"triton" supports chunky truncation, "False" fallbacks to torch.matmul
717717
max_acc_bits: usually INT matmul accumulate in INT32, but some HW could have different
718718
design, such as using INT24 accumulator, which will saturate at
719719
(-2**(acc_bit-1) +1, 2**(acc_bit-1) )
@@ -745,8 +745,8 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
745745
)
746746
# Make sure to register an Op for integer matmul, could be real INT matmul or emulation
747747
qcfg = getattr(fms_mo_qlinear, "qcfg", {})
748-
qlin_int.useINTkernel = kwargs.get(
749-
"useINTkernel", qcfg.get("useINTkernel", "cutlass")
748+
qlin_int.use_int_kernel = kwargs.get(
749+
"use_int_kernel", qcfg.get("use_int_kernel", "cutlass")
750750
)
751751
qlin_int.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", False)
752752
qlin_int.max_acc_bits = kwargs.get("max_acc_bits", 32)
@@ -772,7 +772,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
772772
) # Qw.clipval should have been updated after this
773773
qlin_int.weight = nn.Parameter(
774774
w_int8.to(torch.int8), requires_grad=False
775-
) # NOTE: may needs INT W stored as FP in some cases
775+
) # NOTE: may need INT W stored as FP in some cases
776776

777777
if qlin_int.usePTnativeQfunc:
778778
input_scale = torch.tensor(
@@ -873,7 +873,7 @@ def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs):
873873
qlinear_iW.nbits_w = 8
874874
qlinear_iW.acc_dtype = torch.float16
875875
qlinear_iW.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", True)
876-
qlinear_iW.useINTkernel = True
876+
qlinear_iW.use_int_kernel = True
877877
qlinear_iW.weight = nn.Parameter(
878878
nnlin_iW.weight.to(torch.int8), requires_grad=False
879879
)
@@ -1086,22 +1086,22 @@ def set_matmul_op(self):
10861086
"""
10871087
Sets the matmul operator for the quantized linear module.
10881088
1089-
If `useINTkernel` is True and CUDA is available, it will use the INT kernel
1089+
If `use_int_kernel` is True and CUDA is available, it will use the INT kernel
10901090
for integer matrix multiplication. Otherwise, it will use the FP kernel.
10911091
10921092
If the operator has already been set, it will do nothing.
10931093
"""
1094-
if self.useINTkernel and not torch.cuda.is_available():
1094+
if self.use_int_kernel and not torch.cuda.is_available():
10951095
logger.warning(
1096-
"Cannot set useINTkernel=True when CUDA is not available. "
1097-
"Fallback to useINTkernel=False"
1096+
"Cannot set use_int_kernel=True when CUDA is not available. "
1097+
"Fallback to use_int_kernel=False"
10981098
)
1099-
self.useINTkernel = False
1099+
self.use_int_kernel = False
11001100

11011101
if hasattr(torch.ops, "fms_mo") and hasattr(torch.ops.fms_mo, "imatmul"):
11021102
# imatmul already registered, e.g. when swapping the 2nd QLinear
11031103
self.imatmul = torch.ops.fms_mo.imatmul
1104-
self.iaddmm = self.iaddmm_int if self.useINTkernel else self.iaddmm_FP
1104+
self.iaddmm = self.iaddmm_int if self.use_int_kernel else self.iaddmm_FP
11051105
else:
11061106
# When swapping the first QLinear, need to register our custom Op and choose the kernel
11071107
# Standard
@@ -1113,14 +1113,16 @@ def set_matmul_op(self):
11131113
imatmul_ops_reg,
11141114
)
11151115

1116-
if self.useINTkernel == "triton": # will use real imatmul written in triton
1116+
if self.use_int_kernel == "triton":
1117+
# will use real imatmul written in triton
11171118
imm_func = partial(
11181119
tl_matmul,
11191120
chunk_trun_bits=self.truncate_lsb,
11201121
chunk_size=self.chunk_size,
11211122
)
11221123

1123-
elif self.useINTkernel == "cutlass":
1124+
elif self.use_int_kernel == "cutlass":
1125+
# will use real imatmul written in cutlass
11241126
cutlass_ops_load_and_reg()
11251127
# Third Party
11261128
import cutlass_mm # this module will only be available after calling reg()
@@ -1129,9 +1131,9 @@ def set_matmul_op(self):
11291131
else:
11301132
imm_func = torch.matmul
11311133

1132-
imatmul_ops_reg(self.useINTkernel, imm_func)
1134+
imatmul_ops_reg(self.use_int_kernel, imm_func)
11331135
self.imatmul = torch.ops.fms_mo.imatmul
1134-
self.iaddmm = self.iaddmm_int if self.useINTkernel else self.iaddmm_FP
1136+
self.iaddmm = self.iaddmm_int if self.use_int_kernel else self.iaddmm_FP
11351137

11361138
def _get_name(self):
11371139
"""
@@ -1145,7 +1147,7 @@ def extra_repr(self) -> str:
11451147
"""
11461148
return (
11471149
f"in={self.in_features}, out={self.out_features}, bias={self.bias is not None}, "
1148-
f"useINTkernel={self.useINTkernel}"
1150+
f"use_int_kernel={self.use_int_kernel}"
11491151
)
11501152

11511153
def __getstate__(self):
@@ -1861,7 +1863,7 @@ class LinearFPxAcc(torch.nn.Linear):
18611863
"""Linear layer wrapper that can simulate the HW behavior of LSB truncation on FP accumulation.
18621864
Some HW may have options to allow FP matmul engine to accumulate in precision lower than FP32,
18631865
such as accumulate in TF32 or even BF16. According to Nvidia doc, ~7-10x speed up with minor
1864-
accuracy trade-off. This support both FWD and BWD.
1866+
accuracy trade-off. This supports both FWD and BWD.
18651867
Ref:
18661868
1. https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/
18671869
2. PyTorch's "torch.backends.cuda.matmul.allow_tf32"

0 commit comments

Comments
 (0)