Skip to content

Commit ed76001

Browse files
Merge pull request #69 from chichun-charlie-liu/bug_fix
fix: bug fix and minor changes on triton kernel:
2 parents 02f5ff3 + 19f3ae3 commit ed76001

File tree

3 files changed

+42
-20
lines changed

3 files changed

+42
-20
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
import triton
2121
import triton.language as tl
2222

23+
DTYPE_I8 = [torch.int8]
24+
DTYPE_F8 = [torch.float8_e4m3fn, torch.float8_e5m2]
25+
DTYPE_8BIT = DTYPE_I8 + DTYPE_F8
26+
2327

2428
def get_cuda_autotune_config(chunk_size=None):
2529
"""Basic use of triton.Config() is like:
@@ -145,8 +149,7 @@ def matmul_kernel(
145149
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
146150
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
147151
# 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
148-
full_32b_mask = 0xFFFFFFFF
149-
trun_mask = (full_32b_mask << chunk_trun_bits) & full_32b_mask
152+
trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32)
150153
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
151154
## ---------------------------------------------------------
152155

@@ -160,7 +163,7 @@ def matmul_kernel(
160163
# tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
161164

162165
## ------ add chunky LSB rounding/masking --------
163-
if chunk_trun_bits != 0:
166+
if chunk_trun_bits > 0:
164167
accumulator = libdevice.uint_as_float(
165168
(libdevice.float_as_uint(accumulator) + round_bit) & trun_mask
166169
)
@@ -269,7 +272,14 @@ def leaky_relu(x):
269272
return tl.where(x >= 0, x, 0.01 * x)
270273

271274

272-
def tl_matmul_chunk_truncate(a, b, activation="", chunk_trun_bits=0, chunk_size=16):
275+
def tl_matmul_chunk_truncate(
276+
a,
277+
b,
278+
activation="",
279+
chunk_trun_bits=0,
280+
chunk_size=16,
281+
cast_output_to_input_dtype=None,
282+
):
273283
"""Triton matmul for HW behavior simulation. Supports float and int8.
274284
a. variable chunk size (i.e., BLOCK_SIZE_K)
275285
b. LSB truncation, must <23 if using float.
@@ -279,6 +289,9 @@ def tl_matmul_chunk_truncate(a, b, activation="", chunk_trun_bits=0, chunk_size=
279289
activation (str, optional): activation func to be fused, see relu example.
280290
chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded.
281291
chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
292+
cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
293+
FP32 or INT32. by default we cast the final
294+
output to the same dtype as input for non-8bits.
282295
283296
Returns:
284297
_type_: _description_
@@ -292,30 +305,37 @@ def tl_matmul_chunk_truncate(a, b, activation="", chunk_trun_bits=0, chunk_size=
292305
assert a.is_contiguous(), "Matrix A must be contiguous"
293306
assert a.dtype == b.dtype, "Input dtypes inconsistent"
294307

308+
if cast_output_to_input_dtype is None:
309+
cast_output_to_input_dtype = a.dtype not in DTYPE_8BIT
295310
allowed_dtypes = [torch.float, torch.bfloat16, torch.float16]
296311
cuda_cc = torch.cuda.get_device_capability()
297312
if cuda_cc[0] >= 8:
298-
allowed_dtypes.append(torch.int8)
313+
allowed_dtypes += DTYPE_I8
299314
if cuda_cc[0] >= 9 or cuda_cc == (8, 9):
300-
allowed_dtypes += [torch.float8_e4m3fn, torch.float8_e5m2]
315+
allowed_dtypes += DTYPE_F8
301316
assert a.dtype in allowed_dtypes, "Input dtype is not supported"
302317
M, K = a.shape
303318
K, N = b.shape
304319

305-
# Allocates output, always accumulate in FP32/INT32 then cast (if floats)
320+
# Allocates output, always accumulate in FP32 (if floats) or INT32 then cast
306321
def isPowerofTwo(x):
307322
"""triton-specific limitation: block size needs to be power of 2."""
308323
return (x & (x - 1)) == 0
309324

310-
if a.dtype == torch.int8:
325+
min_chunk_size = 32 if a.dtype in DTYPE_8BIT else 16
326+
if isPowerofTwo(chunk_size):
327+
chunk_size = max(chunk_size, min_chunk_size)
328+
else:
329+
chunk_size = min_chunk_size
330+
331+
if a.dtype in DTYPE_I8:
332+
acc_dtype = torch.int32
311333
mm_kernel = imatmul_kernel
312-
chunk_size = max(chunk_size, 32) if isPowerofTwo(chunk_size) else 32
313-
c = torch.zeros((M, N), device=a.device, dtype=torch.int32)
314334
else:
315-
assert chunk_trun_bits < 23, "FP32 accumulator only has 23 mantissa bits"
335+
acc_dtype = torch.float32
316336
mm_kernel = matmul_kernel
317-
chunk_size = max(chunk_size, 16) if isPowerofTwo(chunk_size) else 16
318-
c = torch.zeros((M, N), device=a.device, dtype=torch.float32)
337+
assert chunk_trun_bits < 23, "FP32 accumulator only has 23 mantissa bits"
338+
c = torch.zeros((M, N), device=a.device, dtype=acc_dtype)
319339

320340
# 1D launch kernel where each block gets its own program.
321341
def grid(META):
@@ -327,7 +347,7 @@ def grid(META):
327347
kernel_config = {
328348
"BLOCK_SIZE_M": 128,
329349
"BLOCK_SIZE_K": chunk_size,
330-
"BLOCK_SIZE_N": 128, # was 32
350+
"BLOCK_SIZE_N": 32,
331351
"GROUP_SIZE_M": 8,
332352
"num_warps": 2,
333353
"num_stages": 5,
@@ -336,7 +356,7 @@ def grid(META):
336356
kernel_config = {
337357
"BLOCK_SIZE_M": 128,
338358
"BLOCK_SIZE_K": chunk_size,
339-
"BLOCK_SIZE_N": 128, # was 64
359+
"BLOCK_SIZE_N": 64,
340360
"GROUP_SIZE_M": 8,
341361
"num_warps": 4,
342362
"num_stages": 4,
@@ -359,4 +379,4 @@ def grid(META):
359379
ACTIVATION=activation,
360380
**kernel_config, # if using auto-tune, comment this line out.
361381
)
362-
return c.to(a.dtype) if a.dtype != torch.int8 else c
382+
return c.to(a.dtype) if cast_output_to_input_dtype else c

fms_mo/run_quant.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from huggingface_hub.errors import HFValidationError
3838
from torch.cuda import OutOfMemoryError
3939
from transformers import AutoTokenizer
40+
import torch
4041
import transformers
4142

4243
# Local
@@ -280,6 +281,10 @@ def parse_arguments(parser, json_config=None):
280281
_,
281282
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
282283

284+
model_args.torch_dtype = getattr(
285+
torch, model_args.torch_dtype.replace("torch.", ""), torch.bfloat16
286+
)
287+
283288
return (
284289
model_args,
285290
data_args,

fms_mo/training_args.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
from dataclasses import dataclass, field
2121
from typing import List, Optional, Union, get_args, get_origin
2222

23-
# Third Party
24-
import torch
25-
2623

2724
@dataclass
2825
class TypeChecker:
@@ -58,7 +55,7 @@ class ModelArguments(TypeChecker):
5855
"""Dataclass for model related arguments."""
5956

6057
model_name_or_path: str = field(default="facebook/opt-125m")
61-
torch_dtype: Union[torch.dtype, str] = torch.bfloat16
58+
torch_dtype: str = field(default="bfloat16")
6259
use_fast_tokenizer: bool = field(
6360
default=True,
6461
metadata={

0 commit comments

Comments
 (0)