Skip to content

Commit 5ef79f3

Browse files
committed
use torch.device meta for state dict expansion.
1 parent 258a398 commit 5ef79f3

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2353,22 +2353,25 @@ def _maybe_expand_transformer_param_shape_or_error_(
23532353
parent_module_name, _, current_module_name = name.rpartition(".")
23542354
parent_module = transformer.get_submodule(parent_module_name)
23552355

2356-
# TODO: consider initializing this under meta device for optims.
2357-
expanded_module = torch.nn.Linear(
2358-
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
2359-
)
2356+
with torch.device("meta"):
2357+
expanded_module = torch.nn.Linear(
2358+
in_features, out_features, bias=bias, dtype=module_weight.dtype
2359+
)
23602360
# Only weights are expanded and biases are not.
23612361
new_weight = torch.zeros_like(
23622362
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
23632363
)
23642364
slices = tuple(slice(0, dim) for dim in module_weight.shape)
23652365
new_weight[slices] = module_weight
2366-
expanded_module.weight.data.copy_(new_weight)
2366+
tmp_state_dict = {"weight": new_weight}
23672367
if module_bias is not None:
2368-
expanded_module.bias.data.copy_(module_bias)
2368+
tmp_state_dict["bias"] = module_bias
2369+
expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True)
23692370

23702371
setattr(parent_module, current_module_name, expanded_module)
23712372

2373+
del tmp_state_dict
2374+
23722375
if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX:
23732376
attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name]
23742377
new_value = int(expanded_module.weight.data.shape[1])

0 commit comments

Comments
 (0)