Skip to content

Commit 275d47d

Browse files
new triton verson doesn't like 0xFFFFFFFF as a const
Signed-off-by: cliu-us <[email protected]>
1 parent 8d6bb65 commit 275d47d

File tree

5 files changed

+42
-21
lines changed

5 files changed

+42
-21
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def matmul_kernel(
164164
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
165165
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
166166
# 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
167-
trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32)
167+
trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32)
168168
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
169169
## ---------------------------------------------------------
170170

@@ -386,7 +386,7 @@ def matmul_kernel_DABC(
386386
# NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b
387387
# e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000
388388
# 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080
389-
trun_mask = tl.cast((0xFFFFFFFF >> chunk_trun_bits) << chunk_trun_bits, tl.uint32)
389+
trun_mask = ~tl.cast((1 << chunk_trun_bits) - 1, tl.uint32)
390390
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
391391
## ---------------------------------------------------------
392392

@@ -448,10 +448,11 @@ def round_and_trun(x, round_bit, trun_mask):
448448
@triton.jit
449449
def fp32_clamp_to_dl16(x):
450450
"""clamp FP32 (1-8-23) TENSOR x to DL16 (1-6-9) range."""
451-
# 1. rounding: add round bit to full uint representation, zero out last 13 bits, back to float
451+
# 1. rounding: add round bit, zero out last 13 bits, back to float
452452
x = libdevice.float_as_uint(x)
453453
round_bit = 1 << (23 - 9 - 1)
454-
x = libdevice.uint_as_float(((x + round_bit) >> 13) << 13)
454+
mask_13x0 = ~tl.cast((1 << 13) - 1, tl.uint32)
455+
x = libdevice.uint_as_float((x + round_bit) & mask_13x0)
455456

456457
# 2. clamp to min/max:
457458
# max = 2^32 * 1.(1111 1111 0)_base2 => 2^32*1.(1111 1111 1) will become inf

fms_mo/custom_ext_kernels/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,12 @@ def lower_qmodel_triton(
918918

919919
if layer_to_exclude is None:
920920
layer_to_exclude = []
921+
elif isinstance(layer_to_exclude, str):
922+
layer_to_exclude = [
923+
layer_to_exclude,
924+
]
925+
elif not isinstance(layer_to_exclude, (list, tuple)):
926+
raise RuntimeError("layer_to_exclude has to be either str, list, or tuple.")
921927

922928
for name, m in model.named_modules():
923929
if not isinstance(m, (QLinear, torch.nn.Linear)) or name in layer_to_exclude:

fms_mo/dq.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -216,17 +216,18 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
216216
act_scales = get_act_scales(model, dq_dataloader, qcfg)
217217
torch.save(act_scales, scale_file)
218218

219-
qmodel_prep(
220-
model,
221-
dq_dataloader,
222-
qcfg,
223-
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
224-
use_dynamo=use_dynamo,
225-
dev=dev,
226-
save_fname="dq",
227-
)
228-
logger.info(f"Quantized model {model}")
229-
logger.info("==" * 20)
219+
if fms_mo_args.aiu_sim_triton != "fp8":
220+
qmodel_prep(
221+
model,
222+
dq_dataloader,
223+
qcfg,
224+
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
225+
use_dynamo=use_dynamo,
226+
dev=dev,
227+
save_fname="dq",
228+
)
229+
logger.info(f"Quantized model {model}")
230+
logger.info("==" * 20)
230231

231232
if qcfg["smoothq"]:
232233
logger.info("Starting to apply smooth scale")
@@ -260,14 +261,16 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
260261
tokenizer.save_pretrained(opt_args.output_dir)
261262

262263
if fms_mo_args.aiu_sim_triton:
264+
# NOTE plz apply correct HW settings here, defaults are not real HW params
263265
lower_qmodel_triton(
264266
model,
265267
use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False,
266268
max_acc_bits=qcfg.get("max_acc_bits", 32),
267269
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
268-
chunk_size=qcfg.get("chunk_size", 1024),
270+
chunk_size=qcfg.get("chunk_size", 32), # 1024
271+
clamp_acc_to_dl16=False, # fms_mo_args.aiu_sim_triton == "fp8"
272+
# layer_to_exclude=["lm_head",]
269273
)
270-
271274
if fms_mo_args.eval_ppl:
272275
path_test = Path(data_args.test_data_path)
273276
arrow_files = list(path_test.glob("*.arrow"))

fms_mo/modules/linear.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,8 +1934,12 @@ def forward(
19341934
ctx.fp8_e4m3_max = torch.finfo(torch.float8_e4m3fn).max
19351935
ctx.fp8_e5m2_max = torch.finfo(torch.float8_e5m2).max
19361936
reduce_dim = None if fp8_dyn == "per_tensor" else 1
1937-
x_scale = x.abs().amax(dim=reduce_dim) / ctx.fp8_e4m3_max
1938-
w_scale = weight.abs().amax(dim=reduce_dim) / ctx.fp8_e4m3_max
1937+
x_scale = (
1938+
x.abs().amax(dim=reduce_dim, keepdim=True) / ctx.fp8_e4m3_max
1939+
).clamp(min=1e-5)
1940+
w_scale = (
1941+
weight.abs().amax(dim=reduce_dim, keepdim=True) / ctx.fp8_e4m3_max
1942+
).clamp(min=1e-5)
19391943

19401944
x = (x / x_scale).to(torch.float8_e4m3fn).to(org_dtype) * x_scale
19411945
weight = (weight / w_scale).to(torch.float8_e4m3fn).to(org_dtype) * w_scale

fms_mo/training_args.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,15 @@ class FMSMOArguments(TypeChecker):
181181
default=2048, metadata={"help": "input sequence length after tokenization"}
182182
)
183183
eval_ppl: bool = field(default=False)
184-
aiu_sim_triton: bool = field(
185-
default=False, metadata={"help": ("AIU simulation with triton kernel")}
184+
aiu_sim_triton: str = field(
185+
default=None,
186+
metadata={
187+
"help": (
188+
"AIU simulation with triton kernel. ['int8', 'fp8', None]\n"
189+
"'int8' mode will trigger qmodel_prep() and swap QLinears"
190+
"'fp8' mode will directly replace existing nn.Linears"
191+
)
192+
},
186193
)
187194
recompute_narrow_weights: bool = field(
188195
default=False,

0 commit comments

Comments
 (0)