Skip to content

Commit 695ad14

Browse files
committed
fix
1 parent 36eb48b commit 695ad14

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,6 +1906,7 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19061906

19071907
for name, module in transformer.named_modules():
19081908
if isinstance(module, torch.nn.Linear) and name in module_names:
1909+
module_weight = module.weight.data
19091910
module_bias = module.bias.data if module.bias is not None else None
19101911
bias = module_bias is not None
19111912

@@ -1919,6 +1920,7 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19191920
in_features,
19201921
out_features,
19211922
bias=bias,
1923+
dtype=module_weight.dtype,
19221924
)
19231925

19241926
tmp_state_dict = {"weight": current_param_weight}
@@ -2021,12 +2023,16 @@ def _maybe_expand_transformer_param_shape_or_error_(
20212023
parent_module = transformer.get_submodule(parent_module_name)
20222024

20232025
with torch.device("meta"):
2024-
expanded_module = torch.nn.Linear(in_features, out_features, bias=bias)
2026+
expanded_module = torch.nn.Linear(
2027+
in_features, out_features, bias=bias, dtype=module_weight.dtype
2028+
)
20252029
# Only weights are expanded and biases are not. This is because only the input dimensions
20262030
# are changed while the output dimensions remain the same. The shape of the weight tensor
20272031
# is (out_features, in_features), while the shape of bias tensor is (out_features,), which
20282032
# explains the reason why only weights are expanded.
2029-
new_weight = torch.zeros_like(expanded_module.weight.data)
2033+
new_weight = torch.zeros_like(
2034+
expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
2035+
)
20302036
slices = tuple(slice(0, dim) for dim in module_weight_shape)
20312037
new_weight[slices] = module_weight
20322038
tmp_state_dict = {"weight": new_weight}

0 commit comments

Comments
 (0)