Skip to content

Commit 3a08171

Browse files
bug fix and minor changes on triton kernel:
1. torch_dtype parsing didn't work with strings 2. minor adjustment on triton code for fp8 Signed-off-by: cliu-us <[email protected]>
1 parent 9fc7c75 commit 3a08171

File tree

3 files changed

+38
-17
lines changed

3 files changed

+38
-17
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
import triton
2121
import triton.language as tl
2222

23+
DTYPE_I8 = [torch.int8]
24+
DTYPE_F8 = [torch.float8_e4m3fn, torch.float8_e5m2]
25+
DTYPE_8BIT = DTYPE_I8 + DTYPE_F8
26+
2327

2428
def get_cuda_autotune_config(chunk_size=None):
2529
"""Basic use of triton.Config() is like:
@@ -145,8 +149,7 @@ def matmul_kernel(
145149
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
146150
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
147151
# 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
148-
full_32b_mask = 0xFFFFFFFF
149-
trun_mask = (full_32b_mask << chunk_trun_bits) & full_32b_mask
152+
trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32)
150153
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
151154
## ---------------------------------------------------------
152155

@@ -160,7 +163,7 @@ def matmul_kernel(
160163
# tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp
161164

162165
## ------ add chunky LSB rounding/masking --------
163-
if chunk_trun_bits != 0:
166+
if chunk_trun_bits > 0:
164167
accumulator = libdevice.uint_as_float(
165168
(libdevice.float_as_uint(accumulator) + round_bit) & trun_mask
166169
)
@@ -269,7 +272,14 @@ def leaky_relu(x):
269272
return tl.where(x >= 0, x, 0.01 * x)
270273

271274

272-
def tl_matmul_chunk_truncate(a, b, activation="", chunk_trun_bits=0, chunk_size=16):
275+
def tl_matmul_chunk_truncate(
276+
a,
277+
b,
278+
activation="",
279+
chunk_trun_bits=0,
280+
chunk_size=16,
281+
cast_output_to_input_dtype=True,
282+
):
273283
"""Triton matmul for HW behavior simulation. Supports float and int8.
274284
a. variable chunk size (i.e., BLOCK_SIZE_K)
275285
b. LSB truncation, must <23 if using float.
@@ -279,6 +289,10 @@ def tl_matmul_chunk_truncate(a, b, activation="", chunk_trun_bits=0, chunk_size=
279289
activation (str, optional): activation func to be fused, see relu example.
280290
chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded.
281291
chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16.
292+
cast_output_to_input_dtype (bool, optional): accumulator has higher prec than input, usually
293+
FP32 or INT32. by default we cast the final
294+
output to the same dtype as input, but can be
295+
changed if needed.
282296
283297
Returns:
284298
_type_: _description_
@@ -295,27 +309,32 @@ def tl_matmul_chunk_truncate(a, b, activation="", chunk_trun_bits=0, chunk_size=
295309
allowed_dtypes = [torch.float, torch.bfloat16, torch.float16]
296310
cuda_cc = torch.cuda.get_device_capability()
297311
if cuda_cc[0] >= 8:
298-
allowed_dtypes.append(torch.int8)
312+
allowed_dtypes += DTYPE_I8
299313
if cuda_cc[0] >= 9 or cuda_cc == (8, 9):
300-
allowed_dtypes += [torch.float8_e4m3fn, torch.float8_e5m2]
314+
allowed_dtypes += DTYPE_F8
301315
assert a.dtype in allowed_dtypes, "Input dtype is not supported"
302316
M, K = a.shape
303317
K, N = b.shape
304318

305-
# Allocates output, always accumulate in FP32/INT32 then cast (if floats)
319+
# Allocates output, always accumulate in FP32 (if floats) or INT32 then cast
306320
def isPowerofTwo(x):
307321
"""triton-specific limitation: block size needs to be power of 2."""
308322
return (x & (x - 1)) == 0
309323

310-
if a.dtype == torch.int8:
324+
min_chunk_size = 32 if a.dtype in DTYPE_8BIT else 16
325+
if isPowerofTwo(chunk_size):
326+
chunk_size = max(chunk_size, min_chunk_size)
327+
else:
328+
chunk_size = min_chunk_size
329+
330+
if a.dtype in DTYPE_I8:
331+
acc_dtype = torch.int32
311332
mm_kernel = imatmul_kernel
312-
chunk_size = max(chunk_size, 32) if isPowerofTwo(chunk_size) else 32
313-
c = torch.zeros((M, N), device=a.device, dtype=torch.int32)
314333
else:
315-
assert chunk_trun_bits < 23, "FP32 accumulator only has 23 mantissa bits"
334+
acc_dtype = torch.float32
316335
mm_kernel = matmul_kernel
317-
chunk_size = max(chunk_size, 16) if isPowerofTwo(chunk_size) else 16
318-
c = torch.zeros((M, N), device=a.device, dtype=torch.float32)
336+
assert chunk_trun_bits < 23, "FP32 accumulator only has 23 mantissa bits"
337+
c = torch.zeros((M, N), device=a.device, dtype=acc_dtype)
319338

320339
# 1D launch kernel where each block gets its own program.
321340
def grid(META):
@@ -327,7 +346,7 @@ def grid(META):
327346
kernel_config = {
328347
"BLOCK_SIZE_M": 128,
329348
"BLOCK_SIZE_K": chunk_size,
330-
"BLOCK_SIZE_N": 128, # was 32
349+
"BLOCK_SIZE_N": 32,
331350
"GROUP_SIZE_M": 8,
332351
"num_warps": 2,
333352
"num_stages": 5,
@@ -336,7 +355,7 @@ def grid(META):
336355
kernel_config = {
337356
"BLOCK_SIZE_M": 128,
338357
"BLOCK_SIZE_K": chunk_size,
339-
"BLOCK_SIZE_N": 128, # was 64
358+
"BLOCK_SIZE_N": 64,
340359
"GROUP_SIZE_M": 8,
341360
"num_warps": 4,
342361
"num_stages": 4,
@@ -359,4 +378,4 @@ def grid(META):
359378
ACTIVATION=activation,
360379
**kernel_config, # if using auto-tune, comment this line out.
361380
)
362-
return c.to(a.dtype) if a.dtype != torch.int8 else c
381+
return c.to(a.dtype) if cast_output_to_input_dtype else c

fms_mo/run_quant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from huggingface_hub.errors import HFValidationError
3838
from torch.cuda import OutOfMemoryError
3939
from transformers import AutoTokenizer
40+
import torch
4041
import transformers
4142

4243
# Local
@@ -306,6 +307,7 @@ def main():
306307
gptq_args,
307308
fp8_args,
308309
) = parse_arguments(parser, job_config)
310+
model_args.torch_dtype = getattr(torch, model_args.torch_dtype, torch.bfloat16)
309311

310312
logger = set_log_level(opt_args.log_level, __name__)
311313

fms_mo/training_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class ModelArguments(TypeChecker):
5858
"""Dataclass for model related arguments."""
5959

6060
model_name_or_path: str = field(default="facebook/opt-125m")
61-
torch_dtype: Union[torch.dtype, str] = torch.bfloat16
61+
torch_dtype: str = field(default="bfloat16")
6262
use_fast_tokenizer: bool = field(
6363
default=True,
6464
metadata={

0 commit comments

Comments
 (0)