Skip to content

Commit 8ffad1b

Browse files
minor adjustment
Signed-off-by: cliu-us <[email protected]>
1 parent 3655518 commit 8ffad1b

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ def imatmul_kernel(
264264
else:
265265
accumulator_inner = tl.dot(a, b, accumulator, input_precision="ieee")
266266

267-
## ------ MSB truncation by clamp, chunky LSB truncation by rounding/masking --------
267+
## ------ INT MSB truncation is simulated by clamping,
268+
# "special" INT LSB truncation by right and left shift --------
268269
if max_acc_bits < 32:
269270
accumulator_inner = tl.maximum(
270271
tl.minimum(accumulator_inner, acc_max), acc_min
@@ -530,7 +531,6 @@ def isPowerofTwo(x):
530531
c_org_dtype = c.dtype
531532
c = c.to(acc_dtype)
532533
assert c.shape[0] == M and c.shape[1] == N, "C shape is inconsistent with A B."
533-
# assert acc_dtype == torch.float32, "INT truncation is not yet supported."
534534

535535
# 1D launch kernel where each block gets its own program.
536536
def grid(META):

tests/triton_kernels/test_triton_mm.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,23 +94,29 @@ def test_triton_matmul_int8(mkn):
9494
torch_output = torch.matmul(a.to(torch.float), b.to(torch.float))
9595
# cast tl_matmul results to float because torch.norm only supports float
9696
tl_output_no_trun = tl_matmul(a, b).to(torch.float)
97-
# check LSB truncation effect
97+
# check LSB truncation effect (underflow)
9898
tl_output_trun_8b = tl_matmul(a, b, chunk_trun_bits=8).to(torch.float)
99-
# check MSB truncation effect
100-
# max(1 int8 * 1 int8) ~ 2^17 -> each chunk acc 32 elem, possible max ~ 2^22
101-
# -> truncate to 18b -> should see large err than LSB-only case
99+
# check MSB truncation effect (overflow)
100+
# max(1 int8 * 1 int8) ~ 2^14 -> each chunk acc 32 elem only, achievable max ~ 2^19
101+
# -> truncate to 18b -> should see slightly large err than LSB-only case
102102
tl_output_trun_18b8b = tl_matmul(a, b, max_acc_bits=18, chunk_trun_bits=8).to(
103103
torch.float
104104
)
105+
# use larger chunk size to accumulate more elem, MSB truncation (overflow) issue should worsen
106+
tl_output_trun_18b8b_128 = tl_matmul(
107+
a, b, max_acc_bits=18, chunk_trun_bits=8, chunk_size=min(128, k)
108+
).to(torch.float)
105109

106110
ref = torch.norm(torch_output)
107111
rel_err_no_trun = torch.norm(torch_output - tl_output_no_trun) / ref
108112
rel_err_trun_8b = torch.norm(torch_output - tl_output_trun_8b) / ref
109113
rel_err_trun_18b8b = torch.norm(torch_output - tl_output_trun_18b8b) / ref
114+
rel_err_trun_18b8b_128 = torch.norm(torch_output - tl_output_trun_18b8b_128) / ref
110115

111116
assert rel_err_no_trun < 1e-5
112117
assert rel_err_trun_8b < 1e-2
113118
assert rel_err_trun_18b8b < 1e-2
119+
assert rel_err_trun_18b8b_128 >= rel_err_trun_18b8b
114120

115121

116122
@pytest.mark.parametrize("feat_in_out", [(64, 128), (256, 1024), (1024, 4096)])

0 commit comments

Comments
 (0)