1818from ..models .wan_video_text_encoder import T5RelativeEmbedding , T5LayerNorm
1919from ..models .wan_video_dit import RMSNorm , sinusoidal_embedding_1d
2020from ..models .wan_video_vae import RMS_norm , CausalConv3d , Upsample
21+ from ..models .wan_video_motion_controller import WanMotionControllerModel
2122
2223
2324
@@ -31,7 +32,8 @@ def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None
3132 self .image_encoder : WanImageEncoder = None
3233 self .dit : WanModel = None
3334 self .vae : WanVideoVAE = None
34- self .model_names = ['text_encoder' , 'dit' , 'vae' , 'image_encoder' ]
35+ self .motion_controller : WanMotionControllerModel = None
36+ self .model_names = ['text_encoder' , 'dit' , 'vae' , 'image_encoder' , 'motion_controller' ]
3537 self .height_division_factor = 16
3638 self .width_division_factor = 16
3739 self .use_unified_sequence_parallel = False
@@ -122,6 +124,22 @@ def enable_vram_management(self, num_persistent_param_in_dit=None):
122124 computation_device = self .device ,
123125 ),
124126 )
127+ if self .motion_controller is not None :
128+ dtype = next (iter (self .motion_controller .parameters ())).dtype
129+ enable_vram_management (
130+ self .motion_controller ,
131+ module_map = {
132+ torch .nn .Linear : AutoWrappedLinear ,
133+ },
134+ module_config = dict (
135+ offload_dtype = dtype ,
136+ offload_device = "cpu" ,
137+ onload_dtype = dtype ,
138+ onload_device = "cpu" ,
139+ computation_dtype = dtype ,
140+ computation_device = self .device ,
141+ ),
142+ )
125143 self .enable_cpu_offload ()
126144
127145
@@ -134,6 +152,7 @@ def fetch_models(self, model_manager: ModelManager):
134152 self .dit = model_manager .fetch_model ("wan_video_dit" )
135153 self .vae = model_manager .fetch_model ("wan_video_vae" )
136154 self .image_encoder = model_manager .fetch_model ("wan_video_image_encoder" )
155+ self .motion_controller = model_manager .fetch_model ("wan_video_motion_controller" )
137156
138157
139158 @staticmethod
@@ -163,22 +182,47 @@ def encode_prompt(self, prompt, positive=True):
163182 return {"context" : prompt_emb }
164183
165184
166- def encode_image (self , image , num_frames , height , width ):
185+ def encode_image (self , image , end_image , num_frames , height , width ):
167186 image = self .preprocess_image (image .resize ((width , height ))).to (self .device )
168187 clip_context = self .image_encoder .encode_image ([image ])
169188 msk = torch .ones (1 , num_frames , height // 8 , width // 8 , device = self .device )
170189 msk [:, 1 :] = 0
190+ if end_image is not None :
191+ end_image = self .preprocess_image (end_image .resize ((width , height ))).to (self .device )
192+ vae_input = torch .concat ([image .transpose (0 ,1 ), torch .zeros (3 , num_frames - 2 , height , width ).to (image .device ), end_image .transpose (0 ,1 )],dim = 1 )
193+ msk [:, - 1 :] = 1
194+ else :
195+ vae_input = torch .concat ([image .transpose (0 , 1 ), torch .zeros (3 , num_frames - 1 , height , width ).to (image .device )], dim = 1 )
196+
171197 msk = torch .concat ([torch .repeat_interleave (msk [:, 0 :1 ], repeats = 4 , dim = 1 ), msk [:, 1 :]], dim = 1 )
172198 msk = msk .view (1 , msk .shape [1 ] // 4 , 4 , height // 8 , width // 8 )
173199 msk = msk .transpose (1 , 2 )[0 ]
174200
175- vae_input = torch .concat ([image .transpose (0 , 1 ), torch .zeros (3 , num_frames - 1 , height , width ).to (image .device )], dim = 1 )
176201 y = self .vae .encode ([vae_input .to (dtype = self .torch_dtype , device = self .device )], device = self .device )[0 ]
177202 y = torch .concat ([msk , y ])
178203 y = y .unsqueeze (0 )
179204 clip_context = clip_context .to (dtype = self .torch_dtype , device = self .device )
180205 y = y .to (dtype = self .torch_dtype , device = self .device )
181206 return {"clip_feature" : clip_context , "y" : y }
207+
208+
209+ def encode_control_video (self , control_video , tiled = True , tile_size = (34 , 34 ), tile_stride = (18 , 16 )):
210+ control_video = self .preprocess_images (control_video )
211+ control_video = torch .stack (control_video , dim = 2 ).to (dtype = self .torch_dtype , device = self .device )
212+ latents = self .encode_video (control_video , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride ).to (dtype = self .torch_dtype , device = self .device )
213+ return latents
214+
215+
216+ def prepare_controlnet_kwargs (self , control_video , num_frames , height , width , clip_feature = None , y = None , tiled = True , tile_size = (34 , 34 ), tile_stride = (18 , 16 )):
217+ if control_video is not None :
218+ control_latents = self .encode_control_video (control_video , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride )
219+ if clip_feature is None or y is None :
220+ clip_feature = torch .zeros ((1 , 257 , 1280 ), dtype = self .torch_dtype , device = self .device )
221+ y = torch .zeros ((1 , 16 , (num_frames - 1 ) // 4 + 1 , height // 8 , width // 8 ), dtype = self .torch_dtype , device = self .device )
222+ else :
223+ y = y [:, - 16 :]
224+ y = torch .concat ([control_latents , y ], dim = 1 )
225+ return {"clip_feature" : clip_feature , "y" : y }
182226
183227
184228 def tensor2video (self , frames ):
@@ -204,6 +248,11 @@ def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18,
204248
205249 def prepare_unified_sequence_parallel (self ):
206250 return {"use_unified_sequence_parallel" : self .use_unified_sequence_parallel }
251+
252+
253+ def prepare_motion_bucket_id (self , motion_bucket_id ):
254+ motion_bucket_id = torch .Tensor ((motion_bucket_id ,)).to (dtype = self .torch_dtype , device = self .device )
255+ return {"motion_bucket_id" : motion_bucket_id }
207256
208257
209258 @torch .no_grad ()
@@ -212,7 +261,9 @@ def __call__(
212261 prompt ,
213262 negative_prompt = "" ,
214263 input_image = None ,
264+ end_image = None ,
215265 input_video = None ,
266+ control_video = None ,
216267 denoising_strength = 1.0 ,
217268 seed = None ,
218269 rand_device = "cpu" ,
@@ -222,6 +273,7 @@ def __call__(
222273 cfg_scale = 5.0 ,
223274 num_inference_steps = 50 ,
224275 sigma_shift = 5.0 ,
276+ motion_bucket_id = None ,
225277 tiled = True ,
226278 tile_size = (30 , 52 ),
227279 tile_stride = (15 , 26 ),
@@ -263,10 +315,21 @@ def __call__(
263315 # Encode image
264316 if input_image is not None and self .image_encoder is not None :
265317 self .load_models_to_device (["image_encoder" , "vae" ])
266- image_emb = self .encode_image (input_image , num_frames , height , width )
318+ image_emb = self .encode_image (input_image , end_image , num_frames , height , width )
267319 else :
268320 image_emb = {}
269321
322+ # ControlNet
323+ if control_video is not None :
324+ self .load_models_to_device (["image_encoder" , "vae" ])
325+ image_emb = self .prepare_controlnet_kwargs (control_video , num_frames , height , width , ** image_emb , ** tiler_kwargs )
326+
327+ # Motion Controller
328+ if self .motion_controller is not None and motion_bucket_id is not None :
329+ motion_kwargs = self .prepare_motion_bucket_id (motion_bucket_id )
330+ else :
331+ motion_kwargs = {}
332+
270333 # Extra input
271334 extra_input = self .prepare_extra_input (latents )
272335
@@ -278,14 +341,24 @@ def __call__(
278341 usp_kwargs = self .prepare_unified_sequence_parallel ()
279342
280343 # Denoise
281- self .load_models_to_device (["dit" ])
344+ self .load_models_to_device (["dit" , "motion_controller" ])
282345 for progress_id , timestep in enumerate (progress_bar_cmd (self .scheduler .timesteps )):
283346 timestep = timestep .unsqueeze (0 ).to (dtype = self .torch_dtype , device = self .device )
284347
285348 # Inference
286- noise_pred_posi = model_fn_wan_video (self .dit , latents , timestep = timestep , ** prompt_emb_posi , ** image_emb , ** extra_input , ** tea_cache_posi , ** usp_kwargs )
349+ noise_pred_posi = model_fn_wan_video (
350+ self .dit , motion_controller = self .motion_controller ,
351+ x = latents , timestep = timestep ,
352+ ** prompt_emb_posi , ** image_emb , ** extra_input ,
353+ ** tea_cache_posi , ** usp_kwargs , ** motion_kwargs
354+ )
287355 if cfg_scale != 1.0 :
288- noise_pred_nega = model_fn_wan_video (self .dit , latents , timestep = timestep , ** prompt_emb_nega , ** image_emb , ** extra_input , ** tea_cache_nega , ** usp_kwargs )
356+ noise_pred_nega = model_fn_wan_video (
357+ self .dit , motion_controller = self .motion_controller ,
358+ x = latents , timestep = timestep ,
359+ ** prompt_emb_nega , ** image_emb , ** extra_input ,
360+ ** tea_cache_nega , ** usp_kwargs , ** motion_kwargs
361+ )
289362 noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega )
290363 else :
291364 noise_pred = noise_pred_posi
@@ -358,13 +431,15 @@ def update(self, hidden_states):
358431
359432def model_fn_wan_video (
360433 dit : WanModel ,
361- x : torch .Tensor ,
362- timestep : torch .Tensor ,
363- context : torch .Tensor ,
434+ motion_controller : WanMotionControllerModel = None ,
435+ x : torch .Tensor = None ,
436+ timestep : torch .Tensor = None ,
437+ context : torch .Tensor = None ,
364438 clip_feature : Optional [torch .Tensor ] = None ,
365439 y : Optional [torch .Tensor ] = None ,
366440 tea_cache : TeaCache = None ,
367441 use_unified_sequence_parallel : bool = False ,
442+ motion_bucket_id : Optional [torch .Tensor ] = None ,
368443 ** kwargs ,
369444):
370445 if use_unified_sequence_parallel :
@@ -375,6 +450,8 @@ def model_fn_wan_video(
375450
376451 t = dit .time_embedding (sinusoidal_embedding_1d (dit .freq_dim , timestep ))
377452 t_mod = dit .time_projection (t ).unflatten (1 , (6 , dit .dim ))
453+ if motion_bucket_id is not None and motion_controller is not None :
454+ t_mod = t_mod + motion_controller (motion_bucket_id ).unflatten (1 , (6 , dit .dim ))
378455 context = dit .text_embedding (context )
379456
380457 if dit .has_image_input :
0 commit comments