Skip to content

Commit 766c805

Browse files
fix for new transformers (>4.48) and new QLinear for INT8 training taking HW simulation into account
Signed-off-by: cliu-us <[email protected]>
1 parent 1f9a1cc commit 766c805

File tree

3 files changed

+168
-7
lines changed

3 files changed

+168
-7
lines changed

fms_mo/dq.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,10 +262,10 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
262262
if fms_mo_args.aiu_sim_triton:
263263
lower_qmodel_triton(
264264
model,
265-
use_dyn_max_act=-1,
266-
max_acc_bits=24,
267-
num_lsb_to_truncate=8,
268-
chunk_size=32,
265+
use_dyn_max_act=-1 if qcfg["qa_mode"] == "pertokenmax" else False,
266+
max_acc_bits=qcfg.get("max_acc_bits", 32),
267+
num_lsb_to_truncate=qcfg.get("lsb_trun_bits", 0),
268+
chunk_size=qcfg.get("chunk_size", 1024),
269269
)
270270

271271
if fms_mo_args.eval_ppl:

fms_mo/modules/linear.py

Lines changed: 161 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,6 @@ def forward(self, x):
269269
if self.calib_counter == 0:
270270
self.quantize_calib_feature = None
271271
self.quantize_calib_weight = None
272-
self.calib_counter = None # [optional] this should release the memory
273272

274273
elif self.ptqmode == "fp32_out":
275274
if self.W_fp is None:
@@ -1101,7 +1100,6 @@ def qa_dyn_max_fake_zero_shift(self, x):
11011100
to multiply input_scale. (Assuming per-tensor, can shift left or right)
11021101
"""
11031102
amax = x.abs().max()
1104-
shift_dir = 1 if amax == x.max() else -1
11051103
levels = 2 ** (self.nbits_a - 1) - 1 - self.input_zp
11061104
self.cvs[0] = amax
11071105
self.cvs[1] = -amax
@@ -2061,6 +2059,167 @@ def extra_repr(self) -> str:
20612059
)
20622060

20632061

2062+
class LinearFuncINT8FwdFP32Bwd(torch.autograd.Function):
2063+
"""[Experimental] Linear autograd function using INT matmul/accumulation to simulate HW behavior
2064+
during QAT, in order to adjust weights for specific HW design.
2065+
Args:
2066+
activation x: FP tensor, need to be reshaped to 2D and quantized to INT8.
2067+
weight: FP 2D tensor, W.shape = [out, in].
2068+
bias: bias from original Linear, does not include INT ZP correction term yet.
2069+
NOTE:
2070+
1. main purpose is to utilize triton INT kernel to simulate MSB/LSB truncation in FWD.
2071+
2. BWD simply uses torch.matmul.
2072+
3. *Max per-Ch* for weights and *dynamic max per-Token* for activations.
2073+
"""
2074+
2075+
@staticmethod
2076+
def forward(
2077+
ctx,
2078+
x,
2079+
weight,
2080+
bias=None,
2081+
lsb_trun_bits=0,
2082+
chunk_size=64,
2083+
max_acc_bits=32,
2084+
):
2085+
assert x.dtype in [torch.float, torch.bfloat16, torch.float16]
2086+
# input can be 2D or 3D, need to reshape before tl_matmul
2087+
org_dtype = x.dtype
2088+
target_shape_output = x.shape[:-1] + (weight.shape[0],)
2089+
x = x.reshape(-1, x.shape[-1])
2090+
2091+
if bias is not None:
2092+
ctx.has_bias = True
2093+
ctx.bias_dtype = bias.dtype
2094+
else:
2095+
ctx.has_bias = False
2096+
2097+
ctx.save_for_backward(x, weight) # x, W are saved in their original dtype
2098+
2099+
# max per_token for input -> reduce_dim = -1
2100+
# per_channel for W but W.shape = [out, in] -> reduce_dim = -1
2101+
# sym activation -> correction term = 0
2102+
x_scale = x.abs().amax(dim=-1, keepdim=True) / 127
2103+
w_scale = weight.abs().amax(dim=-1, keepdim=True) / 127
2104+
2105+
x_i8 = torch.round(x / x_scale).to(torch.int8)
2106+
w_i8 = torch.round(weight / w_scale).to(torch.int8)
2107+
2108+
# triton kernel accepts 2d int8 then return int32
2109+
output = tl_matmul(
2110+
x_i8,
2111+
w_i8.t(),
2112+
chunk_trun_bits=lsb_trun_bits,
2113+
chunk_size=chunk_size,
2114+
max_acc_bits=max_acc_bits,
2115+
)
2116+
output = (
2117+
(output.to(torch.float) * x_scale * w_scale.t())
2118+
.reshape(target_shape_output)
2119+
.to(org_dtype)
2120+
)
2121+
if bias is not None:
2122+
output = output + bias.to(org_dtype)
2123+
2124+
return output
2125+
2126+
@staticmethod
2127+
def backward(ctx, grad_output):
2128+
# load x and W from context, x is 2D already. no quant.
2129+
# option 1: use compute dtype = x.dtype
2130+
# option 2: compute in fp32 for best results.
2131+
x, weight = ctx.saved_tensors # x, W are saved in original dtype
2132+
out_dim = weight.shape[0]
2133+
in_dim = weight.shape[1]
2134+
dtype_grad = x.dtype # torch.float
2135+
# grad_input and grad_output could be 3D as x
2136+
target_shape_grad_input = grad_output.shape[:-1] + (in_dim,)
2137+
grad_output_2D = grad_output.reshape(-1, out_dim).to(dtype_grad)
2138+
2139+
# Compute grad_weight, shape = [out, in]
2140+
grad_weight = torch.matmul(
2141+
grad_output_2D.transpose(0, 1).contiguous(),
2142+
x.to(dtype_grad),
2143+
).to(weight.dtype)
2144+
# Compute grad_input in 2D then reshape to target shape, could be 3D or 2D
2145+
grad_input = (
2146+
torch.matmul(
2147+
grad_output_2D,
2148+
weight.to(dtype_grad),
2149+
)
2150+
.reshape(target_shape_grad_input)
2151+
.to(x.dtype)
2152+
)
2153+
2154+
if not ctx.has_bias:
2155+
grad_bias = None
2156+
else:
2157+
grad_bias = grad_output_2D.sum(0).to(ctx.bias_dtype)
2158+
2159+
return grad_input, grad_weight, grad_bias, None, None, None
2160+
2161+
2162+
class QLinearINT8Train(torch.nn.Linear):
2163+
"""QLinear layer wrapper that simulates INT8 HW behavior, e.g. MSB/LSB truncation, in forward
2164+
and FP32 in backward.
2165+
"""
2166+
2167+
@classmethod
2168+
def from_fms_mo(cls, nnlin, lsb_trun_bits=0, **kwargs):
2169+
"""Converts a torch.nn.Linear or QLinear module to QLinearINT8Train
2170+
2171+
Args:
2172+
cls (class): The class to be created.
2173+
nnlin (torch.nn.Linear or QLinear): The Linear module to be converted.
2174+
lsb_trun_bits (int): INT8 LSB truncation, [0 to 16].
2175+
chunk_size (int): usually >= 64 for INT8, based on HW design.
2176+
max_acc_bits: accumulator max bits, <=32, based on HW design.
2177+
2178+
Returns:
2179+
LinearFPxAcc: The converted linear layer.
2180+
"""
2181+
2182+
target_device = kwargs.get(
2183+
"target_device", kwargs.get("device", next(nnlin.parameters()).device)
2184+
)
2185+
2186+
lin_int8fwd = cls(
2187+
nnlin.in_features,
2188+
nnlin.out_features,
2189+
bias=nnlin.bias is not None,
2190+
device="meta", # target_device,
2191+
)
2192+
2193+
lin_int8fwd.weight = nnlin.weight
2194+
lin_int8fwd.lsb_trun_bits = lsb_trun_bits
2195+
lin_int8fwd.chunk_size = kwargs.get("chunk_size", 64)
2196+
lin_int8fwd.max_acc_bits = kwargs.get("max_acc_bits", 32)
2197+
2198+
if nnlin.bias is not None:
2199+
lin_int8fwd.bias = nnlin.bias
2200+
return lin_int8fwd.to(target_device)
2201+
2202+
def forward(self, inputs):
2203+
return LinearFuncINT8FwdFP32Bwd.apply(
2204+
inputs,
2205+
self.weight,
2206+
self.bias,
2207+
self.lsb_trun_bits,
2208+
self.chunk_size,
2209+
self.max_acc_bits,
2210+
)
2211+
2212+
def extra_repr(self) -> str:
2213+
"""Returns an alternative string representation of the object."""
2214+
repr_str = f"{self.in_features},{self.out_features}"
2215+
repr_str += f",bias={self.bias is not None},chunk_size={self.chunk_size}"
2216+
if self.lsb_trun_bits > 0:
2217+
repr_str += f",lsb_trun={self.lsb_trun_bits}"
2218+
if self.max_acc_bits < 32:
2219+
repr_str += f",max_acc_bits={self.max_acc_bits}"
2220+
return repr_str
2221+
2222+
20642223
if available_packages["mx"]:
20652224
# Third Party
20662225
# pylint: disable = import-error

fms_mo/utils/eval_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ def evaluate(self, model, block_size=2048):
152152
model.device
153153
)
154154
with torch.no_grad():
155-
lm_logits = model(batch, return_dict=True).logits
155+
mod_out = model(batch, return_dict=True)
156+
# for newer transformers, model output could be simply a tuple
157+
lm_logits = getattr(mod_out, "logits", mod_out[0])
156158
shift_logits = lm_logits[:, :-1, :].contiguous().float()
157159
shift_labels = self.dataset[:, (i * block_size) : ((i + 1) * block_size)][
158160
:, 1:

0 commit comments

Comments
 (0)