Skip to content

Commit aa45b29

Browse files
Merge pull request #20 from andrea-fasoli/smoothquant_update
Set smoothq_alpha as buffer
2 parents 5c53e4a + 3f50347 commit aa45b29

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

fms_mo/modules/linear.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,10 @@ def __init__(
223223
self.smoothq = qcfg.get("smoothq", False)
224224
if self.smoothq:
225225
self.register_buffer("smoothq_act_scale", torch.zeros(w_size[1]))
226-
self.smoothq_alpha = qcfg.get("smoothq_alpha", 0.5)
226+
self.register_buffer(
227+
"smoothq_alpha",
228+
torch.tensor([qcfg.get("smoothq_alpha", 0.5)], dtype=torch.float32),
229+
)
227230

228231
def forward(self, x):
229232
"""
@@ -335,11 +338,12 @@ def get_smoothq_scale(self, x):
335338
smoothq_scale = torch.tensor([1.0]).to(x.dtype).to(x.device)
336339
else:
337340
weight_scale = self.weight.abs().max(dim=0, keepdim=True)[0].clamp(min=1e-5)
341+
if isinstance(self.smoothq_alpha, torch.Tensor):
342+
alpha = self.smoothq_alpha.item()
343+
else:
344+
alpha = self.smoothq_alpha
338345
smoothq_scale = (
339-
(
340-
self.smoothq_act_scale.pow(self.smoothq_alpha)
341-
/ weight_scale.pow(1.0 - self.smoothq_alpha)
342-
)
346+
(self.smoothq_act_scale.pow(alpha) / weight_scale.pow(1.0 - alpha))
343347
.clamp(min=1e-5)
344348
.to(x.dtype)
345349
)

0 commit comments

Comments
 (0)