Skip to content

Commit 28eb168

Browse files
authored
Fix Issue with Resizing Parameters on the Meta Device in Low CPU Mem Mode (#96)
* fix: quant models on meta device cannot have embedding resized Signed-off-by: Yu Chin Fabian Lim <[email protected]> * fix: grad reduce hook Signed-off-by: Yu Chin Fabian Lim <[email protected]> --------- Signed-off-by: Yu Chin Fabian Lim <[email protected]>
1 parent 98fcd2e commit 28eb168

File tree

4 files changed

+61
-4
lines changed

4 files changed

+61
-4
lines changed

plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
register_tensors_as_parameters_patch_rule,
3939
requires_installation_on_all_linears,
4040
)
41+
from .fsdp_utils import put_selected_meta_tensors_on_cpu
4142

4243

4344
class AutoGPTQAccelerationPlugin(AccelerationPlugin):
@@ -219,6 +220,11 @@ def model_loader(self, model_name: str, **kwargs):
219220
# replace
220221
AutoModelForCausalLM.from_config = _old_from_config
221222

223+
# in low_cpu_mem_mode, if certain tensors like embeddings
224+
# are in the meta device, then certain operations like
225+
# embedding resizing will fail
226+
put_selected_meta_tensors_on_cpu(model)
227+
222228
# AutoGPTQ does not set the torch_dtype of the model carefully
223229
model.config.torch_dtype = torch_dtype
224230

plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from transformers.utils.import_utils import _is_package_available
2929
import torch
3030

31+
# Local
32+
from .fsdp_utils import put_selected_meta_tensors_on_cpu
33+
3134

3235
# this is a modified copy of the function from peft.utils.other, that we
3336
# will instead use
@@ -154,6 +157,27 @@ def model_loader(self, model_name: str, **kwargs):
154157
attn_implementation=attn_implementation,
155158
)
156159

160+
if (
161+
world_size > 1
162+
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
163+
):
164+
config_kwargs["bnb_4bit_quant_storage"] = torch_dtype
165+
166+
_, _transformers_version = _is_package_available(
167+
"transformers", return_version=True
168+
)
169+
_trl_installed, _trl_version = _is_package_available(
170+
"trl", return_version=True
171+
)
172+
173+
if _transformers_version >= "4.45" and (
174+
not _trl_installed or (_trl_installed and _trl_version >= "0.12")
175+
):
176+
# in low_cpu_mem_mode, if certain tensors like embeddings
177+
# are in the meta device, then certain operations like
178+
# embedding resizing will fail
179+
put_selected_meta_tensors_on_cpu(model)
180+
157181
return model
158182

159183
@property

plugins/accelerated-peft/src/fms_acceleration_peft/fsdp_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from collections import defaultdict
33

44
# Third Party
5+
from accelerate.utils import set_module_tensor_to_device
6+
from transformers import PreTrainedModel
57
import torch
68

79
# Copyright The IBM Tuning Team
@@ -70,3 +72,27 @@ def param_init_fn_tied_param(module: torch.nn.Module):
7072
return module
7173

7274
return param_init_fn_tied_param
75+
76+
77+
# utility to put tensors on the cpu
78+
def put_selected_meta_tensors_on_cpu(model: PreTrainedModel):
79+
80+
done = {}
81+
# - fow now we only put input and output embeddings
82+
for module in [
83+
model.get_input_embeddings(),
84+
model.get_output_embeddings(),
85+
]:
86+
87+
for param_name, param in module.named_parameters(recurse=False):
88+
param_id = id(param)
89+
90+
if param.device == torch.device("meta"):
91+
if param_id not in done:
92+
value = torch.empty(*param.size(), dtype=param.dtype)
93+
done[param_id] = value # memoize
94+
else:
95+
# this is a tied weight, get back the previous value
96+
value = done[param_id]
97+
98+
set_module_tensor_to_device(module, param_name, "cpu", value)

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@ def _all_reduce_hook(grad):
4848
A = mod.lora_A.default
4949
B = mod.lora_B.default
5050

51-
# install hooks on the adapters
52-
A.weight.register_hook(_all_reduce_hook)
53-
B.weight.register_hook(_all_reduce_hook)
54-
5551
# because we will ignore these from FSDP, we need to manually
5652
# move them to gpu if they are already not on them
5753
# - if the adapters are on meta, we assume that this is for FSDP
@@ -80,6 +76,11 @@ def _all_reduce_hook(grad):
8076
if is_fsdp_enabled():
8177
dist.broadcast(B.weight, src=0)
8278

79+
# install hooks on the adapters
80+
# - this has to be done after all weight replacement happens
81+
A.weight.register_hook(_all_reduce_hook)
82+
B.weight.register_hook(_all_reduce_hook)
83+
8384
def register_foak_model_patch_rules(base_type):
8485
# Third Party
8586
from fms_acceleration.model_patcher import ( # pylint: disable=import-outside-toplevel

0 commit comments

Comments
 (0)