@@ -278,7 +278,7 @@ def tl_matmul_chunk_truncate(
278278 activation = "" ,
279279 chunk_trun_bits = 0 ,
280280 chunk_size = 16 ,
281- cast_output_to_input_dtype = True ,
281+ cast_output_to_input_dtype = None ,
282282):
283283 """Triton matmul for HW behavior simulation. Supports float and int8.
284284 a. variable chunk size (i.e., BLOCK_SIZE_K)
@@ -291,8 +291,7 @@ def tl_matmul_chunk_truncate(
291291 chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
292292 cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
293293 FP32 or INT32. by default we cast the final
294- output to the same dtype as input, but can be
295- changed if needed.
294+ output to the same dtype as input for non-8bits.
296295
297296 Returns:
298297 _type_: _description_
@@ -306,6 +305,8 @@ def tl_matmul_chunk_truncate(
306305 assert a .is_contiguous (), "Matrix A must be contiguous"
307306 assert a .dtype == b .dtype , "Input dtypes inconsistent"
308307
308+ if cast_output_to_input_dtype is None :
309+ cast_output_to_input_dtype = a .dtype not in DTYPE_8BIT
309310 allowed_dtypes = [torch .float , torch .bfloat16 , torch .float16 ]
310311 cuda_cc = torch .cuda .get_device_capability ()
311312 if cuda_cc [0 ] >= 8 :
0 commit comments