Skip to content

Commit fcd357c

Browse files
committed
Fix how model is sent to device to calculate smoothquant activation scales
Signed-off-by: Andrea Fasoli <[email protected]>
1 parent aa45b29 commit fcd357c

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

fms_mo/quant/ptq.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""
1616
Post-Training Quantization (PTQ) functions
1717
18-
Class StraightThrough, function _fold_bn, fold_bn_into_conv, reset_bn, and
18+
Class StraightThrough, function _fold_bn, fold_bn_into_conv, reset_bn, and
1919
search_fold_and_remove_bn are modified from QDROP repo https://github.com/wimh966/QDrop
2020
2121
2222
"""
2323

2424
# Standard
2525
from functools import partial
26+
from typing import Optional
2627
import logging
2728
import math
2829
import random
@@ -2383,14 +2384,26 @@ def input_stats_hook(m, x, _y, name, act_scales):
23832384

23842385

23852386
@torch.no_grad()
2386-
def get_act_scales(model, dloader, qcfg):
2387-
"""
2388-
To get max() of activations for linear layers on one device.
2389-
Model size will be limited by memory (GPU) or speed (cpu)
2387+
def get_act_scales(
2388+
model,
2389+
dloader,
2390+
qcfg: dict,
2391+
device: Optional[str | torch.device] = None,
2392+
):
2393+
"""Compute smoothquant activation scales of quantized linear layers.
2394+
Model and examples are moved to selected device, if provided.
23902395
"""
23912396

23922397
model.eval()
2393-
model.cuda()
2398+
2399+
if device is None:
2400+
device = next(model.parameters()).device
2401+
else:
2402+
logger.info(
2403+
f"Moving model to {device} to compute smoothquant activation scales"
2404+
)
2405+
model.to(device)
2406+
23942407
dev = next(model.parameters()).device
23952408
act_scales = {}
23962409
qcfg["sample_id"] = 0
@@ -2408,7 +2421,6 @@ def get_act_scales(model, dloader, qcfg):
24082421

24092422
for data_mb, _ in zip(pbar, range(n_samples)):
24102423
qcfg["sample_id"] += 1
2411-
# logger.info("Now for sample: ", qcfg["sample_id"] )
24122424
data_mb = move_to(data_mb, dev)
24132425
if (
24142426
qcfg["nbits_bmm1"] < 32

0 commit comments

Comments
 (0)