3434logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3535
3636
37+ class WanI2VLoopBeforeDenoiser (PipelineBlock ):
38+ model_name = "stable-diffusion-xl"
39+
40+ @property
41+ def expected_components (self ) -> List [ComponentSpec ]:
42+ return [
43+ ComponentSpec ("scheduler" , UniPCMultistepScheduler ),
44+ ]
45+
46+ @property
47+ def description (self ) -> str :
48+ return (
49+ "Step within the denoising loop that prepares the latent input for the denoiser. "
50+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
51+ "object (e.g. `WanI2VDenoiseLoopWrapper`)"
52+ )
53+
54+ @property
55+ def intermediate_inputs (self ) -> List [str ]:
56+ return [
57+ InputParam (
58+ "latents" ,
59+ required = True ,
60+ type_hint = torch .Tensor ,
61+ description = "The initial latents to use for the denoising process." ,
62+ ),
63+ InputParam (
64+ "latent_condition" ,
65+ required = True ,
66+ type_hint = torch .Tensor ,
67+ description = "The latent condition to use for the denoising process." ,
68+ ),
69+ ]
70+
71+ @property
72+ def intermediate_outputs (self ) -> List [OutputParam ]:
73+ return [
74+ OutputParam (
75+ "concatenated_latents" ,
76+ type_hint = torch .Tensor ,
77+ description = "The concatenated noisy and conditioning latents to use for the denoising process." ,
78+ ),
79+ ]
80+
81+ @torch .no_grad ()
82+ def __call__ (self , components : WanModularPipeline , block_state : BlockState , i : int , t : int ):
83+ block_state .concatenated_latents = torch .cat ([block_state .latents , block_state .latent_condition ], dim = 1 )
84+ return components , block_state
85+
86+
3787class WanLoopDenoiser (PipelineBlock ):
3888 model_name = "wan"
3989
@@ -102,7 +152,7 @@ def __call__(
102152 components .guider .set_state (step = i , num_inference_steps = block_state .num_inference_steps , timestep = t )
103153
104154 # Prepare mini‐batches according to guidance method and `guider_input_fields`
105- # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds .
155+ # Each guider_state_batch will have .prompt_embeds.
106156 # e.g. for CFG, we prepare two batches: one for uncond, one for cond
107157 # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
108158 # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
@@ -120,7 +170,112 @@ def __call__(
120170 guider_state_batch .noise_pred = components .transformer (
121171 hidden_states = block_state .latents .to (transformer_dtype ),
122172 timestep = t .flatten (),
123- encoder_hidden_states = prompt_embeds ,
173+ encoder_hidden_states = prompt_embeds .to (transformer_dtype ),
174+ attention_kwargs = block_state .attention_kwargs ,
175+ return_dict = False ,
176+ )[0 ]
177+ components .guider .cleanup_models (components .transformer )
178+
179+ # Perform guidance
180+ block_state .noise_pred , block_state .scheduler_step_kwargs = components .guider (guider_state )
181+
182+ return components , block_state
183+
184+
185+ class WanI2VLoopDenoiser (PipelineBlock ):
186+ model_name = "wan"
187+
188+ @property
189+ def expected_components (self ) -> List [ComponentSpec ]:
190+ return [
191+ ComponentSpec (
192+ "guider" ,
193+ ClassifierFreeGuidance ,
194+ config = FrozenDict ({"guidance_scale" : 5.0 }),
195+ default_creation_method = "from_config" ,
196+ ),
197+ ComponentSpec ("transformer" , WanTransformer3DModel ),
198+ ]
199+
200+ @property
201+ def description (self ) -> str :
202+ return (
203+ "Step within the denoising loop that denoise the latents with guidance. "
204+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
205+ "object (e.g. `WanDenoiseLoopWrapper`)"
206+ )
207+
208+ @property
209+ def inputs (self ) -> List [Tuple [str , Any ]]:
210+ return [
211+ InputParam ("attention_kwargs" ),
212+ ]
213+
214+ @property
215+ def intermediate_inputs (self ) -> List [str ]:
216+ return [
217+ InputParam (
218+ "concatenated_latents" ,
219+ required = True ,
220+ type_hint = torch .Tensor ,
221+ description = "The initial latents to use for the denoising process." ,
222+ ),
223+ InputParam (
224+ "encoder_hidden_states_image" ,
225+ required = True ,
226+ type_hint = torch .Tensor ,
227+ description = "The encoder hidden states for the image inputs." ,
228+ ),
229+ InputParam (
230+ "num_inference_steps" ,
231+ required = True ,
232+ type_hint = int ,
233+ description = "The number of inference steps to use for the denoising process." ,
234+ ),
235+ InputParam (
236+ kwargs_type = "guider_input_fields" ,
237+ description = (
238+ "All conditional model inputs that need to be prepared with guider. "
239+ "It should contain prompt_embeds/negative_prompt_embeds. "
240+ "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
241+ ),
242+ ),
243+ ]
244+
245+ @torch .no_grad ()
246+ def __call__ (
247+ self , components : WanModularPipeline , block_state : BlockState , i : int , t : torch .Tensor
248+ ) -> PipelineState :
249+ # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
250+ # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
251+ guider_input_fields = {
252+ "prompt_embeds" : ("prompt_embeds" , "negative_prompt_embeds" ),
253+ }
254+ transformer_dtype = components .transformer .dtype
255+
256+ components .guider .set_state (step = i , num_inference_steps = block_state .num_inference_steps , timestep = t )
257+
258+ # Prepare mini‐batches according to guidance method and `guider_input_fields`
259+ # Each guider_state_batch will have .prompt_embeds.
260+ # e.g. for CFG, we prepare two batches: one for uncond, one for cond
261+ # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
262+ # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
263+ guider_state = components .guider .prepare_inputs (block_state , guider_input_fields )
264+
265+ # run the denoiser for each guidance batch
266+ for guider_state_batch in guider_state :
267+ components .guider .prepare_models (components .transformer )
268+ cond_kwargs = guider_state_batch .as_dict ()
269+ cond_kwargs = {k : v for k , v in cond_kwargs .items () if k in guider_input_fields }
270+ prompt_embeds = cond_kwargs .pop ("prompt_embeds" )
271+
272+ # Predict the noise residual
273+ # store the noise_pred in guider_state_batch so that we can apply guidance across all batches
274+ guider_state_batch .noise_pred = components .transformer (
275+ hidden_states = block_state .concatenated_latents .to (transformer_dtype ),
276+ timestep = t .flatten (),
277+ encoder_hidden_states = prompt_embeds .to (transformer_dtype ),
278+ encoder_hidden_states_image = block_state .encoder_hidden_states_image .to (transformer_dtype ),
124279 attention_kwargs = block_state .attention_kwargs ,
125280 return_dict = False ,
126281 )[0 ]
@@ -247,7 +402,7 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
247402 WanLoopDenoiser ,
248403 WanLoopAfterDenoiser ,
249404 ]
250- block_names = ["before_denoiser" , " denoiser" , "after_denoiser" ]
405+ block_names = ["denoiser" , "after_denoiser" ]
251406
252407 @property
253408 def description (self ) -> str :
@@ -257,5 +412,26 @@ def description(self) -> str:
257412 "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n "
258413 " - `WanLoopDenoiser`\n "
259414 " - `WanLoopAfterDenoiser`\n "
260- "This block supports both text2vid tasks."
415+ "This block supports the text2vid task."
416+ )
417+
418+
419+ class WanI2VDenoiseStep (WanDenoiseLoopWrapper ):
420+ block_classes = [
421+ WanI2VLoopBeforeDenoiser ,
422+ WanI2VLoopDenoiser ,
423+ WanLoopAfterDenoiser ,
424+ ]
425+ block_names = ["before_denoiser" , "denoiser" , "after_denoiser" ]
426+
427+ @property
428+ def description (self ) -> str :
429+ return (
430+ "Denoise step that iteratively denoises the latents with conditional first- and last-frame support. \n "
431+ "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n "
432+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n "
433+ " - `WanI2VLoopBeforeDenoiser`\n "
434+ " - `WanI2VLoopDenoiser`\n "
435+ " - `WanI2VLoopAfterDenoiser`\n "
436+ "This block supports the image-to-video and first-last-frame-to-video tasks."
261437 )
0 commit comments