Skip to content

Commit 19f3ae3

Browse files
fix tests
Signed-off-by: cliu-us <[email protected]>
1 parent f093ea4 commit 19f3ae3

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def tl_matmul_chunk_truncate(
278278
activation="",
279279
chunk_trun_bits=0,
280280
chunk_size=16,
281-
cast_output_to_input_dtype=True,
281+
cast_output_to_input_dtype=None,
282282
):
283283
"""Triton matmul for HW behavior simulation. Supports float and int8.
284284
a. variable chunk size (i.e., BLOCK_SIZE_K)
@@ -291,8 +291,7 @@ def tl_matmul_chunk_truncate(
291291
chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
292292
cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
293293
FP32 or INT32. by default we cast the final
294-
output to the same dtype as input, but can be
295-
changed if needed.
294+
output to the same dtype as input for non-8bits.
296295
297296
Returns:
298297
_type_: _description_
@@ -306,6 +305,8 @@ def tl_matmul_chunk_truncate(
306305
assert a.is_contiguous(), "Matrix A must be contiguous"
307306
assert a.dtype == b.dtype, "Input dtypes inconsistent"
308307

308+
if cast_output_to_input_dtype is None:
309+
cast_output_to_input_dtype = a.dtype not in DTYPE_8BIT
309310
allowed_dtypes = [torch.float, torch.bfloat16, torch.float16]
310311
cuda_cc = torch.cuda.get_device_capability()
311312
if cuda_cc[0] >= 8:

fms_mo/run_quant.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,10 @@ def parse_arguments(parser, json_config=None):
281281
_,
282282
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
283283

284+
model_args.torch_dtype = getattr(
285+
torch, model_args.torch_dtype.replace("torch.", ""), torch.bfloat16
286+
)
287+
284288
return (
285289
model_args,
286290
data_args,
@@ -307,7 +311,6 @@ def main():
307311
gptq_args,
308312
fp8_args,
309313
) = parse_arguments(parser, job_config)
310-
model_args.torch_dtype = getattr(torch, model_args.torch_dtype, torch.bfloat16)
311314

312315
logger = set_log_level(opt_args.log_level, __name__)
313316

0 commit comments

Comments
 (0)