1515"""
1616Post-Training Quantization (PTQ) functions
1717
18- Class StraightThrough, function _fold_bn, fold_bn_into_conv, reset_bn, and
18+ Class StraightThrough, function _fold_bn, fold_bn_into_conv, reset_bn, and
1919search_fold_and_remove_bn are modified from QDROP repo https://github.com/wimh966/QDrop
2020
2121
2222"""
2323
2424# Standard
2525from functools import partial
26+ from typing import Optional , Union
2627import logging
2728import math
2829import random
@@ -2383,15 +2384,26 @@ def input_stats_hook(m, x, _y, name, act_scales):
23832384
23842385
23852386@torch .no_grad ()
2386- def get_act_scales (model , dloader , qcfg ):
2387- """
2388- To get max() of activations for linear layers on one device.
2389- Model size will be limited by memory (GPU) or speed (cpu)
2387+ def get_act_scales (
2388+ model ,
2389+ dloader ,
2390+ qcfg : dict ,
2391+ device : Optional [Union [str , torch .device ]] = None ,
2392+ ):
2393+ """Compute smoothquant activation scales of quantized linear layers.
2394+ Model and examples are moved to selected device, if provided.
23902395 """
23912396
23922397 model .eval ()
2393- model .cuda ()
2394- dev = next (model .parameters ()).device
2398+
2399+ if device is None :
2400+ device = next (model .parameters ()).device
2401+ else :
2402+ logger .info (
2403+ f"Moving model to { device } to compute smoothquant activation scales"
2404+ )
2405+ model .to (device )
2406+
23952407 act_scales = {}
23962408 qcfg ["sample_id" ] = 0
23972409 hooks = []
@@ -2408,8 +2420,7 @@ def get_act_scales(model, dloader, qcfg):
24082420
24092421 for data_mb , _ in zip (pbar , range (n_samples )):
24102422 qcfg ["sample_id" ] += 1
2411- # logger.info("Now for sample: ", qcfg["sample_id"] )
2412- data_mb = move_to (data_mb , dev )
2423+ data_mb = move_to (data_mb , device )
24132424 if (
24142425 qcfg ["nbits_bmm1" ] < 32
24152426 or qcfg ["nbits_bmm2" ] < 32
0 commit comments