@@ -95,8 +95,14 @@ def convert(self, state_dict):
9595 return state_dict
9696
9797
98+ class WanLowNoiseLoRAConverter (WanLoRAConverter ):
99+ def convert (self , state_dict ):
100+ return {"dit2" : super ().convert (state_dict )["dit" ]}
101+
102+
98103class WanVideoPipeline (BasePipeline ):
99104 lora_converter = WanLoRAConverter ()
105+ low_noise_lora_converter = WanLowNoiseLoRAConverter ()
100106
101107 def __init__ (
102108 self ,
@@ -133,7 +139,13 @@ def __init__(
133139 self .image_encoder = image_encoder
134140 self .model_names = ["text_encoder" , "dit" , "dit2" , "vae" , "image_encoder" ]
135141
136- def load_loras (self , lora_list : List [Tuple [str , float ]], fused : bool = True , save_original_weight : bool = False ):
142+ def load_loras (
143+ self ,
144+ lora_list : List [Tuple [str , float ]],
145+ fused : bool = True ,
146+ save_original_weight : bool = False ,
147+ lora_converter : Optional [WanLoRAConverter ] = None
148+ ):
137149 assert self .config .tp_degree is None or self .config .tp_degree == 1 , (
138150 "load LoRA is not allowed when tensor parallel is enabled; "
139151 "set tp_degree=None or tp_degree=1 during pipeline initialization"
@@ -142,10 +154,20 @@ def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, sav
142154 "load fused LoRA is not allowed when fully sharded data parallel is enabled; "
143155 "either load LoRA with fused=False or set use_fsdp=False during pipeline initialization"
144156 )
145- super ().load_loras (lora_list , fused , save_original_weight )
157+ super ().load_loras (lora_list , fused , save_original_weight , lora_converter )
158+
159+ def load_loras_low_noise (self , lora_list : List [Tuple [str , float ]], fused : bool = True , save_original_weight : bool = False ):
160+ assert self .dit2 is not None , "low noise LoRA can only be applied to Wan2.2"
161+ self .load_loras (lora_list , fused , save_original_weight , self .low_noise_lora_converter )
162+
163+ def load_loras_high_noise (self , lora_list : List [Tuple [str , float ]], fused : bool = True , save_original_weight : bool = False ):
164+ assert self .dit2 is not None , "high noise LoRA can only be applied to Wan2.2"
165+ self .load_loras (lora_list , fused , save_original_weight )
146166
147167 def unload_loras (self ):
148168 self .dit .unload_loras ()
169+ if self .dit2 is not None :
170+ self .dit2 .unload_loras ()
149171 self .text_encoder .unload_loras ()
150172
151173 def get_default_fps (self ) -> int :
0 commit comments