29
29
from compressed_tensors .quantization .quant_config import QuantizationStatus
30
30
from compressed_tensors .quantization .quant_scheme import QuantizationScheme
31
31
from compressed_tensors .quantization .utils import is_kv_cache_quant_scheme
32
- from compressed_tensors .utils import get_execution_device , is_module_offloaded
32
+ from compressed_tensors .utils import (
33
+ disable_hf_hook ,
34
+ has_offloaded_params ,
35
+ register_offload_parameter ,
36
+ )
33
37
from torch .nn import Module , Parameter
34
38
35
39
@@ -112,43 +116,10 @@ def initialize_module_for_quantization(
112
116
module .quantization_scheme = scheme
113
117
module .quantization_status = QuantizationStatus .INITIALIZED
114
118
115
- offloaded = False
116
- # What is this doing/why isn't this in the attn case?
117
- if is_module_offloaded (module ):
118
- try :
119
- from accelerate .hooks import add_hook_to_module , remove_hook_from_module
120
- from accelerate .utils import PrefixedDataset
121
- except ModuleNotFoundError :
122
- raise ModuleNotFoundError (
123
- "Offloaded model detected. To use CPU offloading with "
124
- "compressed-tensors the `accelerate` package must be installed, "
125
- "run `pip install compressed-tensors[accelerate]`"
126
- )
127
-
128
- offloaded = True
129
- hook = module ._hf_hook
130
- prefix_dict = module ._hf_hook .weights_map
131
- new_prefix = {}
132
-
133
- # recreate the prefix dict (since it is immutable)
134
- # and add quantization parameters
135
- for key , data in module .named_parameters ():
136
- if key not in prefix_dict :
137
- new_prefix [f"{ prefix_dict .prefix } { key } " ] = data
138
- else :
139
- new_prefix [f"{ prefix_dict .prefix } { key } " ] = prefix_dict [key ]
140
- new_prefix_dict = PrefixedDataset (new_prefix , prefix_dict .prefix )
141
- remove_hook_from_module (module )
142
-
143
- # wrap forward call of module to perform
144
- # quantized actions based on calltime status
145
- wrap_module_forward_quantized (module , scheme )
146
-
147
- if offloaded :
148
- # we need to re-add the hook for offloading now that we've wrapped forward
149
- add_hook_to_module (module , hook )
150
- if prefix_dict is not None :
151
- module ._hf_hook .weights_map = new_prefix_dict
119
+ with disable_hf_hook (module ):
120
+ # wrap forward call of module to perform
121
+ # quantized actions based on calltime status
122
+ wrap_module_forward_quantized (module , scheme )
152
123
153
124
154
125
def is_attention_module (module : Module ):
@@ -169,9 +140,11 @@ def _initialize_scale_zero_point(
169
140
if quantization_args .dynamic :
170
141
return
171
142
172
- device = next (module .parameters ()).device
173
- if is_module_offloaded (module ):
174
- device = get_execution_device (module )
143
+ # begin on the same device as other parameters or cpu if offloaded.
144
+ # in the offloaded case, there's no point moving tensors to the execution device
145
+ # if they're going to be immediately offloaded by `register_offload_parameter`
146
+ params_device = next (module .parameters ()).device
147
+ device = "cpu" if has_offloaded_params (module ) else params_device
175
148
176
149
# infer expected scale/zero point shape
177
150
if quantization_args .strategy == QuantizationStrategy .TOKEN :
@@ -196,15 +169,15 @@ def _initialize_scale_zero_point(
196
169
torch .empty (expected_shape , dtype = scale_dtype , device = device ),
197
170
requires_grad = False ,
198
171
)
199
- module . register_parameter ( f"{ base_name } _scale" , init_scale )
172
+ register_offload_parameter ( module , f"{ base_name } _scale" , init_scale )
200
173
201
174
if force_zero_point or not quantization_args .symmetric :
202
175
zp_dtype = quantization_args .pytorch_dtype ()
203
176
init_zero_point = Parameter (
204
177
torch .zeros (expected_shape , device = device , dtype = zp_dtype ),
205
178
requires_grad = False ,
206
179
)
207
- module . register_parameter ( f"{ base_name } _zero_point" , init_zero_point )
180
+ register_offload_parameter ( module , f"{ base_name } _zero_point" , init_zero_point )
208
181
209
182
# only grouped activation ordering has g_idx
210
183
if quantization_args .actorder == ActivationOrdering .GROUP :
@@ -214,7 +187,7 @@ def _initialize_scale_zero_point(
214
187
torch .full (g_idx_shape , - 1 , device = device , dtype = g_idx_dtype ),
215
188
requires_grad = False ,
216
189
)
217
- module . register_parameter ( f"{ base_name } _g_idx" , init_g_idx )
190
+ register_offload_parameter ( module , f"{ base_name } _g_idx" , init_g_idx )
218
191
219
192
220
193
def _initialize_attn_scales (module : Module ) -> None :
0 commit comments