Skip to content

Commit 989a035

Browse files
authored
Optionally use unmerged weights for inference (#745)
1 parent fa15369 commit 989a035

File tree

8 files changed

+66
-4
lines changed

8 files changed

+66
-4
lines changed

examples/inference/lora/wan_lora_inference_from_ckpt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ def main():
1414
vae_cpu_offload=True,
1515
text_encoder_cpu_offload=True,
1616
pin_cpu_memory=True, # set to false if low CPU RAM or hit obscure "CUDA error: Invalid argument"
17-
lora_path="checkpoints/wan_t2v_finetune_lora/checkpoint-1000/transformer",
17+
lora_path="checkpoints/wan_t2v_finetune_lora/checkpoint-160/transformer",
1818
lora_nickname="crush_smol"
1919
)
20+
generator.unmerge_lora_weights()
2021
kwargs = {
2122
"height": 480,
2223
"width": 832,

examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v_lora.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export WANDB_MODE=online
77
MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
88
DATA_DIR="data/crush-smol_processed_t2v/combined_parquet_dataset/"
99
VALIDATION_DATASET_FILE="$(dirname "$0")/validation.json"
10-
NUM_GPUS=2
10+
NUM_GPUS=1
1111
# export CUDA_VISIBLE_DEVICES=4,5
1212

1313

@@ -76,6 +76,7 @@ miscellaneous_args=(
7676
--dit_precision "fp32"
7777
--num_euler_timesteps 50
7878
--ema_start_step 0
79+
--resume_from_checkpoint "checkpoints/wan_t2v_finetune_lora/checkpoint-160"
7980
)
8081

8182
torchrun \

fastvideo/entrypoints/video_generator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,16 @@ def set_lora_adapter(self,
351351
lora_path: str | None = None) -> None:
352352
self.executor.set_lora_adapter(lora_nickname, lora_path)
353353

354+
def unmerge_lora_weights(self) -> None:
355+
"""
356+
Use unmerged weights for inference to produce videos that align with
357+
validation videos generated during training.
358+
"""
359+
self.executor.unmerge_lora_weights()
360+
361+
def merge_lora_weights(self) -> None:
362+
self.executor.merge_lora_weights()
363+
354364
def shutdown(self):
355365
"""
356366
Shutdown the video generator.

fastvideo/layers/lora/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7676
lora_B = self.lora_B.to_local()
7777
lora_A = self.lora_A.to_local()
7878

79-
if (self.training_mode or not self.merged) and not self.disable_lora:
79+
if not self.merged and not self.disable_lora:
8080
delta = x @ (
8181
self.slice_lora_b_weights(lora_B.to(x, non_blocking=True))
8282
@ self.slice_lora_a_weights(lora_A.to(x, non_blocking=True)))

fastvideo/pipelines/lora_pipeline.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,11 @@ def set_lora_adapter(self,
217217
layer.disable_lora = True
218218
logger.info("Rank %d: LoRA adapter %s applied to %d layers", rank,
219219
lora_path, adapted_count)
220+
221+
def merge_lora_weights(self) -> None:
222+
for name, layer in self.lora_layers.items():
223+
layer.merge_lora_weights()
224+
225+
def unmerge_lora_weights(self) -> None:
226+
for name, layer in self.lora_layers.items():
227+
layer.unmerge_lora_weights()

fastvideo/worker/executor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,20 @@ def set_lora_adapter(self,
5757
"""
5858
raise NotImplementedError
5959

60+
@abstractmethod
61+
def unmerge_lora_weights(self) -> None:
62+
"""
63+
Unmerge the LoRA weights for the workers.
64+
"""
65+
raise NotImplementedError
66+
67+
@abstractmethod
68+
def merge_lora_weights(self) -> None:
69+
"""
70+
Merge the LoRA weights for the workers.
71+
"""
72+
raise NotImplementedError
73+
6074
@abstractmethod
6175
def collective_rpc(self,
6276
method: str | Callable[..., _R],

fastvideo/worker/gpu_worker.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from fastvideo.distributed.parallel_state import get_local_torch_device
1919
from fastvideo.fastvideo_args import FastVideoArgs
2020
from fastvideo.logger import init_logger
21-
from fastvideo.pipelines import ForwardBatch, build_pipeline
21+
from fastvideo.pipelines import ForwardBatch, LoRAPipeline, build_pipeline
2222
from fastvideo.platforms import current_platform
2323
from fastvideo.utils import (get_exception_traceback,
2424
kill_itself_when_parent_died)
@@ -117,6 +117,14 @@ def shutdown(self) -> dict[str, Any]:
117117
local_main_process_only=False)
118118
return {"status": "shutdown_complete"}
119119

120+
def unmerge_lora_weights(self) -> None:
121+
if isinstance(self.pipeline, LoRAPipeline):
122+
self.pipeline.unmerge_lora_weights()
123+
124+
def merge_lora_weights(self) -> None:
125+
if isinstance(self.pipeline, LoRAPipeline):
126+
self.pipeline.merge_lora_weights()
127+
120128
def event_loop(self) -> None:
121129
"""Event loop for the worker."""
122130
logger.info("Worker %d starting event loop...",
@@ -154,6 +162,14 @@ def event_loop(self) -> None:
154162
logger.info("Worker %d set LoRA adapter %s with path %s",
155163
self.rank, lora_nickname, lora_path)
156164
self.pipe.send({"status": "lora_adapter_set"})
165+
elif method_name == 'unmerge_lora_weights':
166+
self.unmerge_lora_weights()
167+
logger.info("Worker %d unmerged LoRA weights", self.rank)
168+
self.pipe.send({"status": "lora_adapter_unmerged"})
169+
elif method_name == 'merge_lora_weights':
170+
self.merge_lora_weights()
171+
logger.info("Worker %d merged LoRA weights", self.rank)
172+
self.pipe.send({"status": "lora_adapter_merged"})
157173
else:
158174
# Handle other methods dynamically if needed
159175
args = recv_rpc.get('args', ())

fastvideo/worker/multiproc_executor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,18 @@ def set_lora_adapter(self,
109109
raise RuntimeError(
110110
f"Worker {i} failed to set LoRA adapter to {lora_path}")
111111

112+
def unmerge_lora_weights(self) -> None:
113+
responses = self.collective_rpc("unmerge_lora_weights", kwargs={})
114+
for i, response in enumerate(responses):
115+
if response["status"] != "lora_adapter_unmerged":
116+
raise RuntimeError(f"Worker {i} failed to unmerge LoRA weights")
117+
118+
def merge_lora_weights(self) -> None:
119+
responses = self.collective_rpc("merge_lora_weights", kwargs={})
120+
for i, response in enumerate(responses):
121+
if response["status"] != "lora_adapter_merged":
122+
raise RuntimeError(f"Worker {i} failed to merge LoRA weights")
123+
112124
def collective_rpc(self,
113125
method: str | Callable,
114126
timeout: float | None = None,

0 commit comments

Comments
 (0)