Skip to content

Commit 988add0

Browse files
Merge pull request #141 from chichun-charlie-liu/triton_aiu_sim
fix: feat: fix for new transformers (>4.48) and new QLinear for INT8 training with HW emulation
2 parents b47b08f + cb1dfca commit 988add0

File tree

10 files changed

+332
-62
lines changed

10 files changed

+332
-62
lines changed

fms_mo/calib.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -574,13 +574,19 @@ def qmodel_calib(
574574
f"Qmodel calibration (clip_val analysis) in progress: {i}/{Nbatch}"
575575
)
576576

577-
if "perCh" not in qcfg["qw_mode"]:
578-
cv_sum_dict = {"layer": [], "value": []}
579-
for k, v in tempmodel.state_dict().items():
580-
if "clip" in k:
581-
cv_sum_dict["layer"].append(k)
582-
cv_sum_dict["value"].append(v.item())
583-
logger.info(f"Observed clipvals: \n{ pd.DataFrame(cv_sum_dict) }")
577+
cv_sum_dict = {"layer": [], "value": []}
578+
for k, v in tempmodel.state_dict().items():
579+
if "clip" not in k:
580+
continue
581+
582+
if v.numel() > 1:
583+
k = k + "*"
584+
v = v.mean()
585+
cv_sum_dict["layer"].append(k)
586+
cv_sum_dict["value"].append(v.item())
587+
logger.info(
588+
f"Observed clipvals: ('*' if it's a vector) \n{ pd.DataFrame(cv_sum_dict) }"
589+
)
584590

585591
# Step 3: extract new clip_vals, params and buffers, then remove handles if needed
586592
temp_new_clipvals = {

fms_mo/dq.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636

3737
# Local
3838
from fms_mo import qconfig_init, qmodel_prep
39+
from fms_mo.custom_ext_kernels.utils import (
40+
lower_qmodel_triton, # pylint: disable=unused-import
41+
)
3942
from fms_mo.fx.utils import model_size_Wb
4043
from fms_mo.quant.ptq import (
4144
calibration_llm_1GPU_v2,
@@ -256,6 +259,15 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
256259
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
257260
tokenizer.save_pretrained(opt_args.output_dir)
258261

262+
if fms_mo_args.aiu_sim_triton:
263+
lower_qmodel_triton(
264+
model,
265+
use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False,
266+
max_acc_bits=qcfg.get("max_acc_bits", 32),
267+
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
268+
chunk_size=qcfg.get("chunk_size", 1024),
269+
)
270+
259271
if fms_mo_args.eval_ppl:
260272
path_test = Path(data_args.test_data_path)
261273
arrow_files = list(path_test.glob("*.arrow"))

fms_mo/modules/bmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def forward(self, m1, m2):
192192
torch.Tensor: Output tensor after quantized bmm.
193193
"""
194194
# pylint: disable = access-member-before-definition
195-
if self.calib_counter:
195+
if self.calib_counter > 0:
196196
with torch.no_grad():
197197
qm1 = self.quantize_calib_m1(m1)
198198
qm2 = self.quantize_calib_m2(m2)

fms_mo/modules/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def forward(self, x):
270270
torch.Tensor: Output tensor of shape (batch_size, out_channels, out_height, out_width).
271271
"""
272272
# pylint: disable = access-member-before-definition
273-
if self.calib_counter:
273+
if self.calib_counter > 0:
274274
with torch.no_grad():
275275
qinput = self.quantize_calib_feature(x)
276276
qweight = self.quantize_calib_weight(self.weight)

fms_mo/modules/linear.py

Lines changed: 291 additions & 46 deletions
Large diffs are not rendered by default.

fms_mo/quant/quantizers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3476,11 +3476,14 @@ def __init__(self, num_bits):
34763476
"""
34773477
super().__init__()
34783478
self.num_bits = num_bits
3479+
self.register_buffer("clip_val", torch.Tensor([0.0]))
3480+
self.register_buffer("clip_valn", torch.Tensor([0.0]))
34793481

34803482
def forward(self, input_tensor):
3481-
scales = input_tensor.abs().max(dim=-1, keepdim=True)[0]
3483+
self.clip_val = input_tensor.abs().max(dim=-1, keepdim=True)[0]
3484+
self.clip_valn = -self.clip_val
34823485
levels = 2 ** (self.num_bits - 1) - 1
3483-
scales.clamp_(min=1e-5).div_(levels)
3486+
scales = self.clip_val.clamp(min=1e-5).div(levels)
34843487
input_tensor.div_(scales).round_().mul_(scales)
34853488
return input_tensor
34863489

fms_mo/utils/aiu_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def process_zero_shift(
309309
a_cvn = model.state_dict()[a_cvn_name]
310310

311311
# compute "zero_shift" correction factor only for asymmetric activations
312-
if a_cv and a_cvn and a_cv != -a_cvn:
312+
if not (a_cv is None or a_cvn is None or torch.equal(a_cv, -a_cvn)):
313313
if weight_int is None:
314314
logger.info(
315315
f"As weights appear to be not quantized, zero shift for {k} "

fms_mo/utils/eval_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ def evaluate(self, model, block_size=2048):
152152
model.device
153153
)
154154
with torch.no_grad():
155-
lm_logits = model(batch, return_dict=True).logits
155+
mod_out = model(batch, return_dict=True)
156+
# for newer transformers, model output could be simply a tuple
157+
lm_logits = getattr(mod_out, "logits", mod_out[0])
156158
shift_logits = lm_logits[:, :-1, :].contiguous().float()
157159
shift_labels = self.dataset[:, (i * block_size) : ((i + 1) * block_size)][
158160
:, 1:

tests/models/test_model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def check_linear_dtypes(state_dict: dict, linear_names: list):
232232
if any(n in k for n in linear_names):
233233
if k.endswith(".weight"):
234234
assert v.dtype == torch.int8
235-
elif k.endswith(".zero_point"):
235+
elif k.endswith(".zero_point") or k.endswith(".zero_shift"):
236236
assert v.dtype == torch.float32
237237
else:
238238
assert v.dtype == torch.float16

tests/triton_kernels/test_triton_mm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ def test_triton_matmul_fp(mkn, dtype_to_test):
6969
.to("cuda")
7070
.to(torch.float)
7171
)
72-
tl_output_no_trun = tl_matmul(a, b).to(torch.float)
73-
tl_output_trun_8b = tl_matmul(a, b, chunk_trun_bits=8).to(torch.float)
72+
tl_output_no_trun = tl_matmul(a, b, truncate_then_accumulate=False).to(torch.float)
73+
tl_output_trun_8b = tl_matmul(
74+
a, b, chunk_trun_bits=8, truncate_then_accumulate=False
75+
).to(torch.float)
7476

7577
diff_no_trun = torch_output - tl_output_no_trun
7678
diff_trun_8b = torch_output - tl_output_trun_8b

0 commit comments

Comments
 (0)