Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions toolkit/network_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,15 +677,32 @@ def load_weights(self: Network, file, force_weight_mapping=False):

elif "lora_down" in key and src_h < tgt_h:
print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}")
# Seed new dimensions with small kaiming noise instead of zeros
# so gradients can flow from step 1 (pure zeros create a dead zone
# where both up and down have zero gradients and can never learn)
new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype)
new_val[:src_h, :src_w] = load_value # src_w should already match
new_val[:src_h, :src_w] = load_value
# Scale noise relative to existing weights
weight_scale = load_value.float().norm() / (src_h ** 0.5) if src_h > 0 else 1e-3
noise_scale = weight_scale * 0.1 # 10% of typical row magnitude
fan_in = tgt_w
std = noise_scale / (fan_in ** 0.5)
new_val[src_h:, :tgt_w] = (torch.randn(
(tgt_h - src_h, tgt_w), device=load_value.device
) * std).to(load_value.dtype)
load_sd[key] = new_val
self.did_change_weights = True

elif "lora_up" in key and src_w < tgt_w:
print_once(f"Expanding {key} from {load_value.shape} to {blank_val.shape}")
# Seed new dimensions with small noise (matching lora_down seeding above)
new_val = torch.zeros((tgt_h, tgt_w), device=load_value.device, dtype=load_value.dtype)
new_val[:src_h, :src_w] = load_value # src_h should already match
new_val[:src_h, :src_w] = load_value
weight_scale = load_value.float().norm() / (src_w ** 0.5) if src_w > 0 else 1e-3
noise_scale = weight_scale * 0.01 # 1% of typical column magnitude
new_val[:tgt_h, src_w:] = (torch.randn(
(tgt_h, tgt_w - src_w), device=load_value.device
) * noise_scale).to(load_value.dtype)
load_sd[key] = new_val
self.did_change_weights = True

Expand Down