Skip to content

Commit 3a89c7b

Browse files
minor changes per Derrick's feedback
Signed-off-by: cliu-us <[email protected]>
1 parent bc9155d commit 3a89c7b

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

examples/QAT_INT8/run_qa_no_trainer_qat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def parse_args():
389389
parser.add_argument(
390390
"--do_lowering",
391391
type=str,
392-
default=None,
392+
default="triton",
393393
help="convert QAT model to utilize real INT8 GPU kernel, 'cutlass' or 'triton'",
394394
)
395395

fms_mo/modules/linear.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,8 +1049,9 @@ def iaddmm_int(self, bias, m1, m2):
10491049
imm_out = torch.bitwise_right_shift(
10501050
imm_out + round_bit, self.truncate_lsb
10511051
)
1052-
# imm_out = imm_out.to(torch.int16)
1053-
# only cast to i16 when truncating 8b from both side
1052+
# could cast to smaller data type to further simulate HW behavior, for example,
1053+
# if HW truncates 8b from both sides of i32 accumulator, the remaining data can
1054+
# be cast to i16 to be more realistic. pay attention to overflow handling
10541055
fp16_out += imm_out.to(torch.float16)
10551056

10561057
return (

0 commit comments

Comments
 (0)