Skip to content

Commit 622f721

Browse files
Sara AdkinsGeorge Ohashi
andauthored
Model Offloading Support (#113)
* compute zp, scale if weight exists in module * WIP, gets through 1 forward pass * fix for zeroed out scales * fix model load * style * offload helper fns * pass tests * add test to check that observers are used to populate zp and scale in initialization * fix no calibration case * clean up for PR * fix test * update dependencies * fix forward bug * don't calibrate on weights * dont calib weight in forward * fix zp load * check calibration --------- Co-authored-by: George Ohashi <[email protected]>
1 parent c214cbc commit 622f721

File tree

10 files changed

+189
-27
lines changed

10 files changed

+189
-27
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _setup_packages() -> List:
4646
)
4747

4848
def _setup_install_requires() -> List:
49-
return ["torch>=1.7.0", "transformers", "pydantic>=2.0"]
49+
return ["torch>=1.7.0", "transformers", "accelerate", "pydantic>=2.0"]
5050

5151
def _setup_extras() -> Dict:
5252
return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0", "nbconvert>=7.16.3"]}

src/compressed_tensors/compressors/model_compressor.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939
is_module_quantized,
4040
iter_named_leaf_modules,
4141
)
42-
from compressed_tensors.utils import get_safetensors_folder
42+
from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
4343
from compressed_tensors.utils.helpers import fix_fsdp_module_name
4444
from torch import Tensor
45-
from torch.nn import Module, Parameter
45+
from torch.nn import Module
4646
from tqdm import tqdm
4747
from transformers import AutoConfig
4848
from transformers.file_utils import CONFIG_NAME
@@ -307,12 +307,10 @@ def update_config(self, save_directory: str):
307307

308308
def _replace_weights(self, dense_weight_generator, model):
309309
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
310-
# loading the decompressed weights into the model
311-
model_device = operator.attrgetter(name)(model).device
312-
data_old = operator.attrgetter(name)(model)
313-
data_dtype = data_old.dtype
314-
data_new = Parameter(data.to(model_device).to(data_dtype))
315-
data_old.data = data_new.data
310+
split_name = name.split(".")
311+
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
312+
module = operator.attrgetter(prefix)(model)
313+
update_parameter_data(module, data, param_name)
316314

317315

318316
def map_modules_to_quant_args(model: Module) -> Dict:

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
iter_named_leaf_modules,
4444
)
4545
from compressed_tensors.utils.helpers import fix_fsdp_module_name
46+
from compressed_tensors.utils.offload import update_parameter_data
4647
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
4748
from torch.nn import Module
4849

@@ -265,19 +266,17 @@ def _load_quant_args_from_state_dict(
265266
"""
266267
scale_name = f"{base_name}_scale"
267268
zp_name = f"{base_name}_zero_point"
268-
device = next(module.parameters()).device
269-
270-
scale = getattr(module, scale_name, None)
271-
zp = getattr(module, zp_name, None)
272-
if scale is not None:
273-
state_dict_scale = state_dict[f"{module_name}.{scale_name}"]
274-
scale.data = state_dict_scale.to(device).to(scale.dtype)
275-
if zp is not None:
276-
zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None)
277-
if zp_from_state is not None: # load the non-zero zero points
278-
zp.data = zp_from_state.to(device).to(zp.dtype)
279-
else: # fill with zeros matching scale shape
280-
zp.data = torch.zeros_like(scale, dtype=zp.dtype).to(device)
269+
270+
state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
271+
state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
272+
273+
if state_dict_scale is not None:
274+
# module is quantized
275+
update_parameter_data(module, state_dict_scale, scale_name)
276+
if state_dict_zp is None:
277+
# fill in zero point for symmetric quantization
278+
state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
279+
update_parameter_data(module, state_dict_zp, zp_name)
281280

282281

283282
def _scheme_from_targets(

src/compressed_tensors/quantization/lifecycle/calibration.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717

1818
from compressed_tensors.quantization.quant_config import QuantizationStatus
19+
from compressed_tensors.utils import is_module_offloaded, update_parameter_data
1920
from torch.nn import Module
2021

2122

@@ -48,4 +49,20 @@ def set_module_for_calibration(module: Module):
4849
"to re-calibrate a frozen module"
4950
)
5051

52+
if module.quantization_scheme.weights is not None:
53+
# set weight scale and zero_point up front, calibration data doesn't affect it
54+
observer = module.weight_observer
55+
56+
offloaded = False
57+
if is_module_offloaded(module):
58+
module._hf_hook.pre_forward(module)
59+
offloaded = True
60+
61+
scale, zero_point = observer(module.weight)
62+
update_parameter_data(module, scale, "weight_scale")
63+
update_parameter_data(module, zero_point, "weight_zero_point")
64+
65+
if offloaded:
66+
module._hf_hook.post_forward(module, None)
67+
5168
module.quantization_status = QuantizationStatus.CALIBRATION

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from compressed_tensors.quantization.quant_config import QuantizationStatus
2727
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
28+
from compressed_tensors.utils import update_parameter_data
2829
from torch.nn import Module
2930

3031

@@ -312,16 +313,19 @@ def maybe_calibrate_or_quantize(
312313
scale = getattr(module, f"{base_name}_scale")
313314
zero_point = getattr(module, f"{base_name}_zero_point")
314315

315-
if module.quantization_status == QuantizationStatus.CALIBRATION:
316+
if (
317+
module.quantization_status == QuantizationStatus.CALIBRATION
318+
and base_name != "weight"
319+
):
316320
# calibration mode - get new quant params from observer
317321
observer = getattr(module, f"{base_name}_observer")
318322

319323
updated_scale, updated_zero_point = observer(value)
320324

321325
# update scale and zero point
322-
device = next(module.parameters()).device
323-
scale.data = updated_scale.to(device)
324-
zero_point.data = updated_zero_point.to(device)
326+
update_parameter_data(module, updated_scale, f"{base_name}_scale")
327+
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
328+
325329
return fake_quantize(value, scale, zero_point, args)
326330

327331

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from typing import Optional
1818

1919
import torch
20+
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
21+
from accelerate.utils import PrefixedDataset
2022
from compressed_tensors.quantization.lifecycle.forward import (
2123
wrap_module_forward_quantized,
2224
)
@@ -26,6 +28,7 @@
2628
)
2729
from compressed_tensors.quantization.quant_config import QuantizationStatus
2830
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
31+
from compressed_tensors.utils import get_execution_device, is_module_offloaded
2932
from torch.nn import Module, Parameter
3033

3134

@@ -81,9 +84,32 @@ def initialize_module_for_quantization(
8184
module.quantization_scheme = scheme
8285
module.quantization_status = QuantizationStatus.INITIALIZED
8386

87+
offloaded = False
88+
if is_module_offloaded(module):
89+
offloaded = True
90+
hook = module._hf_hook
91+
prefix_dict = module._hf_hook.weights_map
92+
new_prefix = {}
93+
94+
# recreate the prefix dict (since it is immutable)
95+
# and add quantization parameters
96+
for key, data in module.named_parameters():
97+
if key not in prefix_dict:
98+
new_prefix[f"{prefix_dict.prefix}{key}"] = data
99+
else:
100+
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
101+
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
102+
remove_hook_from_module(module)
103+
84104
# wrap forward call of module to perform quantized actions based on calltime status
85105
wrap_module_forward_quantized(module, scheme)
86106

107+
if offloaded:
108+
# we need to re-add the hook for offloading now that we've wrapped forward
109+
add_hook_to_module(module, hook)
110+
if prefix_dict is not None:
111+
module._hf_hook.weights_map = new_prefix_dict
112+
87113

88114
def _initialize_scale_zero_point_observer(
89115
module: Module,
@@ -99,6 +125,8 @@ def _initialize_scale_zero_point_observer(
99125
return # no need to register a scale and zero point for a dynamic observer
100126

101127
device = next(module.parameters()).device
128+
if is_module_offloaded(module):
129+
device = get_execution_device(module)
102130

103131
# infer expected scale/zero point shape
104132
expected_shape = 1 # per tensor

src/compressed_tensors/quantization/quant_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,14 @@ def from_pretrained(
239239
format=format,
240240
ignore=consolidated_ignore,
241241
)
242+
243+
def requires_calibration_data(self):
244+
for _, scheme in self.config_groups.items():
245+
if scheme.input_activations is not None:
246+
if not scheme.input_activations.dynamic:
247+
return True
248+
if scheme.output_activations is not None:
249+
if not scheme.output_activations.dynamic:
250+
return True
251+
252+
return False

src/compressed_tensors/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# flake8: noqa
1515

1616
from .helpers import *
17+
from .offload import *
1718
from .permutations_24 import *
1819
from .safetensors_load import *
1920
from .semi_structured_conversions import *
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from torch.nn import Module
17+
18+
19+
__all__ = [
20+
"is_module_offloaded",
21+
"get_execution_device",
22+
"get_offloaded_device",
23+
"update_prefix_dict",
24+
"update_parameter_data",
25+
]
26+
27+
28+
def is_module_offloaded(module: Module) -> bool:
29+
"""
30+
:param module: layer to check
31+
:return: True if layer is offloaded from GPU, False otherwise
32+
"""
33+
return hasattr(module, "_hf_hook") and module._hf_hook.offload
34+
35+
36+
def get_execution_device(module: Module) -> torch.device:
37+
"""
38+
:param module: layer to check
39+
:return: device layer is loaded onto during forward pass
40+
"""
41+
if is_module_offloaded(module):
42+
return module._hf_hook.execution_device
43+
return next(module.parameters()).device
44+
45+
46+
def get_offloaded_device(module: Module) -> torch.device:
47+
"""
48+
:param module: layer to check
49+
:return: device layer is offloaded to onto after forward pass
50+
"""
51+
if is_module_offloaded(module):
52+
first_key = list(module._hf_hook.weights_map.keys())[0]
53+
prefix_dataset = module._hf_hook.weights_map.dataset
54+
return prefix_dataset[first_key].device
55+
return next(module.parameters()).device
56+
57+
58+
def update_prefix_dict(module: Module, key: str, data: torch.Tensor):
59+
"""
60+
Updates the offloaded state dict for a given module. Parameter named key is replaced
61+
by data. This is neccesary because parameter updates for offloaded modules do not
62+
persist automatically between loads. This function only affects the offloaded
63+
state dict and not the current state of the loaded module.
64+
65+
:param module: layer containing the parameter to update
66+
:param key: name of parameter to update
67+
:param data: tensor to update parameter with in the offloaded state dict
68+
"""
69+
if not is_module_offloaded(module):
70+
raise ValueError("Prefix dict is only applicable to offloaded modules")
71+
prefix_dict = module._hf_hook.weights_map
72+
prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data
73+
74+
75+
def update_parameter_data(
76+
module: Module, new_param_data: torch.Tensor, param_name: str
77+
):
78+
"""
79+
Updates the paramter value named param_name for a given module. This function
80+
updates both the current loaded module state and the offloaded state dict if
81+
the module is offloaded. This is neccesary because parameter updates for offloaded
82+
modules do not persist automatically between loads.
83+
84+
:param module: layer containing the parameter to update
85+
:param new_param_data: tensor to update parameter with
86+
:param param_name:
87+
"""
88+
device = next(module.parameters()).device
89+
90+
offloaded = False
91+
if is_module_offloaded(module):
92+
offload_device = get_offloaded_device(module)
93+
offloaded = True
94+
95+
parameter = getattr(module, param_name, None)
96+
dtype = parameter.dtype
97+
parameter.data = new_param_data.to(device).to(dtype)
98+
99+
if offloaded:
100+
prefix_dict = module._hf_hook.weights_map.dataset
101+
prefix = module._hf_hook.weights_map.prefix
102+
prefix_dict[f"{prefix}{param_name}"] = new_param_data.to(offload_device).to(
103+
dtype
104+
)

tests/test_quantization/lifecycle/test_enabled.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_quantization_enabled_disabled():
3232
apply_quantization_config(
3333
model=quantized_model,
3434
config=QuantizationConfig(
35-
config_groups=dict(W4A16=["Linear"]),
35+
config_groups=dict(W8A8=["Linear"]),
3636
quantization_status="calibration",
3737
),
3838
)

0 commit comments

Comments
 (0)