Skip to content

Commit 4c643a4

Browse files
rename "smq_" to "smoothq_"
Signed-off-by: cliu-us <[email protected]>
1 parent a748476 commit 4c643a4

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

fms_mo/modules/linear.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -795,20 +795,23 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
795795
qlin_int.cvs = [a_cv, a_cvn, w_cv] # TODO remove the need of this?
796796

797797
# prepare smoothQuant scale, = (smQ_a_scale ^ alpha)/(smQ_w_scale ^ (1-alpha) )
798-
smq_scale = torch.tensor([1.0], device=tar_dev, dtype=fms_mo_w_dtype)
798+
smoothq_scale = torch.tensor([1.0], device=tar_dev, dtype=fms_mo_w_dtype)
799799
if getattr(fms_mo_qlinear, "smoothq", False):
800-
smq_a_scale = fms_mo_qlinear.smoothq_act_scale
801-
smq_w_scale = (
800+
smoothq_a_scale = fms_mo_qlinear.smoothq_act_scale
801+
smoothq_w_scale = (
802802
fms_mo_qlinear.weight.abs()
803803
.max(dim=0, keepdim=True)[0]
804804
.clamp(min=1e-5)
805805
)
806-
smq_alpha = fms_mo_qlinear.smoothq_alpha
807-
if torch.all(smq_a_scale != 0).item():
808-
smq_scale = (
809-
(smq_a_scale**smq_alpha / smq_w_scale ** (1.0 - smq_alpha))
806+
smoothq_alpha = fms_mo_qlinear.smoothq_alpha
807+
if torch.all(smoothq_a_scale != 0).item():
808+
smoothq_scale = (
809+
(
810+
smoothq_a_scale**smoothq_alpha
811+
/ smoothq_w_scale ** (1.0 - smoothq_alpha)
812+
)
810813
.clamp(min=1e-5)
811-
.to(smq_a_scale.dtype)
814+
.to(smoothq_a_scale.dtype)
812815
)
813816

814817
# could trigger Qw.clipval re-calc for SAWB here, if needed
@@ -845,18 +848,18 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
845848
qlin_int.register_buffer("input_zp", input_zero_point)
846849
qlin_int.register_buffer("w_scale", w_scale)
847850
qlin_int.register_buffer("w_zp", w_zp)
848-
qlin_int.register_buffer("smq_scale", smq_scale)
851+
qlin_int.register_buffer("smoothq_scale", smoothq_scale)
849852

850853
# NOTE:
851854
# 1. Keep W transposed to prevent confusion, hence (W.t()/scale).t()
852855
# 2. only a few quantizer have .dequantize working correctly, e.g. SAWB
853856
# 3. smooth_quant factor is included in the W here, will also include it in the forward
854857
if isinstance(Qw, SAWB):
855858
Qw.dequantize = False
856-
w_int8 = Qw(fms_mo_qlinear.weight.float() * smq_scale)
859+
w_int8 = Qw(fms_mo_qlinear.weight.float() * smoothq_scale)
857860
else:
858861
w_int8 = (
859-
torch.round((fms_mo_qlinear.weight * smq_scale).t() / w_scale)
862+
torch.round((fms_mo_qlinear.weight * smoothq_scale).t() / w_scale)
860863
.clamp(-w_levels / 2, w_levels / 2)
861864
.t()
862865
)
@@ -1323,8 +1326,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
13231326
self.weight.shape[0],
13241327
) # W.shape=[out,in]
13251328

1326-
if torch.all(self.smq_scale != 1).item():
1327-
x = x.view(re_shape) / self.smq_scale
1329+
if torch.all(self.smoothq_scale != 1).item():
1330+
x = x.view(re_shape) / self.smoothq_scale
13281331
else:
13291332
x = x.view(re_shape)
13301333

0 commit comments

Comments
 (0)