Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/peft/utils/save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,10 @@ def renamed_dora_weights(k):
)
if low_cpu_mem_usage:
load_result = model.load_state_dict(peft_model_state_dict, strict=False, assign=True)
# ensure that the correct device is set
for module in model.modules():
if hasattr(module, "_move_adapter_to_device_of_base_layer"):
module._move_adapter_to_device_of_base_layer(adapter_name)
else:
load_result = model.load_state_dict(peft_model_state_dict, strict=False)

Expand Down
51 changes: 51 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@
PromptEncoderConfig,
TaskType,
get_peft_model,
get_peft_model_state_dict,
inject_adapter_in_model,
prepare_model_for_kbit_training,
replace_lora_weights_loftq,
set_peft_model_state_dict,
)
from peft.tuners import boft
from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
Expand Down Expand Up @@ -3226,3 +3229,51 @@ def test_p_tuning_exactly_reproducible_after_loading(self, tmp_path):

torch.testing.assert_close(output_loaded, output_peft)
torch.testing.assert_close(gen_loaded, gen_peft)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
@pytest.mark.single_gpu_tests
class TestLowCpuMemUsageDifferentDevices:
"""Test for the low CPU memory usage option for loading PEFT models.

There are already tests for this in test_initialization.py but here we want to specifically test diverging devices
for the model and state_dict.

"""

model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"

@pytest.mark.parametrize("device_model, device_sd", [("cpu", "cuda"), ("cuda", "cpu")])
def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, device_model, device_sd):
inputs = {"input_ids": torch.randint(0, 100, (1, 10)), "attention_mask": torch.ones(1, 10)}
inputs = {k: v.to(device_model) for k, v in inputs.items()}

model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model)
lora_config = LoraConfig(init_lora_weights=False, target_modules="all-linear")
model = get_peft_model(model, lora_config)
model.eval()
logits_not_low_cpu_mem = model(**inputs).logits

state_dict = get_peft_model_state_dict(model)
peft_model_state_dict = {}
# remap the state dict so that it can be correctly loaded, and move weights to the other device
prefix = "base_model.model."
for k, v in state_dict.items():
k = k[len(prefix) :]
peft_model_state_dict[k] = v.to(device_sd)

del model

model = AutoModelForCausalLM.from_pretrained(self.model_id).to(device_model)
model.eval()
inject_adapter_in_model(lora_config, model, low_cpu_mem_usage=True)
load_result = set_peft_model_state_dict(model, peft_model_state_dict, low_cpu_mem_usage=True)

# sanity check: all lora keys are matched
assert not any("lora" in k for k in load_result.missing_keys)
assert not any("lora" in k for k in load_result.unexpected_keys)

logits_low_cpu_mem = model(**inputs).logits

assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem)
assert {p.device.type for p in model.parameters()} == {device_model}
Loading