1515"""
1616Post-Training Quantization (PTQ) functions
1717
18- Class StraightThrough, function _fold_bn, fold_bn_into_conv, reset_bn, and
18+ Class StraightThrough, function _fold_bn, fold_bn_into_conv, reset_bn, and
1919search_fold_and_remove_bn are modified from QDROP repo https://github.com/wimh966/QDrop
2020
2121
2222"""
2323
2424# Standard
2525from functools import partial
26+ from typing import Optional
2627import logging
2728import math
2829import random
@@ -2383,14 +2384,26 @@ def input_stats_hook(m, x, _y, name, act_scales):
23832384
23842385
23852386@torch .no_grad ()
2386- def get_act_scales (model , dloader , qcfg ):
2387- """
2388- To get max() of activations for linear layers on one device.
2389- Model size will be limited by memory (GPU) or speed (cpu)
2387+ def get_act_scales (
2388+ model ,
2389+ dloader ,
2390+ qcfg : dict ,
2391+ device : Optional [str | torch .device ] = None ,
2392+ ):
2393+ """Compute smoothquant activation scales of quantized linear layers.
2394+ Model and examples are moved to selected device, if provided.
23902395 """
23912396
23922397 model .eval ()
2393- model .cuda ()
2398+
2399+ if device is None :
2400+ device = next (model .parameters ()).device
2401+ else :
2402+ logger .info (
2403+ f"Moving model to { device } to compute smoothquant activation scales"
2404+ )
2405+ model .to (device )
2406+
23942407 dev = next (model .parameters ()).device
23952408 act_scales = {}
23962409 qcfg ["sample_id" ] = 0
@@ -2408,7 +2421,6 @@ def get_act_scales(model, dloader, qcfg):
24082421
24092422 for data_mb , _ in zip (pbar , range (n_samples )):
24102423 qcfg ["sample_id" ] += 1
2411- # logger.info("Now for sample: ", qcfg["sample_id"] )
24122424 data_mb = move_to (data_mb , dev )
24132425 if (
24142426 qcfg ["nbits_bmm1" ] < 32
0 commit comments