Skip to content

Commit 31bc3f6

Browse files
Merge pull request #21 from andrea-fasoli/get_act_scales_fix
Fix device for smoothquant activation scales
2 parents aa45b29 + 4b99b24 commit 31bc3f6

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

fms_mo/quant/ptq.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""
1616
Post-Training Quantization (PTQ) functions
1717
18-
Class StraightThrough, function _fold_bn, fold_bn_into_conv, reset_bn, and
18+
Class StraightThrough, function _fold_bn, fold_bn_into_conv, reset_bn, and
1919
search_fold_and_remove_bn are modified from QDROP repo https://github.com/wimh966/QDrop
2020
2121
2222
"""
2323

2424
# Standard
2525
from functools import partial
26+
from typing import Optional, Union
2627
import logging
2728
import math
2829
import random
@@ -2383,15 +2384,26 @@ def input_stats_hook(m, x, _y, name, act_scales):
23832384

23842385

23852386
@torch.no_grad()
2386-
def get_act_scales(model, dloader, qcfg):
2387-
"""
2388-
To get max() of activations for linear layers on one device.
2389-
Model size will be limited by memory (GPU) or speed (cpu)
2387+
def get_act_scales(
2388+
model,
2389+
dloader,
2390+
qcfg: dict,
2391+
device: Optional[Union[str, torch.device]] = None,
2392+
):
2393+
"""Compute smoothquant activation scales of quantized linear layers.
2394+
Model and examples are moved to selected device, if provided.
23902395
"""
23912396

23922397
model.eval()
2393-
model.cuda()
2394-
dev = next(model.parameters()).device
2398+
2399+
if device is None:
2400+
device = next(model.parameters()).device
2401+
else:
2402+
logger.info(
2403+
f"Moving model to {device} to compute smoothquant activation scales"
2404+
)
2405+
model.to(device)
2406+
23952407
act_scales = {}
23962408
qcfg["sample_id"] = 0
23972409
hooks = []
@@ -2408,8 +2420,7 @@ def get_act_scales(model, dloader, qcfg):
24082420

24092421
for data_mb, _ in zip(pbar, range(n_samples)):
24102422
qcfg["sample_id"] += 1
2411-
# logger.info("Now for sample: ", qcfg["sample_id"] )
2412-
data_mb = move_to(data_mb, dev)
2423+
data_mb = move_to(data_mb, device)
24132424
if (
24142425
qcfg["nbits_bmm1"] < 32
24152426
or qcfg["nbits_bmm2"] < 32

0 commit comments

Comments
 (0)