|
18 | 18 | from fastvideo.distributed.parallel_state import get_local_torch_device |
19 | 19 | from fastvideo.fastvideo_args import FastVideoArgs |
20 | 20 | from fastvideo.logger import init_logger |
21 | | -from fastvideo.pipelines import ForwardBatch, build_pipeline |
| 21 | +from fastvideo.pipelines import ForwardBatch, LoRAPipeline, build_pipeline |
22 | 22 | from fastvideo.platforms import current_platform |
23 | 23 | from fastvideo.utils import (get_exception_traceback, |
24 | 24 | kill_itself_when_parent_died) |
@@ -117,6 +117,14 @@ def shutdown(self) -> dict[str, Any]: |
117 | 117 | local_main_process_only=False) |
118 | 118 | return {"status": "shutdown_complete"} |
119 | 119 |
|
| 120 | + def unmerge_lora_weights(self) -> None: |
| 121 | + if isinstance(self.pipeline, LoRAPipeline): |
| 122 | + self.pipeline.unmerge_lora_weights() |
| 123 | + |
| 124 | + def merge_lora_weights(self) -> None: |
| 125 | + if isinstance(self.pipeline, LoRAPipeline): |
| 126 | + self.pipeline.merge_lora_weights() |
| 127 | + |
120 | 128 | def event_loop(self) -> None: |
121 | 129 | """Event loop for the worker.""" |
122 | 130 | logger.info("Worker %d starting event loop...", |
@@ -154,6 +162,14 @@ def event_loop(self) -> None: |
154 | 162 | logger.info("Worker %d set LoRA adapter %s with path %s", |
155 | 163 | self.rank, lora_nickname, lora_path) |
156 | 164 | self.pipe.send({"status": "lora_adapter_set"}) |
| 165 | + elif method_name == 'unmerge_lora_weights': |
| 166 | + self.unmerge_lora_weights() |
| 167 | + logger.info("Worker %d unmerged LoRA weights", self.rank) |
| 168 | + self.pipe.send({"status": "lora_adapter_unmerged"}) |
| 169 | + elif method_name == 'merge_lora_weights': |
| 170 | + self.merge_lora_weights() |
| 171 | + logger.info("Worker %d merged LoRA weights", self.rank) |
| 172 | + self.pipe.send({"status": "lora_adapter_merged"}) |
157 | 173 | else: |
158 | 174 | # Handle other methods dynamically if needed |
159 | 175 | args = recv_rpc.get('args', ()) |
|
0 commit comments