@@ -55,7 +55,8 @@ def matmul_kernel(
5555 stride_bn ,
5656 stride_cm ,
5757 stride_cn ,
58- scale ,
58+ a_scale_ptr ,
59+ b_scale_ptr ,
5960 # Meta-parameters
6061 BLOCK_SIZE_M : tl .constexpr ,
6162 BLOCK_SIZE_N : tl .constexpr ,
@@ -92,6 +93,9 @@ def matmul_kernel(
9293 offs_bn = (pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )) % N
9394 a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_k [None , :] * stride_ak )
9495 b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
96+ if APPLY_SCALE :
97+ a_scale = tl .load (a_scale_ptr )
98+ b_scale = tl .load (b_scale_ptr )
9599
96100 acc_dtype = tl .float32 if c_ptr .type .element_ty != tl .int8 else tl .int32
97101 accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = acc_dtype )
@@ -110,12 +114,13 @@ def matmul_kernel(
110114 # Advance the ptrs to the next K block.
111115 a_ptrs += BLOCK_SIZE_K * stride_ak
112116 b_ptrs += BLOCK_SIZE_K * stride_bk
117+ # Apply scale to recover dynamic range reduced due to lower precision inputs.
118+ if APPLY_SCALE :
119+ accumulator = accumulator * a_scale * b_scale
113120 # Apply activation function, if specified.
121+ # TODO(vgokhale): Add different types of activations.
114122 if ACTIVATION == "leaky_relu" :
115123 accumulator = leaky_relu (accumulator )
116- # Apply scale to recover dynamic range reduced due to lower precision inputs.
117- if APPLY_SCALE :
118- accumulator = accumulator * scale
119124 c = accumulator .to (c_ptr .type .element_ty )
120125
121126 # Write back the block of the output matrix C with masks.
@@ -134,15 +139,13 @@ def leaky_relu(x):
134139
135140
136141# Wrapper for gemm kernel.
137- def matmul (a , b , c , a_scale , b_scale , activation = "" ):
142+ def matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = False , activation = "" ):
138143 # Check constraints.
139144 assert a .shape [1 ] == b .shape [0 ], "Incompatible dimensions!!!"
140145 assert a .dtype == b .dtype , "Mixed dtype GEMMs are not supported!!!"
141146 M , K = a .shape
142147 K , N = b .shape
143148 grid = lambda META : (triton .cdiv (M , META ['BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]), )
144- apply_scale = a_scale is not None and b_scale is not None
145- scale = a_scale * b_scale if apply_scale else None
146149 matmul_kernel [grid ](
147150 a ,
148151 b ,
@@ -156,8 +159,9 @@ def matmul(a, b, c, a_scale, b_scale, activation=""):
156159 b .stride (1 ),
157160 c .stride (0 ),
158161 c .stride (1 ),
159- scale ,
160- APPLY_SCALE = apply_scale ,
162+ a_scale ,
163+ b_scale ,
164+ APPLY_SCALE = scale_a8_b8 ,
161165 ACTIVATION = activation ,
162166 )
163167
@@ -173,9 +177,12 @@ def matmul(a, b, c, a_scale, b_scale, activation=""):
173177}
174178
175179dtype_max = {
176- torch .float8_e5m2fnuz : 57344 ,
177- torch .float8_e4m3fnuz : 240 ,
178- torch .int8 : 127 ,
180+ dtype : (torch .finfo (dtype ) if dtype .is_floating_point else torch .iinfo (dtype )).max
181+ for dtype in [
182+ torch .float8_e5m2fnuz ,
183+ torch .float8_e4m3fnuz ,
184+ torch .int8 ,
185+ ]
179186}
180187
181188
@@ -213,6 +220,7 @@ def get_x_vals():
213220
214221
215222# Unit tests
223+ #TODO(vgokhale): Test activation.
216224@pytest .mark .parametrize (
217225 "M, N, K, in_dtype, out_dtype, col_a, col_b" ,
218226 [(* shape , in_dtype , out_dtype , col_a , col_b )
@@ -232,12 +240,12 @@ def test_correctness(M, N, K, col_a, col_b, in_dtype, out_dtype):
232240 # This requires us to compute in fp32 because for e5m2, the range is same as fp16 (e5m10).
233241 # If we use fp16 it is possible to return infs from the torch.matmul call.
234242 if dtype_is_8_bit (torch_in_dtype ):
235- matmul (a , b , c , a_scale . item () , b_scale . item () , activation = "" )
243+ matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = True , activation = "" )
236244 torch_output = torch .matmul (a_fp32 , b_fp32 )
237245 torch_output = torch_output * a_scale * b_scale
238246 # For other dtypes, use the same torch matmul as the dtype.
239247 else :
240- matmul (a , b , c , a_scale = None , b_scale = None , activation = "" )
248+ matmul (a , b , c , a_scale = None , b_scale = None , scale_a8_b8 = False , activation = "" )
241249 torch_output = torch .matmul (a .to (torch_in_dtype ), b .to (torch_in_dtype ))
242250 if out_dtype == 'int8' :
243251 torch .testing .assert_close (c .to (torch .float32 ),
0 commit comments