Skip to content

Commit ad16289

Browse files
authored
[LoRA] Fix lora merge weights (#579)
1 parent 2a41da1 commit ad16289

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

fastvideo/v1/layers/lora/linear.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,16 @@ def merge_lora_weights(self) -> None:
7171
f"cuda:{torch.cuda.current_device()}").full_tensor()
7272
data += (self.slice_lora_b_weights(self.lora_B)
7373
@ self.slice_lora_a_weights(self.lora_A)).to(data)
74-
self.base_layer.weight.data = distribute_tensor(
75-
data, mesh, placements=placements).to(current_device)
74+
self.base_layer.weight = nn.Parameter(
75+
distribute_tensor(data, mesh,
76+
placements=placements).to(current_device))
7677
else:
7778
current_device = self.base_layer.weight.data.device
78-
data = self.base_layer.weight.data.to(
79+
data = self.base_layer.weight.to(
7980
f"cuda:{torch.cuda.current_device()}")
8081
data += \
8182
(self.slice_lora_b_weights(self.lora_B) @ self.slice_lora_a_weights(self.lora_A)).to(data)
82-
self.base_layer.weight.data = data.to(current_device)
83+
self.base_layer.weight = nn.Parameter(data.to(current_device))
8384
self.merged = True
8485

8586
@torch.no_grad()
@@ -106,8 +107,8 @@ def unmerge_lora_weights(self) -> None:
106107
f"cuda:{torch.cuda.current_device()}").full_tensor()
107108
data -= self.slice_lora_b_weights(
108109
self.lora_B) @ self.slice_lora_a_weights(self.lora_A)
109-
self.base_layer.weight.data = distribute_tensor(
110-
data, mesh, placements=placement).to(device)
110+
self.base_layer.weight = nn.Parameter(
111+
distribute_tensor(data, mesh, placements=placement).to(device))
111112
else:
112113
self.base_layer.weight.data -= \
113114
self.slice_lora_b_weights(self.lora_B) @\

0 commit comments

Comments
 (0)