@@ -104,6 +104,7 @@ def calculate_shift(
104104 return mu
105105
106106
107+ # Adapted from the original implementation.
107108def prepare_latents_img2img (
108109 vae , scheduler , image , timestep , batch_size , num_channels_latents , height , width , dtype , device , generator
109110):
@@ -196,8 +197,19 @@ def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
196197 return image_latents
197198
198199
199- def _get_timesteps_and_optionals (transformer , scheduler , latents , num_inference_steps , guidance_scale , sigmas , device ):
200- image_seq_len = latents .shape [1 ]
200+ def _get_initial_timesteps_and_optionals (
201+ transformer ,
202+ scheduler ,
203+ batch_size ,
204+ height ,
205+ width ,
206+ vae_scale_factor ,
207+ num_inference_steps ,
208+ guidance_scale ,
209+ sigmas ,
210+ device ,
211+ ):
212+ image_seq_len = (int (height ) // vae_scale_factor // 2 ) * (int (width ) // vae_scale_factor // 2 )
201213
202214 sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
203215 if hasattr (scheduler .config , "use_flow_sigmas" ) and scheduler .config .use_flow_sigmas :
@@ -212,7 +224,7 @@ def _get_timesteps_and_optionals(transformer, scheduler, latents, num_inference_
212224 timesteps , num_inference_steps = retrieve_timesteps (scheduler , num_inference_steps , device , sigmas = sigmas , mu = mu )
213225 if transformer .config .guidance_embeds :
214226 guidance = torch .full ([1 ], guidance_scale , device = device , dtype = torch .float32 )
215- guidance = guidance .expand (latents . shape [ 0 ] )
227+ guidance = guidance .expand (batch_size )
216228 else :
217229 guidance = None
218230
@@ -328,18 +340,20 @@ def inputs(self) -> List[InputParam]:
328340 InputParam ("timesteps" ),
329341 InputParam ("sigmas" ),
330342 InputParam ("guidance_scale" , default = 3.5 ),
331- InputParam ("latents" , type_hint = torch .Tensor ),
343+ InputParam ("num_images_per_prompt" , default = 1 ),
344+ InputParam ("height" , type_hint = int ),
345+ InputParam ("width" , type_hint = int ),
332346 ]
333347
334348 @property
335349 def intermediate_inputs (self ) -> List [str ]:
336350 return [
337351 InputParam (
338- "latents " ,
352+ "batch_size " ,
339353 required = True ,
340- type_hint = torch . Tensor ,
341- description = "The initial latents to use for the denoising process . Can be generated in prepare_latent step." ,
342- )
354+ type_hint = int ,
355+ description = "Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt` . Can be generated in input step." ,
356+ ),
343357 ]
344358
345359 @property
@@ -362,10 +376,14 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
362376 scheduler = components .scheduler
363377 transformer = components .transformer
364378
365- timesteps , num_inference_steps , sigmas , guidance = _get_timesteps_and_optionals (
379+ batch_size = block_state .batch_size * block_state .num_images_per_prompt
380+ timesteps , num_inference_steps , sigmas , guidance = _get_initial_timesteps_and_optionals (
366381 transformer ,
367382 scheduler ,
368- block_state .latents ,
383+ batch_size ,
384+ block_state .height ,
385+ block_state .width ,
386+ components .vae_scale_factor ,
369387 block_state .num_inference_steps ,
370388 block_state .guidance_scale ,
371389 block_state .sigmas ,
@@ -397,20 +415,22 @@ def inputs(self) -> List[InputParam]:
397415 InputParam ("num_inference_steps" , default = 50 ),
398416 InputParam ("timesteps" ),
399417 InputParam ("sigmas" ),
418+ InputParam ("strength" , default = 0.6 ),
400419 InputParam ("guidance_scale" , default = 3.5 ),
401- InputParam ("latents" , type_hint = torch .Tensor ),
402420 InputParam ("num_images_per_prompt" , default = 1 ),
421+ InputParam ("height" , type_hint = int ),
422+ InputParam ("width" , type_hint = int ),
403423 ]
404424
405425 @property
406426 def intermediate_inputs (self ) -> List [str ]:
407427 return [
408428 InputParam (
409- "latents " ,
429+ "batch_size " ,
410430 required = True ,
411- type_hint = torch . Tensor ,
412- description = "The initial latents to use for the denoising process . Can be generated in prepare_latent step." ,
413- )
431+ type_hint = int ,
432+ description = "Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt` . Can be generated in input step." ,
433+ ),
414434 ]
415435
416436 @property
@@ -430,30 +450,48 @@ def intermediate_outputs(self) -> List[OutputParam]:
430450 OutputParam ("guidance" , type_hint = torch .Tensor , description = "Optional guidance to be used." ),
431451 ]
432452
453+ @staticmethod
454+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps with self->scheduler
455+ def get_timesteps (scheduler , num_inference_steps , strength , device ):
456+ # get the original timestep using init_timestep
457+ init_timestep = min (num_inference_steps * strength , num_inference_steps )
458+
459+ t_start = int (max (num_inference_steps - init_timestep , 0 ))
460+ timesteps = scheduler .timesteps [t_start * scheduler .order :]
461+ if hasattr (scheduler , "set_begin_index" ):
462+ scheduler .set_begin_index (t_start * scheduler .order )
463+
464+ return timesteps , num_inference_steps - t_start
465+
433466 @torch .no_grad ()
434467 def __call__ (self , components : FluxModularPipeline , state : PipelineState ) -> PipelineState :
435468 block_state = self .get_block_state (state )
436469 block_state .device = components ._execution_device
437470
438471 scheduler = components .scheduler
439472 transformer = components .transformer
440-
441- timesteps , num_inference_steps , sigmas , guidance = _get_timesteps_and_optionals (
473+ batch_size = block_state . batch_size * block_state . num_images_per_prompt
474+ timesteps , num_inference_steps , sigmas , guidance = _get_initial_timesteps_and_optionals (
442475 transformer ,
443476 scheduler ,
444- block_state .latents ,
477+ batch_size ,
478+ block_state .height ,
479+ block_state .width ,
480+ components .vae_scale_factor ,
445481 block_state .num_inference_steps ,
446482 block_state .guidance_scale ,
447483 block_state .sigmas ,
448484 block_state .device ,
449485 )
486+ timesteps , num_inference_steps = self .get_timesteps (
487+ scheduler , num_inference_steps , block_state .strength , block_state .device
488+ )
450489 block_state .timesteps = timesteps
451490 block_state .num_inference_steps = num_inference_steps
452491 block_state .sigmas = sigmas
453492 block_state .guidance = guidance
454493
455- batch_size = block_state .latents .shape [0 ]
456- block_state .latent_timestep = timesteps [:1 ].repeat (batch_size * block_state .num_images_per_prompt )
494+ block_state .latent_timestep = timesteps [:1 ].repeat (batch_size )
457495
458496 self .set_block_state (state , block_state )
459497 return components , state
@@ -468,7 +506,7 @@ def expected_components(self) -> List[ComponentSpec]:
468506
469507 @property
470508 def description (self ) -> str :
471- return "Prepare latents step that prepares the latents for the text-to-video generation process"
509+ return "Prepare latents step that prepares the latents for the text-to-image generation process"
472510
473511 @property
474512 def inputs (self ) -> List [InputParam ]:
@@ -565,10 +603,10 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
565603 block_state .num_channels_latents = components .num_channels_latents
566604
567605 self .check_inputs (components , block_state )
568-
606+ batch_size = block_state . batch_size * block_state . num_images_per_prompt
569607 block_state .latents , block_state .latent_image_ids = self .prepare_latents (
570608 components ,
571- block_state . batch_size * block_state . num_images_per_prompt ,
609+ batch_size ,
572610 block_state .num_channels_latents ,
573611 block_state .height ,
574612 block_state .width ,
@@ -601,7 +639,6 @@ def inputs(self) -> List[Tuple[str, Any]]:
601639 InputParam ("width" , type_hint = int ),
602640 InputParam ("latents" , type_hint = Optional [torch .Tensor ]),
603641 InputParam ("num_images_per_prompt" , type_hint = int , default = 1 ),
604- InputParam ("latents" ),
605642 ]
606643
607644 @property
@@ -655,14 +692,14 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
655692 block_state .device = components ._execution_device
656693
657694 # TODO: implement `check_inputs`
658-
695+ batch_size = block_state . batch_size * block_state . num_images_per_prompt
659696 if block_state .latents is None :
660697 block_state .latents , block_state .latent_image_ids = prepare_latents_img2img (
661698 components .vae ,
662699 components .scheduler ,
663700 block_state .image_latents ,
664701 block_state .latent_timestep ,
665- block_state . batch_size * block_state . num_images_per_prompt ,
702+ batch_size ,
666703 block_state .num_channels_latents ,
667704 block_state .height ,
668705 block_state .width ,
0 commit comments