1919from ...models import FluxTransformer2DModel
2020from ...schedulers import FlowMatchEulerDiscreteScheduler
2121from ...utils import logging
22+ from ...configuration_utils import FrozenDict
23+ from ...guiders import ClassifierFreeGuidance
2224from ..modular_pipeline import (
2325 BlockState ,
2426 LoopSequentialPipelineBlocks ,
@@ -37,7 +39,9 @@ class FluxLoopDenoiser(PipelineBlock):
3739
3840 @property
3941 def expected_components (self ) -> List [ComponentSpec ]:
40- return [ComponentSpec ("transformer" , FluxTransformer2DModel )]
42+ return [
43+ ComponentSpec ("transformer" , FluxTransformer2DModel )
44+ ]
4145
4246 @property
4347 def description (self ) -> str :
@@ -49,9 +53,7 @@ def description(self) -> str:
4953
5054 @property
5155 def inputs (self ) -> List [Tuple [str , Any ]]:
52- return [
53- InputParam ("attention_kwargs" ),
54- ]
56+ return [InputParam ("joint_attention_kwargs" )]
5557
5658 @property
5759 def intermediate_inputs (self ) -> List [str ]:
@@ -63,10 +65,34 @@ def intermediate_inputs(self) -> List[str]:
6365 description = "The initial latents to use for the denoising process. Can be generated in prepare_latent step." ,
6466 ),
6567 InputParam (
66- "num_inference_steps " ,
68+ "guidance " ,
6769 required = True ,
68- type_hint = int ,
69- description = "The number of inference steps to use for the denoising process. Can be generated in set_timesteps step." ,
70+ type_hint = torch .Tensor ,
71+ description = "Guidance scale as a tensor" ,
72+ ),
73+ InputParam (
74+ "prompt_embeds" ,
75+ required = True ,
76+ type_hint = torch .Tensor ,
77+ description = "Prompt embeddings" ,
78+ ),
79+ InputParam (
80+ "pooled_prompt_embeds" ,
81+ required = True ,
82+ type_hint = torch .Tensor ,
83+ description = "Pooled prompt embeddings" ,
84+ ),
85+ InputParam (
86+ "text_ids" ,
87+ required = True ,
88+ type_hint = torch .Tensor ,
89+ description = "IDs computed from text sequence needed for RoPE" ,
90+ ),
91+ InputParam (
92+ "latent_image_ids" ,
93+ required = True ,
94+ type_hint = torch .Tensor ,
95+ description = "IDs computed from image sequence needed for RoPE" ,
7096 ),
7197 # TODO: guidance
7298 ]
@@ -78,9 +104,10 @@ def __call__(
78104 noise_pred = components .transformer (
79105 hidden_states = block_state .latents ,
80106 timestep = t .flatten () / 1000 ,
107+ guidance = block_state .guidance ,
81108 encoder_hidden_states = block_state .prompt_embeds ,
82109 pooled_projections = block_state .pooled_prompt_embeds ,
83- attention_kwargs = block_state .attention_kwargs ,
110+ joint_attention_kwargs = block_state .joint_attention_kwargs ,
84111 txt_ids = block_state .text_ids ,
85112 img_ids = block_state .latent_image_ids ,
86113 return_dict = False ,
@@ -96,7 +123,7 @@ class FluxLoopAfterDenoiser(PipelineBlock):
96123 @property
97124 def expected_components (self ) -> List [ComponentSpec ]:
98125 return [
99- ComponentSpec ("scheduler" , FlowMatchEulerDiscreteScheduler ),
126+ ComponentSpec ("scheduler" , FlowMatchEulerDiscreteScheduler )
100127 ]
101128
102129 @property
@@ -113,9 +140,7 @@ def inputs(self) -> List[Tuple[str, Any]]:
113140
114141 @property
115142 def intermediate_inputs (self ) -> List [str ]:
116- return [
117- InputParam ("generator" ),
118- ]
143+ return [InputParam ("generator" )]
119144
120145 @property
121146 def intermediate_outputs (self ) -> List [OutputParam ]:
@@ -129,7 +154,6 @@ def __call__(self, components: FluxModularPipeline, block_state: BlockState, i:
129154 block_state .noise_pred ,
130155 t ,
131156 block_state .latents ,
132- ** block_state .scheduler_step_kwargs ,
133157 return_dict = False ,
134158 )[0 ]
135159
@@ -199,9 +223,9 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
199223class FluxDenoiseStep (FluxDenoiseLoopWrapper ):
200224 block_classes = [
201225 FluxLoopDenoiser ,
202- FluxLoopAfterDenoiser ,
226+ FluxLoopAfterDenoiser
203227 ]
204- block_names = ["before_denoiser" , " denoiser" , "after_denoiser" ]
228+ block_names = ["denoiser" , "after_denoiser" ]
205229
206230 @property
207231 def description (self ) -> str :
0 commit comments