Skip to content

Commit 31c4e2f

Browse files
committed
Add QmaxDynamic to allow unify Qmax , Qminmax, pertokenmax
Signed-off-by: Iqbal Saraf <[email protected]>
1 parent 9623337 commit 31c4e2f

File tree

1 file changed

+55
-15
lines changed

1 file changed

+55
-15
lines changed

fms_mo/quant/quantizers.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,24 @@ def get_activation_quantizer(
123123
)
124124
elif qa_mode == "dorefa":
125125
act_quantizer = dorefa_quantize_activation
126-
elif (
127-
qa_mode == "max"
128-
): # NOTE Need to be careful using this for activation, particular to 1 sided.
129-
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
130-
elif qa_mode == "minmax":
131-
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
126+
127+
elif "max" in qa_mode:
128+
# NOTE Need to be careful using this for activation, particular to 1 sided.
129+
if "min" in qa_mode:
130+
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
131+
elif "pertoken" in qa_mode or "perToken" in qa_mode:
132+
act_quantizer = QMaxDynamic(nbits, dim=-1)
133+
elif "per_channel" in qa_mode or "perCh" in qa_mode:
134+
act_quantizer = QMaxDynamic(nbits, dim=-2)
135+
elif "sym" in qa_mode:
136+
act_quantizer = Qmax(
137+
nbits,
138+
align_zero=True,
139+
minmax=False,
140+
extend_act_range=extend_act_range,
141+
)
142+
else:
143+
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
132144
elif qa_mode == "fix":
133145
act_quantizer = QFixSymmetric(
134146
nbits, init_clip_val=clip_val, align_zero=align_zero
@@ -140,13 +152,7 @@ def get_activation_quantizer(
140152
minmax=False,
141153
extend_act_range=extend_act_range,
142154
)
143-
elif qa_mode == "pactsym":
144-
act_quantizer = PACT2Sym(
145-
nbits,
146-
init_clip_val=clip_val,
147-
dequantize=True,
148-
inplace=False,
149-
)
155+
150156
elif qa_mode == "pactsym+":
151157
act_quantizer = PACTplusSym(
152158
nbits,
@@ -179,8 +185,6 @@ def get_activation_quantizer(
179185
perToken=perToken,
180186
emulate=True,
181187
)
182-
elif qa_mode == "pertokenmax":
183-
act_quantizer = PerTokenMax(nbits)
184188
else:
185189
raise ValueError(f"unrecognized activation quantization mode {qa_mode}")
186190
else: # swcap-compatible activation quantizers
@@ -3488,6 +3492,42 @@ def __repr__(self):
34883492
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"
34893493

34903494

3495+
class QMaxDynamic(nn.Module):
3496+
def __init__(self, num_bits, dim=-1):
3497+
"""
3498+
For per-token or per-channel quantization using abs().max() as scale, usually for activation
3499+
and could be used for Qbmm M2 as well.
3500+
(reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token
3501+
dim = -2 -> per-channel
3502+
Zero is aligned so that the levels are symmetric around zero (lossing one level)
3503+
Since the token length is un-known before running, the quantizater can only calculate the
3504+
scales at the run times dynamically, meaning no trainable quantization scales is allowed.
3505+
(unless input seq length is always the same, not just padded to a fixed length.)
3506+
"""
3507+
super().__init__()
3508+
self.num_bits = num_bits
3509+
self.levels = 2 ** (self.num_bits - 1) - 1
3510+
if isinstance(dim, str):
3511+
if "perCh" in dim or "per_channel" in dim:
3512+
dim = -2
3513+
elif "perToken" in dim or "per_token" in dim or "per_Token" in dim:
3514+
dim = -1
3515+
elif dim in [-1, -2]:
3516+
self.reduce_dim = dim
3517+
else:
3518+
raise ValueError(
3519+
f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found {dim}"
3520+
)
3521+
3522+
def forward(self, input_tensor):
3523+
amax_dim = input_tensor.abs().max(dim=self.reduce_dim, keepdim=True)[0]
3524+
scales = amax_dim.clamp(min=1e-5).div(self.levels)
3525+
return input_tensor.div(scales).round().mul(scales)
3526+
3527+
def __repr__(self):
3528+
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"
3529+
3530+
34913531
class Qdynamic(nn.Module):
34923532
def __init__(
34933533
self,

0 commit comments

Comments
 (0)