Skip to content

Commit 6ed1131

Browse files
committed
meta device fixes.
1 parent 2f05455 commit 6ed1131

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,12 +2310,14 @@ def unload_lora_weights(self):
23102310
dtype=module_weight.dtype,
23112311
)
23122312

2313-
original_module.weight.data.copy_(current_param_weight)
2313+
tmp_state_dict = {"weight": current_param_weight}
23142314
if module_bias is not None:
2315-
original_module.bias.data.copy_(overwritten_params[f"{name}.bias"])
2316-
2315+
tmp_state_dict.update({"bias": overwritten_params[f"{name}.bias"]})
2316+
original_module.load_state_dict(tmp_state_dict, assign=True, strict=True)
23172317
setattr(parent_module, current_module_name, original_module)
23182318

2319+
del tmp_state_dict
2320+
23192321
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
23202322
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
23212323
new_value = int(current_param_weight.shape[1])

0 commit comments

Comments
 (0)