Skip to content

Commit 5283323

Browse files
fix DL8/DL16 bugs and a couple other minor bugs fix
Signed-off-by: cliu-us <[email protected]>
1 parent b685ea8 commit 5283323

File tree

4 files changed

+67
-67
lines changed

4 files changed

+67
-67
lines changed

fms_mo/dq.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -268,35 +268,9 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
268268
max_acc_bits=qcfg.get("max_acc_bits", 32),
269269
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
270270
chunk_size=qcfg.get("chunk_size", 32), # 1024
271-
clamp_acc_to_dl16=False, # fms_mo_args.aiu_sim_triton == "fp8"
271+
clamp_acc_to_dl16=fms_mo_args.aiu_sim_triton == "fp8",
272272
# layer_to_exclude=["lm_head",]
273273
)
274-
# [CL] -------- record W, A, qW, qA with hooks ----------------
275-
# from fms_mo.modules.linear import QLinear, QLinearINT8Deploy
276-
# from fms_mo.quant.ptq import HookRecPostQuantInOut
277-
# cache_dict = {}
278-
# hook_handles = []
279-
# for n, m in model.named_modules():
280-
# if not isinstance(m, (QLinear, QLinearINT8Deploy, torch.nn.Linear)):
281-
# continue
282-
283-
# m.mod_name = n
284-
# hook_handles.append(
285-
# m.register_forward_hook( HookRecPostQuantInOut(cache_dict, n))
286-
# )
287-
288-
# data_mb = next(iter(eval_dataloader))
289-
# with torch.no_grad():
290-
# model(**data_mb)
291-
292-
# for h in hook_handles:
293-
# h.remove()
294-
295-
# torch.save(
296-
# cache_dict,
297-
# f"roberta_sqv2_data_dump_{qcfg['qa_mode']}_{qcfg['qw_mode']}_chunk64_lsb{args.aiu_int_lsb_trun}_dq.pt"
298-
# )
299-
# return
300274

301275
if fms_mo_args.eval_ppl:
302276
path_test = Path(data_args.test_data_path)

fms_mo/fx/dynamo_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,14 +1180,20 @@ def cus_backend_model_analyzer(
11801180
if is_transformers:
11811181
# NOTE simplified method to determine 1st/last modules for transformers.
11821182
# will not work if model has multiple parallel heads at the end, e.g. obj det
1183-
def call_seq_hook(mod, *_args, **_kwargs):
1184-
qcfg["mod_call_seq"].append(lut_weight2modname[mod.weight])
1183+
def call_seq_hook(mod, *_args, **kwargs):
1184+
mod_name = kwargs.get("mod_name", lut_weight2modname.get(mod.weight, None))
1185+
if mod_name is None:
1186+
raise RuntimeError("cannot determine module name, plz check model.")
1187+
1188+
qcfg["mod_call_seq"].append(mod_name)
11851189

11861190
h_hooks = []
11871191
qcfg["mod_call_seq"] = []
11881192
for n, m in model.named_modules():
11891193
if isinstance(m, (torch.nn.Linear, torch.nn.Conv2d)):
1190-
h_hooks.append(m.register_forward_hook(call_seq_hook))
1194+
h_hooks.append(
1195+
m.register_forward_hook(partial(call_seq_hook, mod_name=n))
1196+
)
11911197

11921198
with torch.no_grad():
11931199
run_fwd_once(model, sample_inp)

fms_mo/fx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,14 +461,14 @@ def model_size_Wb(mod, unit="MB", print_to_file=True, show_details=False):
461461
w_mat.numel() * w_mat.element_size()
462462
+ b_mat.numel() * b_mat.element_size()
463463
)
464-
w_dtype = w_mat.dtype
464+
w_dtype = str(w_mat.dtype)
465465
w_shape = w_mat.shape
466466

467467
elif isinstance(w, torch.Tensor):
468468
mem_use = w.numel() * w.element_size()
469469
if hasattr(m, "bias") and m.bias is not None:
470470
mem_use += m.bias.numel() * m.bias.element_size()
471-
w_dtype = w.dtype
471+
w_dtype = str(w.dtype)
472472
w_shape = w.shape
473473

474474
if w_shape:

fms_mo/modules/linear.py

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,14 +1926,16 @@ def forward(
19261926
ctx.chunk_size = chunk_size
19271927
ctx.fp8_dyn = fp8_dyn
19281928
ctx.clamp_acc_to_dl16 = clamp_acc_to_dl16
1929+
ctx.fp8_e4m3_max = torch.finfo(torch.float8_e4m3fn).max
1930+
ctx.fp8_e5m2_max = torch.finfo(torch.float8_e5m2).max
19291931
ctx.dl8_min = 0.0087890625
19301932

1933+
x_scale = torch.tensor(1.0, device=x.device, dtype=org_dtype)
1934+
w_scale = x_scale.clone()
19311935
if fp8_dyn:
19321936
# use Q/dQ simulation for now, meaning still compute in fp16/bf16
19331937
# if choose per_token for input, use per_channel for W
19341938
# (W saved as [out, in], reduce inCh-dim, => reduce_dim=1)
1935-
ctx.fp8_e4m3_max = torch.finfo(torch.float8_e4m3fn).max
1936-
ctx.fp8_e5m2_max = torch.finfo(torch.float8_e5m2).max
19371939
reduce_dim = None if fp8_dyn == "per_tensor" else 1
19381940
x_scale = (
19391941
x.abs().amax(dim=reduce_dim, keepdim=True) / ctx.fp8_e4m3_max
@@ -1942,22 +1944,30 @@ def forward(
19421944
weight.abs().amax(dim=reduce_dim, keepdim=True) / ctx.fp8_e4m3_max
19431945
).clamp(min=1e-5)
19441946

1945-
x = (x / x_scale).to(torch.float8_e4m3fn).to(org_dtype) * x_scale
1946-
weight = (weight / w_scale).to(torch.float8_e4m3fn).to(org_dtype) * w_scale
1947+
x = (x / x_scale).to(torch.float8_e4m3fn).to(torch.float32)
1948+
weight = (weight / w_scale).to(torch.float8_e4m3fn).to(torch.float32)
19471949
if clamp_acc_to_dl16:
1948-
# NOTE For DL8@DL8 acc in DL16, as DL8 doesn't support subnorm numbers like PyTorch
1949-
# (whose real min for e4m3fn is 2^-9), need to flush subnorm numbers to 0
1950-
x.masked_fill_(x < ctx.dl8_min, 0)
1951-
weight.masked_fill_(weight < ctx.dl8_min, 0)
1950+
# at this point, x and W are clamped to PT's FP8 range (2^-9 to 448). But since DL8
1951+
# doesn't support subnorm like PyTorch, need to flush subnorms to 0 BEFORE descaling
1952+
x.masked_fill_(x.abs() < ctx.dl8_min, 0)
1953+
weight.masked_fill_(weight.abs() < ctx.dl8_min, 0)
19521954

19531955
# triton kernel assumes 2D inputs and cast the return to input.dtype
1954-
output = tl_matmul(
1955-
x,
1956-
weight.t().to(org_dtype),
1957-
chunk_trun_bits=trun_bits,
1958-
chunk_size=chunk_size,
1959-
clamp_acc_to_dl16=clamp_acc_to_dl16,
1960-
).reshape(target_shape_output)
1956+
output = (
1957+
(
1958+
tl_matmul(
1959+
x,
1960+
weight.t(),
1961+
chunk_trun_bits=trun_bits,
1962+
chunk_size=chunk_size,
1963+
clamp_acc_to_dl16=clamp_acc_to_dl16,
1964+
)
1965+
* x_scale
1966+
* w_scale.t()
1967+
)
1968+
.to(org_dtype)
1969+
.reshape(target_shape_output)
1970+
)
19611971

19621972
if bias is not None:
19631973
output = output + bias.to(org_dtype)
@@ -1977,44 +1987,54 @@ def backward(ctx, grad_output):
19771987
target_shape_grad_input = grad_output.shape[:-1] + (in_dim,)
19781988
grad_output_2D = grad_output.reshape(-1, out_dim).to(dtype_input)
19791989

1990+
x_scale = torch.tensor(1.0, device=x.device, dtype=dtype_input)
1991+
w_scale = x_scale.clone()
19801992
if ctx.fp8_dyn:
19811993
reduce_dim = None if ctx.fp8_dyn == "per_tensor" else 1
19821994
x_scale = x.abs().amax(dim=reduce_dim) / ctx.fp8_e5m2_max
19831995
w_scale = weight.abs().amax(dim=reduce_dim) / ctx.fp8_e5m2_max
19841996
# always assume perT in this case
19851997
grad_out_scale = grad_output_2D.abs().amax(dim=None) / ctx.fp8_e5m2_max
19861998

1987-
x = (x / x_scale).to(torch.float8_e5m2).to(dtype_input) * x_scale
1988-
weight = (weight / w_scale).to(torch.float8_e5m2).to(weight.dtype) * w_scale
1989-
grad_output_2D = (grad_output_2D / grad_out_scale).to(torch.float8_e5m2).to(
1990-
grad_output.dtype
1991-
) * grad_out_scale
1999+
x = (x / x_scale).to(torch.float8_e5m2).to(torch.float)
2000+
weight = (weight / w_scale).to(torch.float8_e5m2).to(torch.float)
2001+
grad_output_2D = (
2002+
(grad_output_2D / grad_out_scale).to(torch.float8_e5m2).to(torch.float)
2003+
)
19922004
if ctx.clamp_acc_to_dl16:
19932005
# flush subnorm numbers to 0 as DL8 doesn't support it
1994-
x.masked_fill_(x < ctx.dl8_min, 0)
1995-
weight.masked_fill_(weight < ctx.dl8_min, 0)
1996-
grad_output_2D.masked_fill_(grad_output_2D < ctx.dl8_min, 0)
2006+
x.masked_fill_(x.abs() < ctx.dl8_min, 0)
2007+
weight.masked_fill_(weight.abs() < ctx.dl8_min, 0)
2008+
grad_output_2D.masked_fill_(grad_output_2D.abs() < ctx.dl8_min, 0)
19972009

19982010
# Compute grad_weight, shape = [out, in]
19992011
# NOTE: this triton kernel requires A matrix to be contiguous
2000-
grad_weight = tl_matmul(
2001-
grad_output_2D.transpose(0, 1).contiguous(),
2002-
x,
2003-
chunk_trun_bits=trun_bits,
2004-
chunk_size=chunk_size,
2005-
clamp_acc_to_dl16=ctx.clamp_acc_to_dl16,
2006-
).to(weight.dtype)
2007-
# Compute grad_input in 2D then reshape to target shape, could be 3D or 2D
2008-
grad_input = (
2012+
grad_weight = (
20092013
tl_matmul(
2010-
grad_output_2D,
2011-
weight.to(dtype_input),
2014+
grad_output_2D.transpose(0, 1).contiguous(),
2015+
x,
20122016
chunk_trun_bits=trun_bits,
20132017
chunk_size=chunk_size,
20142018
clamp_acc_to_dl16=ctx.clamp_acc_to_dl16,
20152019
)
2016-
.reshape(target_shape_grad_input)
2020+
* grad_out_scale.t()
2021+
* x_scale
2022+
).to(weight.dtype)
2023+
# Compute grad_input in 2D then reshape to target shape, could be 3D or 2D
2024+
grad_input = (
2025+
(
2026+
tl_matmul(
2027+
grad_output_2D,
2028+
weight,
2029+
chunk_trun_bits=trun_bits,
2030+
chunk_size=chunk_size,
2031+
clamp_acc_to_dl16=ctx.clamp_acc_to_dl16,
2032+
)
2033+
* grad_out_scale
2034+
* w_scale
2035+
)
20172036
.to(dtype_input)
2037+
.reshape(target_shape_grad_input)
20182038
)
20192039

20202040
if not ctx.has_bias:

0 commit comments

Comments
 (0)