Skip to content

Commit 8ddb20b

Browse files
committed
up
1 parent e5089d7 commit 8ddb20b

File tree

2 files changed

+120
-88
lines changed

2 files changed

+120
-88
lines changed

src/diffusers/pipelines/modular_pipeline.py

Lines changed: 86 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ def add_input(self, key: str, value: Any):
6161
def add_intermediate(self, key: str, value: Any):
6262
self.intermediates[key] = value
6363

64-
def add_output(self, key: str, value: Any):
65-
self.outputs[key] = value
66-
6764
def get_input(self, key: str, default: Any = None) -> Any:
6865
return self.inputs.get(key, default)
6966

@@ -194,45 +191,45 @@ def format_intermediates_short(intermediates_inputs: List[InputParam], required_
194191
Formats intermediate inputs and outputs of a block into a string representation.
195192
196193
Args:
197-
block: Pipeline block with potential intermediates
194+
intermediates_inputs: List of intermediate input parameters
195+
required_intermediates_inputs: List of required intermediate input names
196+
intermediates_outputs: List of intermediate output parameters
198197
199198
Returns:
200-
str: Formatted string like "input1, Required(input2) -> output1, output2"
199+
str: Formatted string like:
200+
Intermediates:
201+
- inputs: Required(latents), dtype
202+
- modified: latents # variables that appear in both inputs and outputs
203+
- outputs: images # new outputs only
201204
"""
202205
# Handle inputs
203206
input_parts = []
204-
205207
for inp in intermediates_inputs:
206-
parts = []
207-
# Check if input is required
208208
if inp.name in required_intermediates_inputs:
209-
parts.append("Required")
210-
211-
# Get base name or modified name
212-
name = inp.name
213-
if name in {out.name for out in intermediates_outputs}:
214-
name = f"*{name}"
215-
216-
# Combine Required() wrapper with possibly starred name
217-
if parts:
218-
input_parts.append(f"Required({name})")
209+
input_parts.append(f"Required({inp.name})")
219210
else:
220-
input_parts.append(name)
211+
input_parts.append(inp.name)
221212

222-
# Handle outputs
223-
output_parts = []
224-
outputs = [out.name for out in intermediates_outputs]
225-
# Only show new outputs if we have inputs
213+
# Handle modified variables (appear in both inputs and outputs)
226214
inputs_set = {inp.name for inp in intermediates_inputs}
227-
outputs = [out for out in outputs if out not in inputs_set]
228-
output_parts.extend(outputs)
215+
modified_parts = []
216+
new_output_parts = []
229217

230-
# Combine with arrow notation if both inputs and outputs exist
231-
if output_parts:
232-
return f"-> {', '.join(output_parts)}" if not input_parts else f"{', '.join(input_parts)} -> {', '.join(output_parts)}"
233-
elif input_parts:
234-
return ', '.join(input_parts)
235-
return ""
218+
for out in intermediates_outputs:
219+
if out.name in inputs_set:
220+
modified_parts.append(out.name)
221+
else:
222+
new_output_parts.append(out.name)
223+
224+
result = []
225+
if input_parts:
226+
result.append(f" - inputs: {', '.join(input_parts)}")
227+
if modified_parts:
228+
result.append(f" - modified: {', '.join(modified_parts)}")
229+
if new_output_parts:
230+
result.append(f" - outputs: {', '.join(new_output_parts)}")
231+
232+
return "\n".join(result) if result else " (none)"
236233

237234

238235
def format_params(params: List[Union[InputParam, OutputParam]], header: str = "Args", indent_level: int = 4, max_line_length: int = 115) -> str:
@@ -323,7 +320,7 @@ def format_output_params(output_params: List[OutputParam], indent_level: int = 4
323320

324321

325322

326-
def make_doc_string(inputs, intermediates_inputs, intermediates_outputs, final_intermediates_outputs=None, description=""):
323+
def make_doc_string(inputs, intermediates_inputs, outputs, description=""):
327324
"""
328325
Generates a formatted documentation string describing the pipeline block's parameters and structure.
329326
@@ -340,20 +337,8 @@ def make_doc_string(inputs, intermediates_inputs, intermediates_outputs, final_i
340337

341338
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
342339

343-
# YiYi TODO: refactor to remove this and `outputs` attribute instead
344-
if final_intermediates_outputs:
345-
output += "\n\n"
346-
output += format_output_params(final_intermediates_outputs, indent_level=2)
347-
348-
if intermediates_outputs:
349-
output += "\n\n------------------------\n"
350-
intermediates_str = format_params(intermediates_outputs, "Intermediates Outputs", indent_level=2)
351-
output += intermediates_str
352-
353-
elif intermediates_outputs:
354-
output +="\n\n"
355-
output += format_output_params(intermediates_outputs, indent_level=2)
356-
340+
output += "\n\n"
341+
output += format_output_params(outputs, indent_level=2)
357342

358343
return output
359344

@@ -367,23 +352,28 @@ class PipelineBlock:
367352

368353
@property
369354
def description(self) -> str:
370-
return ""
355+
"""Description of the block. Must be implemented by subclasses."""
356+
raise NotImplementedError("description method must be implemented in subclasses")
371357

372358
@property
373359
def inputs(self) -> List[InputParam]:
374-
return []
360+
"""List of input parameters. Must be implemented by subclasses."""
361+
raise NotImplementedError("inputs method must be implemented in subclasses")
375362

376363
@property
377364
def intermediates_inputs(self) -> List[InputParam]:
378-
return []
365+
"""List of intermediate input parameters. Must be implemented by subclasses."""
366+
raise NotImplementedError("intermediates_inputs method must be implemented in subclasses")
379367

380368
@property
381369
def intermediates_outputs(self) -> List[OutputParam]:
382-
return []
383-
370+
"""List of intermediate output parameters. Must be implemented by subclasses."""
371+
raise NotImplementedError("intermediates_outputs method must be implemented in subclasses")
372+
373+
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
384374
@property
385375
def outputs(self) -> List[OutputParam]:
386-
return []
376+
return self.intermediates_outputs
387377

388378
@property
389379
def required_inputs(self) -> List[str]:
@@ -413,7 +403,7 @@ def __repr__(self):
413403
class_name = self.__class__.__name__
414404
base_class = self.__class__.__bases__[0].__name__
415405

416-
# Components section - group into main components and auxiliaries if needed
406+
# Components section
417407
expected_components = set(getattr(self, "expected_components", []))
418408
loaded_components = set(self.components.keys())
419409
all_components = sorted(expected_components | loaded_components)
@@ -446,7 +436,7 @@ def __repr__(self):
446436

447437
# Intermediates section
448438
intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs)
449-
intermediates = f"Intermediates(`*` = modified):\n {intermediates_str}"
439+
intermediates = f"Intermediates:\n{intermediates_str}"
450440

451441
return (
452442
f"{class_name}(\n"
@@ -461,7 +451,7 @@ def __repr__(self):
461451

462452
@property
463453
def doc(self):
464-
return make_doc_string(self.inputs, self.intermediates_inputs, self.intermediates_outputs, None, self.description)
454+
return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description)
465455

466456

467457
def get_block_state(self, state: PipelineState) -> dict:
@@ -489,11 +479,6 @@ def add_block_state(self, state: PipelineState, block_state: BlockState):
489479
if not hasattr(block_state, output_param.name):
490480
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
491481
state.add_intermediate(output_param.name, getattr(block_state, output_param.name))
492-
493-
for output_param in self.outputs:
494-
if not hasattr(block_state, output_param.name):
495-
raise ValueError(f"Output '{output_param.name}' is missing in block state")
496-
state.add_output(output_param.name, getattr(block_state, output_param.name))
497482

498483

499484
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
@@ -704,6 +689,12 @@ def intermediates_outputs(self) -> List[str]:
704689
named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()]
705690
combined_outputs = combine_outputs(*named_outputs)
706691
return combined_outputs
692+
693+
@property
694+
def outputs(self) -> List[str]:
695+
named_outputs = [(name, block.outputs) for name, block in self.blocks.items()]
696+
combined_outputs = combine_outputs(*named_outputs)
697+
return combined_outputs
707698

708699
@torch.no_grad()
709700
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
@@ -803,10 +794,21 @@ def __repr__(self):
803794
sections.append(f" Block: {block.__class__.__name__}")
804795

805796
inputs_str = format_inputs_short(block.inputs)
806-
sections.append(f" inputs:\n {inputs_str}")
797+
sections.append(f" inputs: {inputs_str}")
807798

808-
intermediates_str = f" intermediates(`*` = modified):\n {format_intermediates_short(block.intermediates_inputs, block.required_intermediates_inputs, block.intermediates_outputs)}"
809-
sections.append(intermediates_str)
799+
# Format intermediates with proper indentation
800+
intermediates_str = format_intermediates_short(
801+
block.intermediates_inputs,
802+
block.required_intermediates_inputs,
803+
block.intermediates_outputs
804+
)
805+
if intermediates_str != " (none)": # Only add if there are intermediates
806+
sections.append(" intermediates:")
807+
# Add extra indentation to each line of intermediates
808+
indented_intermediates = "\n".join(
809+
" " + line for line in intermediates_str.split("\n")
810+
)
811+
sections.append(indented_intermediates)
810812

811813
sections.append("")
812814

@@ -819,7 +821,7 @@ def __repr__(self):
819821

820822
@property
821823
def doc(self):
822-
return make_doc_string(self.inputs, self.intermediates_inputs, self.intermediates_outputs, None, self.description)
824+
return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description)
823825

824826
class SequentialPipelineBlocks:
825827
"""
@@ -962,7 +964,7 @@ def intermediates_outputs(self) -> List[str]:
962964
return combined_outputs
963965

964966
@property
965-
def final_intermediates_outputs(self) -> List[str]:
967+
def outputs(self) -> List[str]:
966968
return next(reversed(self.blocks.values())).intermediates_outputs
967969

968970
@torch.no_grad()
@@ -1121,28 +1123,34 @@ def __repr__(self):
11211123
for i, (name, block) in enumerate(self.blocks.items()):
11221124
blocks_str += f" {i}. {name} ({block.__class__.__name__})\n"
11231125

1126+
# Format inputs
11241127
inputs_str = format_inputs_short(block.inputs)
1125-
11261128
blocks_str += f" inputs: {inputs_str}\n"
11271129

1128-
intermediates_str = format_intermediates_short(block.intermediates_inputs, block.required_intermediates_inputs, block.intermediates_outputs)
1129-
1130-
if intermediates_str:
1131-
blocks_str += f" intermediates(`*` = modified): {intermediates_str}\n"
1130+
# Format intermediates with proper indentation
1131+
intermediates_str = format_intermediates_short(
1132+
block.intermediates_inputs,
1133+
block.required_intermediates_inputs,
1134+
block.intermediates_outputs
1135+
)
1136+
if intermediates_str != " (none)": # Only add if there are intermediates
1137+
blocks_str += " intermediates:\n"
1138+
# Add extra indentation to each line of intermediates
1139+
indented_intermediates = "\n".join(
1140+
" " + line for line in intermediates_str.split("\n")
1141+
)
1142+
blocks_str += f"{indented_intermediates}\n"
11321143
blocks_str += "\n"
11331144

11341145
inputs_str = format_inputs_short(self.inputs)
11351146
inputs_str = " Inputs:\n " + inputs_str
1136-
final_intermediates_outputs = [out.name for out in self.final_intermediates_outputs]
1147+
outputs = [out.name for out in self.outputs]
11371148

1138-
intermediates_str_short = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs)
1139-
intermediates_input_str = intermediates_str_short.split('->')[0].strip() # "Required(latents), crops_coords"
1140-
intermediates_output_str = intermediates_str_short.split('->')[1].strip()
1149+
intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs)
11411150
intermediates_str = (
11421151
"\n Intermediates:\n"
1143-
f" - inputs: {intermediates_input_str}\n"
1144-
f" - outputs: {intermediates_output_str}\n"
1145-
f" - final outputs: {', '.join(final_intermediates_outputs)}"
1152+
f"{intermediates_str}\n"
1153+
f" - final outputs: {', '.join(outputs)}"
11461154
)
11471155

11481156
return (
@@ -1158,7 +1166,7 @@ def __repr__(self):
11581166

11591167
@property
11601168
def doc(self):
1161-
return make_doc_string(self.inputs, self.intermediates_inputs, self.intermediates_outputs, self.final_intermediates_outputs, self.description)
1169+
return make_doc_string(self.inputs, self.intermediates_inputs, self.outputs, self.description)
11621170

11631171
class ModularPipeline(ConfigMixin):
11641172
"""

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,19 @@ def description(self) -> str:
139139
" for more details"
140140
)
141141

142+
143+
@property
144+
def inputs(self) -> List[InputParam]:
145+
return []
146+
147+
@property
148+
def intermediates_inputs(self) -> List[InputParam]:
149+
return []
150+
151+
@property
152+
def intermediates_outputs(self) -> List[OutputParam]:
153+
return []
154+
142155
def __init__(self):
143156
super().__init__()
144157
self.components["text_encoder"] = None
@@ -178,11 +191,17 @@ def inputs(self) -> List[InputParam]:
178191
description="Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale is enabled by setting `guidance_scale > 1`."
179192
),
180193
]
181-
194+
182195
@property
183-
def intermediates_outputs(self) -> List[str]:
184-
return [OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
185-
OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings")]
196+
def intermediates_inputs(self) -> List[InputParam]:
197+
return []
198+
199+
@property
200+
def intermediates_outputs(self) -> List[OutputParam]:
201+
return [
202+
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
203+
OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings")
204+
]
186205

187206
def __init__(self):
188207
super().__init__()
@@ -270,7 +289,11 @@ def inputs(self) -> List[InputParam]:
270289

271290

272291
@property
273-
def intermediates_outputs(self) -> List[str]:
292+
def intermediates_inputs(self) -> List[InputParam]:
293+
return []
294+
295+
@property
296+
def intermediates_outputs(self) -> List[OutputParam]:
274297
return [
275298
OutputParam("prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation"),
276299
OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="negative text embeddings used to guide the image generation"),
@@ -378,13 +401,13 @@ def inputs(self) -> List[InputParam]:
378401
]
379402

380403
@property
381-
def intermediates_inputs(self) -> List[str]:
404+
def intermediates_inputs(self) -> List[InputParam]:
382405
return [
383406
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
384407
InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")]
385408

386409
@property
387-
def intermediates_outputs(self) -> List[str]:
410+
def intermediates_outputs(self) -> List[OutputParam]:
388411
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")]
389412

390413
def __init__(self):
@@ -818,6 +841,10 @@ def intermediates_outputs(self) -> List[OutputParam]:
818841
return [OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
819842
OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps to perform at inference time")]
820843

844+
@property
845+
def intermediates_inputs(self) -> List[InputParam]:
846+
return []
847+
821848
def __init__(self):
822849
super().__init__()
823850
self.components["scheduler"] = None
@@ -2831,9 +2858,6 @@ def intermediates_inputs(self) -> List[str]:
28312858
def intermediates_outputs(self) -> List[str]:
28322859
return [OutputParam("images", description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`")]
28332860

2834-
@property
2835-
def outputs(self) -> List[Tuple[str, Any]]:
2836-
return [(OutputParam("images", type_hint=Union[Tuple[PIL.Image.Image], StableDiffusionXLPipelineOutput], description="The final images output, can be a tuple or a `StableDiffusionXLPipelineOutput`"))]
28372861

28382862
@torch.no_grad()
28392863
def __call__(self, pipeline, state: PipelineState) -> PipelineState:

0 commit comments

Comments
 (0)