Skip to content

Commit d555ef5

Browse files
Merge pull request #139 from iqbal-saraf/QmaxDynamic
feat: Add QmaxDynamic to allow unify Qmax , Qminmax, pertokenmax
2 parents 20f3d82 + 343abc7 commit d555ef5

File tree

1 file changed

+54
-15
lines changed

1 file changed

+54
-15
lines changed

fms_mo/quant/quantizers.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,23 +123,28 @@ def get_activation_quantizer(
123123
)
124124
elif qa_mode == "dorefa":
125125
act_quantizer = dorefa_quantize_activation
126-
elif (
127-
qa_mode == "max"
128-
): # NOTE Need to be careful using this for activation, particular to 1 sided.
129-
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
130-
elif qa_mode == "minmax":
131-
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
126+
127+
elif "max" in qa_mode:
128+
# NOTE Need to be careful using this for activation, particular to 1 sided.
129+
if "min" in qa_mode:
130+
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
131+
elif "pertoken" in qa_mode or "perToken" in qa_mode:
132+
act_quantizer = QMaxDynamic(nbits, dim=-1)
133+
elif "per_channel" in qa_mode or "perCh" in qa_mode:
134+
act_quantizer = QMaxDynamic(nbits, dim=-2)
135+
elif "sym" in qa_mode:
136+
act_quantizer = Qmax(
137+
nbits,
138+
align_zero=True,
139+
minmax=False,
140+
extend_act_range=extend_act_range,
141+
)
142+
else:
143+
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
132144
elif qa_mode == "fix":
133145
act_quantizer = QFixSymmetric(
134146
nbits, init_clip_val=clip_val, align_zero=align_zero
135147
)
136-
elif qa_mode == "maxsym":
137-
act_quantizer = Qmax(
138-
nbits,
139-
align_zero=True,
140-
minmax=False,
141-
extend_act_range=extend_act_range,
142-
)
143148
elif qa_mode == "pactsym":
144149
act_quantizer = PACT2Sym(
145150
nbits,
@@ -179,8 +184,6 @@ def get_activation_quantizer(
179184
perToken=perToken,
180185
emulate=True,
181186
)
182-
elif qa_mode == "pertokenmax":
183-
act_quantizer = PerTokenMax(nbits)
184187
else:
185188
raise ValueError(f"unrecognized activation quantization mode {qa_mode}")
186189
else: # swcap-compatible activation quantizers
@@ -3491,6 +3494,42 @@ def __repr__(self):
34913494
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"
34923495

34933496

3497+
class QMaxDynamic(nn.Module):
3498+
def __init__(self, num_bits, dim=-1):
3499+
"""
3500+
For per-token or per-channel quantization using abs().max() as scale, usually for activation
3501+
and could be used for Qbmm M2 as well.
3502+
(reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token
3503+
dim = -2 -> per-channel
3504+
Zero is aligned so that the levels are symmetric around zero (lossing one level)
3505+
Since the token length is un-known before running, the quantizater can only calculate the
3506+
scales at the run times dynamically, meaning no trainable quantization scales is allowed.
3507+
(unless input seq length is always the same, not just padded to a fixed length.)
3508+
"""
3509+
super().__init__()
3510+
self.num_bits = num_bits
3511+
self.levels = 2 ** (self.num_bits - 1) - 1
3512+
if isinstance(dim, str):
3513+
if "perCh" in dim or "per_channel" in dim:
3514+
dim = -2
3515+
elif "perToken" in dim or "per_token" in dim or "per_Token" in dim:
3516+
dim = -1
3517+
elif dim in [-1, -2]:
3518+
self.reduce_dim = dim
3519+
else:
3520+
raise ValueError(
3521+
f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found {dim}"
3522+
)
3523+
3524+
def forward(self, input_tensor):
3525+
amax_dim = input_tensor.abs().max(dim=self.reduce_dim, keepdim=True)[0]
3526+
scales = amax_dim.clamp(min=1e-5).div(self.levels)
3527+
return input_tensor.div(scales).round().mul(scales)
3528+
3529+
def __repr__(self):
3530+
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"
3531+
3532+
34943533
class Qdynamic(nn.Module):
34953534
def __init__(
34963535
self,

0 commit comments

Comments
 (0)