Skip to content

Commit a8bd03b

Browse files
committed
reviewer feedback
1 parent 55058e2 commit a8bd03b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,14 +2318,13 @@ def _maybe_expand_transformer_param_shape_or_error_(
23182318

23192319
lora_A_weight_name = f"{name}.lora_A.weight"
23202320
lora_B_weight_name = f"{name}.lora_B.weight"
2321-
# lora_B_bias_name = f"{name}.lora_B.bias"
2322-
23232321
if lora_A_weight_name not in state_dict.keys():
23242322
continue
23252323

23262324
in_features = state_dict[lora_A_weight_name].shape[1]
23272325
out_features = state_dict[lora_B_weight_name].shape[0]
23282326

2327+
# This means there's no need for an expansion in the params, so we simply skip.
23292328
if tuple(module_weight.shape) == (out_features, in_features):
23302329
continue
23312330

@@ -2349,6 +2348,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
23492348
parent_module_name, _, current_module_name = name.rpartition(".")
23502349
parent_module = transformer.get_submodule(parent_module_name)
23512350

2351+
# TODO: consider initializing this under meta device for optims.
23522352
expanded_module = torch.nn.Linear(
23532353
in_features, out_features, bias=bias, device=module_weight.device, dtype=module_weight.dtype
23542354
)

0 commit comments

Comments
 (0)