2222from ...schedulers import FlowMatchEulerDiscreteScheduler
2323from ...utils import logging
2424from ...utils .torch_utils import randn_tensor
25- from ..modular_pipeline import PipelineBlock , PipelineState
25+ from ..modular_pipeline import ModularPipelineBlocks , PipelineState
2626from ..modular_pipeline_utils import ComponentSpec , InputParam , OutputParam
2727from .modular_pipeline import FluxModularPipeline
2828
@@ -231,7 +231,7 @@ def _get_initial_timesteps_and_optionals(
231231 return timesteps , num_inference_steps , sigmas , guidance
232232
233233
234- class FluxInputStep (PipelineBlock ):
234+ class FluxInputStep (ModularPipelineBlocks ):
235235 model_name = "flux"
236236
237237 @property
@@ -249,11 +249,6 @@ def description(self) -> str:
249249 def inputs (self ) -> List [InputParam ]:
250250 return [
251251 InputParam ("num_images_per_prompt" , default = 1 ),
252- ]
253-
254- @property
255- def intermediate_inputs (self ) -> List [str ]:
256- return [
257252 InputParam (
258253 "prompt_embeds" ,
259254 required = True ,
@@ -322,7 +317,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
322317 return components , state
323318
324319
325- class FluxSetTimestepsStep (PipelineBlock ):
320+ class FluxSetTimestepsStep (ModularPipelineBlocks ):
326321 model_name = "flux"
327322
328323 @property
@@ -340,14 +335,10 @@ def inputs(self) -> List[InputParam]:
340335 InputParam ("timesteps" ),
341336 InputParam ("sigmas" ),
342337 InputParam ("guidance_scale" , default = 3.5 ),
338+ InputParam ("latents" , type_hint = torch .Tensor ),
343339 InputParam ("num_images_per_prompt" , default = 1 ),
344340 InputParam ("height" , type_hint = int ),
345341 InputParam ("width" , type_hint = int ),
346- ]
347-
348- @property
349- def intermediate_inputs (self ) -> List [str ]:
350- return [
351342 InputParam (
352343 "batch_size" ,
353344 required = True ,
@@ -398,7 +389,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
398389 return components , state
399390
400391
401- class FluxImg2ImgSetTimestepsStep (PipelineBlock ):
392+ class FluxImg2ImgSetTimestepsStep (ModularPipelineBlocks ):
402393 model_name = "flux"
403394
404395 @property
@@ -420,11 +411,6 @@ def inputs(self) -> List[InputParam]:
420411 InputParam ("num_images_per_prompt" , default = 1 ),
421412 InputParam ("height" , type_hint = int ),
422413 InputParam ("width" , type_hint = int ),
423- ]
424-
425- @property
426- def intermediate_inputs (self ) -> List [str ]:
427- return [
428414 InputParam (
429415 "batch_size" ,
430416 required = True ,
@@ -497,7 +483,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
497483 return components , state
498484
499485
500- class FluxPrepareLatentsStep (PipelineBlock ):
486+ class FluxPrepareLatentsStep (ModularPipelineBlocks ):
501487 model_name = "flux"
502488
503489 @property
@@ -515,11 +501,6 @@ def inputs(self) -> List[InputParam]:
515501 InputParam ("width" , type_hint = int ),
516502 InputParam ("latents" , type_hint = Optional [torch .Tensor ]),
517503 InputParam ("num_images_per_prompt" , type_hint = int , default = 1 ),
518- ]
519-
520- @property
521- def intermediate_inputs (self ) -> List [InputParam ]:
522- return [
523504 InputParam ("generator" ),
524505 InputParam (
525506 "batch_size" ,
@@ -621,7 +602,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
621602 return components , state
622603
623604
624- class FluxImg2ImgPrepareLatentsStep (PipelineBlock ):
605+ class FluxImg2ImgPrepareLatentsStep (ModularPipelineBlocks ):
625606 model_name = "flux"
626607
627608 @property
@@ -639,11 +620,6 @@ def inputs(self) -> List[Tuple[str, Any]]:
639620 InputParam ("width" , type_hint = int ),
640621 InputParam ("latents" , type_hint = Optional [torch .Tensor ]),
641622 InputParam ("num_images_per_prompt" , type_hint = int , default = 1 ),
642- ]
643-
644- @property
645- def intermediate_inputs (self ) -> List [InputParam ]:
646- return [
647623 InputParam ("generator" ),
648624 InputParam (
649625 "image_latents" ,
0 commit comments