@@ -94,23 +94,29 @@ def test_triton_matmul_int8(mkn):
9494 torch_output = torch .matmul (a .to (torch .float ), b .to (torch .float ))
9595 # cast tl_matmul results to float because torch.norm only supports float
9696 tl_output_no_trun = tl_matmul (a , b ).to (torch .float )
97- # check LSB truncation effect
97+ # check LSB truncation effect (underflow)
9898 tl_output_trun_8b = tl_matmul (a , b , chunk_trun_bits = 8 ).to (torch .float )
99- # check MSB truncation effect
100- # max(1 int8 * 1 int8) ~ 2^17 -> each chunk acc 32 elem, possible max ~ 2^22
101- # -> truncate to 18b -> should see large err than LSB-only case
99+ # check MSB truncation effect (overflow)
100+ # max(1 int8 * 1 int8) ~ 2^14 -> each chunk acc 32 elem only, achievable max ~ 2^19
101+ # -> truncate to 18b -> should see slightly large err than LSB-only case
102102 tl_output_trun_18b8b = tl_matmul (a , b , max_acc_bits = 18 , chunk_trun_bits = 8 ).to (
103103 torch .float
104104 )
105+ # use larger chunk size to accumulate more elem, MSB truncation (overflow) issue should worsen
106+ tl_output_trun_18b8b_128 = tl_matmul (
107+ a , b , max_acc_bits = 18 , chunk_trun_bits = 8 , chunk_size = min (128 , k )
108+ ).to (torch .float )
105109
106110 ref = torch .norm (torch_output )
107111 rel_err_no_trun = torch .norm (torch_output - tl_output_no_trun ) / ref
108112 rel_err_trun_8b = torch .norm (torch_output - tl_output_trun_8b ) / ref
109113 rel_err_trun_18b8b = torch .norm (torch_output - tl_output_trun_18b8b ) / ref
114+ rel_err_trun_18b8b_128 = torch .norm (torch_output - tl_output_trun_18b8b_128 ) / ref
110115
111116 assert rel_err_no_trun < 1e-5
112117 assert rel_err_trun_8b < 1e-2
113118 assert rel_err_trun_18b8b < 1e-2
119+ assert rel_err_trun_18b8b_128 >= rel_err_trun_18b8b
114120
115121
116122@pytest .mark .parametrize ("feat_in_out" , [(64 , 128 ), (256 , 1024 ), (1024 , 4096 )])
0 commit comments