Skip to content

Commit 608068d

Browse files
Merge pull request #61 from chichun-charlie-liu/triton_kernel
fix: Triton kernel bug fix
2 parents a31a4e5 + b350fde commit 608068d

File tree

4 files changed

+22
-17
lines changed

4 files changed

+22
-17
lines changed

examples/QAT_INT8/run_qa_no_trainer_qat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def parse_args():
390390
"--do_lowering",
391391
choices=["cutlass", "triton"],
392392
type=str,
393-
default="triton",
393+
default=None,
394394
help="convert QAT model to utilize real INT8 GPU kernel, 'cutlass' or 'triton'",
395395
)
396396

@@ -1136,7 +1136,7 @@ def speedtest(model, exam_inp, Ntest=100):
11361136
logger.info(
11371137
f"\n {label} {'with' if comp_mode else 'without'} torch.compile"
11381138
)
1139-
model_copy = deepcopy(model)
1139+
model_copy = deepcopy(model).half()
11401140

11411141
if label == "int8":
11421142
qcfg = qconfig_init(recipe="qat_int8", args=args)
@@ -1178,7 +1178,7 @@ def speedtest(model, exam_inp, Ntest=100):
11781178

11791179
# Median runtime using fixed input (in msec)
11801180
med_runtime = speedtest(model_copy, exam_inp)
1181-
metrics = squad_eval(model_copy) if label == "int8" else {"f1": None}
1181+
metrics = squad_eval(model_copy) # if label == "int8" else {"f1": None}
11821182

11831183
summary["precision"].append(label)
11841184
summary["compile mode"].append(comp_mode)

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def imatmul_kernel(
235235
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
236236
## ------ prepare LSB rounding/truncation masks -------
237237
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
238+
# msb_mask = 0x00FFFFFF # only needed when simulating truncation on MSB
238239
## ---------------------------------------------------------
239240

240241
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
@@ -326,7 +327,7 @@ def grid(META):
326327
kernel_config = {
327328
"BLOCK_SIZE_M": 128,
328329
"BLOCK_SIZE_K": chunk_size,
329-
"BLOCK_SIZE_N": 32,
330+
"BLOCK_SIZE_N": 128, # was 32
330331
"GROUP_SIZE_M": 8,
331332
"num_warps": 2,
332333
"num_stages": 5,
@@ -335,7 +336,7 @@ def grid(META):
335336
kernel_config = {
336337
"BLOCK_SIZE_M": 128,
337338
"BLOCK_SIZE_K": chunk_size,
338-
"BLOCK_SIZE_N": 64,
339+
"BLOCK_SIZE_N": 128, # was 64
339340
"GROUP_SIZE_M": 8,
340341
"num_warps": 4,
341342
"num_stages": 4,

fms_mo/modules/linear.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
752752
qlin_int.max_acc_bits = kwargs.get("max_acc_bits", 32)
753753
qlin_int.accminmax = (
754754
-(1 << (qlin_int.max_acc_bits - 1)),
755-
1 << (qlin_int.max_acc_bits - 1) - 1,
755+
(1 << (qlin_int.max_acc_bits - 1)) - 1,
756756
)
757757
qlin_int.truncate_lsb = kwargs.get("truncate_lsb", 0)
758758
qlin_int.chunk_size = kwargs.get("chunk_size", 100000)
@@ -871,16 +871,16 @@ def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs):
871871

872872
qlinear_iW.nbits_a = 8 # Only support INT8 for now
873873
qlinear_iW.nbits_w = 8
874-
qlinear_iW.acc_dtype = torch.float16
874+
qlinear_iW.acc_dtype = kwargs.get("acc_dtype", torch.float)
875875
qlinear_iW.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", True)
876-
qlinear_iW.use_int_kernel = True
876+
qlinear_iW.use_int_kernel = kwargs.get("use_int_kernel", "triton")
877877
qlinear_iW.weight = nn.Parameter(
878878
nnlin_iW.weight.to(torch.int8), requires_grad=False
879879
)
880880
qlinear_iW.max_acc_bits = kwargs.get("max_acc_bits", 32)
881881
qlinear_iW.accminmax = (
882882
-(1 << (qlinear_iW.max_acc_bits - 1)),
883-
1 << (qlinear_iW.max_acc_bits - 1) - 1,
883+
(1 << (qlinear_iW.max_acc_bits - 1)) - 1,
884884
)
885885
qlinear_iW.truncate_lsb = kwargs.get("truncate_lsb", False)
886886
qlinear_iW.chunk_size = kwargs.get("chunk_size", 100000)
@@ -1027,11 +1027,11 @@ def iaddmm_int(self, bias, m1, m2):
10271027
else:
10281028
m1 = self.qa_fmo_mo_qfunc(m1)
10291029

1030-
if m1.shape[1] > self.chunk_size:
1030+
if m1.shape[1] > self.chunk_size and self.use_int_kernel != "triton":
10311031
idx = list(range(0, m1.shape[1], self.chunk_size))
10321032
Nchunk = len(idx)
10331033
idx.append(m1.shape[1])
1034-
fp16_out = torch.zeros(
1034+
accumulator = torch.zeros(
10351035
(m1.shape[0], m2.shape[1]), dtype=torch.float16, device=m1.device
10361036
)
10371037
trun_scale = 1
@@ -1052,11 +1052,11 @@ def iaddmm_int(self, bias, m1, m2):
10521052
# could cast to smaller data type to further simulate HW behavior, for example,
10531053
# if HW truncates 8b from both sides of i32 accumulator, the remaining data can
10541054
# be cast to i16 to be more realistic. pay attention to overflow handling
1055-
fp16_out += imm_out.to(torch.float16)
1055+
accumulator += imm_out.to(torch.float16)
10561056

10571057
return (
1058-
fp16_out
1059-
* (trun_scale * self.input_scale * self.w_scale).to(torch.float16)
1058+
accumulator
1059+
* (trun_scale * self.input_scale * self.w_scale) # .to(torch.float16)
10601060
+ bias
10611061
).to(self.acc_dtype)
10621062
# The safest casting, i32 -> f32
@@ -1145,10 +1145,13 @@ def extra_repr(self) -> str:
11451145
"""
11461146
Returns an alternative string representation of the object
11471147
"""
1148-
return (
1148+
repr_str = (
11491149
f"in={self.in_features}, out={self.out_features}, bias={self.bias is not None}, "
1150-
f"use_int_kernel={self.use_int_kernel}"
1150+
f"int_kernel={self.use_int_kernel}"
11511151
)
1152+
if self.truncate_lsb > 0 or self.max_acc_bits < 32:
1153+
repr_str += f", acc_bits={self.max_acc_bits}, trun_lsb={self.truncate_lsb}"
1154+
return repr_str
11521155

11531156
def __getstate__(self):
11541157
"""

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ dependencies = [
2626
"accelerate>=0.20.3,!=0.34,<1.4",
2727
"transformers>=4.45,<4.49",
2828
"torch>=2.2.0,<2.5",
29+
"triton>=3.0,<3.2",
2930
"tqdm>=4.66.2,<5.0",
3031
"datasets>=3.0.0,<4.0",
3132
"ninja>=1.11.1.1,<2.0",
3233
"tensorboard",
3334
"notebook",
34-
"torchvision>=0.8",
35+
"torchvision>=0.17",
3536
"evaluate",
3637
"huggingface_hub",
3738
"pandas",

0 commit comments

Comments
 (0)