Skip to content

Commit 85b473e

Browse files
authored
Accelerate Utilities (#193)
* wip * add modify_offload_module * update docs * WIP * cleanup functions, begin depreciation Signed-off-by: Kyle Sayers <[email protected]> * remove extra space Signed-off-by: Kyle Sayers <[email protected]> * revert get_offloaded_device Signed-off-by: Kyle Sayers <[email protected]> * update to align_module_device Signed-off-by: Kyle Sayers <[email protected]> * add requires skip for accelerate * fix per token initialization * remove align_module_device * respond to nits Signed-off-by: Kyle Sayers <[email protected]> * Accelerate Utilities Follow-up (#224) * rename * implement recursive case * remove print * support OffloadedWeightsLoader * add lifecycle docstring * implement offload_to_weights_map with recursive definition Signed-off-by: Kyle Sayers <[email protected]> * add docstring Signed-off-by: Kyle Sayers <[email protected]> * fix type hint * add check_accelerate guard Signed-off-by: Kyle Sayers <[email protected]> * make device used by clearer Signed-off-by: Kyle Sayers <[email protected]> * update update_prefix_dict Signed-off-by: Kyle Sayers <[email protected]> * reuse fixture Signed-off-by: Kyle Sayers <[email protected]> * use apply rather than recursion Signed-off-by: Kyle Sayers <[email protected]> * clearer delete_from_weights_map * add offload_device argument (#228) Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 975cb22 commit 85b473e

File tree

5 files changed

+708
-91
lines changed

5 files changed

+708
-91
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 17 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929
from compressed_tensors.quantization.quant_config import QuantizationStatus
3030
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
3131
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
32-
from compressed_tensors.utils import get_execution_device, is_module_offloaded
32+
from compressed_tensors.utils import (
33+
disable_hf_hook,
34+
has_offloaded_params,
35+
register_offload_parameter,
36+
)
3337
from torch.nn import Module, Parameter
3438

3539

@@ -112,43 +116,10 @@ def initialize_module_for_quantization(
112116
module.quantization_scheme = scheme
113117
module.quantization_status = QuantizationStatus.INITIALIZED
114118

115-
offloaded = False
116-
# What is this doing/why isn't this in the attn case?
117-
if is_module_offloaded(module):
118-
try:
119-
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
120-
from accelerate.utils import PrefixedDataset
121-
except ModuleNotFoundError:
122-
raise ModuleNotFoundError(
123-
"Offloaded model detected. To use CPU offloading with "
124-
"compressed-tensors the `accelerate` package must be installed, "
125-
"run `pip install compressed-tensors[accelerate]`"
126-
)
127-
128-
offloaded = True
129-
hook = module._hf_hook
130-
prefix_dict = module._hf_hook.weights_map
131-
new_prefix = {}
132-
133-
# recreate the prefix dict (since it is immutable)
134-
# and add quantization parameters
135-
for key, data in module.named_parameters():
136-
if key not in prefix_dict:
137-
new_prefix[f"{prefix_dict.prefix}{key}"] = data
138-
else:
139-
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
140-
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
141-
remove_hook_from_module(module)
142-
143-
# wrap forward call of module to perform
144-
# quantized actions based on calltime status
145-
wrap_module_forward_quantized(module, scheme)
146-
147-
if offloaded:
148-
# we need to re-add the hook for offloading now that we've wrapped forward
149-
add_hook_to_module(module, hook)
150-
if prefix_dict is not None:
151-
module._hf_hook.weights_map = new_prefix_dict
119+
with disable_hf_hook(module):
120+
# wrap forward call of module to perform
121+
# quantized actions based on calltime status
122+
wrap_module_forward_quantized(module, scheme)
152123

153124

154125
def is_attention_module(module: Module):
@@ -169,9 +140,11 @@ def _initialize_scale_zero_point(
169140
if quantization_args.dynamic:
170141
return
171142

172-
device = next(module.parameters()).device
173-
if is_module_offloaded(module):
174-
device = get_execution_device(module)
143+
# begin on the same device as other parameters or cpu if offloaded.
144+
# in the offloaded case, there's no point moving tensors to the execution device
145+
# if they're going to be immediately offloaded by `register_offload_parameter`
146+
params_device = next(module.parameters()).device
147+
device = "cpu" if has_offloaded_params(module) else params_device
175148

176149
# infer expected scale/zero point shape
177150
if quantization_args.strategy == QuantizationStrategy.TOKEN:
@@ -196,15 +169,15 @@ def _initialize_scale_zero_point(
196169
torch.empty(expected_shape, dtype=scale_dtype, device=device),
197170
requires_grad=False,
198171
)
199-
module.register_parameter(f"{base_name}_scale", init_scale)
172+
register_offload_parameter(module, f"{base_name}_scale", init_scale)
200173

201174
if force_zero_point or not quantization_args.symmetric:
202175
zp_dtype = quantization_args.pytorch_dtype()
203176
init_zero_point = Parameter(
204177
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
205178
requires_grad=False,
206179
)
207-
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
180+
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
208181

209182
# only grouped activation ordering has g_idx
210183
if quantization_args.actorder == ActivationOrdering.GROUP:
@@ -214,7 +187,7 @@ def _initialize_scale_zero_point(
214187
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
215188
requires_grad=False,
216189
)
217-
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
190+
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
218191

219192

220193
def _initialize_attn_scales(module: Module) -> None:

src/compressed_tensors/utils/helpers.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Dict, Optional
15+
import warnings
16+
from functools import wraps
17+
from typing import Any, Callable, Dict, Optional
1618

1719
import torch
1820
from transformers import AutoConfig
@@ -24,6 +26,8 @@
2426
"tensor_follows_mask_structure",
2527
"replace_module",
2628
"is_compressed_tensors_config",
29+
"getattr_chain",
30+
"deprecated",
2731
"Aliasable",
2832
]
2933

@@ -122,6 +126,65 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
122126
return False
123127

124128

129+
def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
130+
"""
131+
Chain multiple getattr calls, separated by `.`
132+
133+
:param obj: base object whose attributes are being retrieved
134+
:param chain_str: attribute names separated by `.`
135+
:param default: default value, throw error otherwise
136+
"""
137+
if len(args) >= 1:
138+
has_default = True
139+
default = args[0]
140+
elif "default" in kwargs:
141+
has_default = True
142+
default = kwargs["default"]
143+
else:
144+
has_default = False
145+
146+
attr_names = chain_str.split(".")
147+
148+
res = obj
149+
for attr_name in attr_names:
150+
if not hasattr(res, attr_name):
151+
if has_default:
152+
return default
153+
else:
154+
raise AttributeError(f"{res} object has no attribute {attr_name}")
155+
res = getattr(res, attr_name)
156+
157+
return res
158+
159+
160+
def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
161+
"""
162+
Decorator to mark functions as deprecated
163+
164+
:param new_function: Function called in place of depreciated function
165+
:param message: Depreciation message, replaces default depreciation message
166+
"""
167+
168+
def decorator(func: Callable[[Any], Any]):
169+
nonlocal message
170+
171+
if message is None:
172+
message = (
173+
f"{func.__name__} is deprecated and will be removed in a future release"
174+
)
175+
if future_name is not None:
176+
message += f". Please use {future_name} instead."
177+
178+
@wraps(func)
179+
def wrapped(*args, **kwargs):
180+
warnings.warn(message, DeprecationWarning, stacklevel=2)
181+
return func(*args, **kwargs)
182+
183+
return wrapped
184+
185+
return decorator
186+
187+
125188
class Aliasable:
126189
"""
127190
A mixin for enums to allow aliasing of enum members

0 commit comments

Comments
 (0)