Skip to content

Commit addaad0

Browse files
committed
more more more refactor
1 parent 485f8d1 commit addaad0

File tree

4 files changed

+138
-75
lines changed

4 files changed

+138
-75
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@
406406
"StableDiffusionXLPAGInpaintPipeline",
407407
"StableDiffusionXLPAGPipeline",
408408
"StableDiffusionXLPipeline",
409+
"StableDiffusionXLAutoPipeline",
409410
"StableUnCLIPImg2ImgPipeline",
410411
"StableUnCLIPPipeline",
411412
"StableVideoDiffusionPipeline",
@@ -897,6 +898,7 @@
897898
StableDiffusionXLPAGInpaintPipeline,
898899
StableDiffusionXLPAGPipeline,
899900
StableDiffusionXLPipeline,
901+
StableDiffusionXLAutoPipeline,
900902
StableUnCLIPImg2ImgPipeline,
901903
StableUnCLIPPipeline,
902904
StableVideoDiffusionPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@
317317
"StableDiffusionXLInstructPix2PixPipeline",
318318
"StableDiffusionXLPipeline",
319319
"StableDiffusionXLModularPipeline",
320+
"StableDiffusionXLAutoPipeline",
320321
]
321322
)
322323
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
@@ -667,6 +668,7 @@
667668
StableDiffusionXLInstructPix2PixPipeline,
668669
StableDiffusionXLModularPipeline,
669670
StableDiffusionXLPipeline,
671+
StableDiffusionXLAutoPipeline,
670672
)
671673
from .stable_video_diffusion import StableVideoDiffusionPipeline
672674
from .t2i_adapter import (

src/diffusers/pipelines/modular_pipeline.py

Lines changed: 132 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,22 @@ def trigger_inputs(self):
765765
def __repr__(self):
766766
class_name = self.__class__.__name__
767767
base_class = self.__class__.__bases__[0].__name__
768+
header = (
769+
f"{class_name}(\n Class: {base_class}\n"
770+
if base_class and base_class != "object"
771+
else f"{class_name}(\n"
772+
)
773+
774+
775+
if self.trigger_inputs:
776+
header += "\n"
777+
header += " " + "=" * 100 + "\n"
778+
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
779+
header += f" Trigger Inputs: {self.trigger_inputs}\n"
780+
# Get first trigger input as example
781+
example_input = next(t for t in self.trigger_inputs if t is not None)
782+
header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
783+
header += " " + "=" * 100 + "\n\n"
768784

769785
# Format description with proper indentation
770786
desc_lines = self.description.split('\n')
@@ -776,70 +792,92 @@ def __repr__(self):
776792
desc.extend(f" {line}" for line in desc_lines[1:])
777793
desc = '\n'.join(desc) + '\n'
778794

779-
sections = []
780-
all_triggers = set(self.trigger_to_block_map.keys())
781-
for trigger in sorted(all_triggers, key=lambda x: str(x)):
782-
sections.append(f"\n Trigger Input: {trigger}\n")
783-
784-
block = self.trigger_to_block_map.get(trigger)
785-
if block is None:
786-
continue
795+
# Components section
796+
expected_components = set(getattr(self, "expected_components", []))
797+
loaded_components = set(self.components.keys())
798+
all_components = sorted(expected_components | loaded_components)
799+
components_str = " Components:\n" + "\n".join(
800+
f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}"
801+
for k in all_components
802+
)
787803

788-
# Add block description with proper indentation
804+
# Auxiliaries section
805+
auxiliaries_str = " Auxiliaries:\n" + "\n".join(
806+
f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items()
807+
)
808+
809+
# Configs section
810+
expected_configs = set(getattr(self, "expected_configs", []))
811+
loaded_configs = set(self.configs.keys())
812+
all_configs = sorted(expected_configs | loaded_configs)
813+
configs_str = " Configs:\n" + "\n".join(
814+
f" - {k}={v}" if k in loaded_configs else f" - {k}" for k, v in self.configs.items()
815+
)
816+
817+
blocks_str = " Blocks:\n"
818+
for i, (name, block) in enumerate(self.blocks.items()):
819+
# Get trigger input for this block
820+
trigger = None
821+
if hasattr(self, 'block_to_trigger_map'):
822+
trigger = self.block_to_trigger_map.get(name)
823+
# Format the trigger info
824+
if trigger is None:
825+
trigger_str = "[default]"
826+
elif isinstance(trigger, (list, tuple)):
827+
trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
828+
else:
829+
trigger_str = f"[trigger: {trigger}]"
830+
# For AutoPipelineBlocks, add bullet points
831+
blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
832+
else:
833+
# For SequentialPipelineBlocks, show execution order
834+
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
835+
836+
# Add block description
789837
desc_lines = block.description.split('\n')
790-
# First line starts right after "Description:", subsequent lines get indented
791838
indented_desc = desc_lines[0]
792839
if len(desc_lines) > 1:
793-
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) # Align with first line
794-
sections.append(f" Description: {indented_desc}\n")
795-
796-
expected_components = set(getattr(block, "expected_components", []))
797-
loaded_components = set(k for k, v in self.components.items()
798-
if v is not None and hasattr(block, k))
799-
all_components = sorted(expected_components | loaded_components)
800-
if all_components:
801-
sections.append(" Components:\n" + "\n".join(
802-
f" - {k}={type(self.components[k]).__name__}" if k in loaded_components
803-
else f" - {k}" for k in all_components
804-
))
805-
806-
if self.auxiliaries:
807-
sections.append(" Auxiliaries:\n" + "\n".join(
808-
f" - {k}={type(v).__name__}"
809-
for k, v in self.auxiliaries.items()
810-
))
811-
812-
if self.configs:
813-
sections.append(" Configs:\n" + "\n".join(
814-
f" - {k}={v}" for k, v in self.configs.items()
815-
))
816-
817-
sections.append(f" Block: {block.__class__.__name__}")
818-
840+
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
841+
blocks_str += f" Description: {indented_desc}\n"
842+
843+
# Format inputs
819844
inputs_str = format_inputs_short(block.inputs)
820-
sections.append(f" inputs: {inputs_str}")
845+
blocks_str += f" inputs: {inputs_str}\n"
821846

822-
# Format intermediates with proper indentation
847+
# Format intermediates
823848
intermediates_str = format_intermediates_short(
824-
block.intermediates_inputs,
825-
block.required_intermediates_inputs,
849+
block.intermediates_inputs,
850+
block.required_intermediates_inputs,
826851
block.intermediates_outputs
827852
)
828-
if intermediates_str != " (none)": # Only add if there are intermediates
829-
sections.append(" intermediates:")
830-
# Add extra indentation to each line of intermediates
853+
if intermediates_str != " (none)":
854+
blocks_str += " intermediates:\n"
831855
indented_intermediates = "\n".join(
832856
" " + line for line in intermediates_str.split("\n")
833857
)
834-
sections.append(indented_intermediates)
835-
836-
sections.append("")
858+
blocks_str += f"{indented_intermediates}\n"
859+
blocks_str += "\n"
860+
861+
inputs_str = format_inputs_short(self.inputs)
862+
inputs_str = " Inputs:\n " + inputs_str
863+
outputs = [out.name for out in self.outputs]
864+
865+
intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs)
866+
intermediates_str = (
867+
"\n Intermediates:\n"
868+
f"{intermediates_str}\n"
869+
f" - final outputs: {', '.join(outputs)}"
870+
)
837871

838872
return (
839-
f"{class_name}(\n"
840-
f" Class: {base_class}\n"
873+
f"{header}\n"
841874
f"{desc}"
842-
f"{chr(10).join(sections)}"
875+
f"{components_str}\n"
876+
f"{auxiliaries_str}\n"
877+
f"{configs_str}\n"
878+
f"{blocks_str}\n"
879+
f"{inputs_str}\n"
880+
f"{intermediates_str}\n"
843881
f")"
844882
)
845883

@@ -1097,7 +1135,7 @@ def fn_recursive_traverse(block, block_name, active_triggers):
10971135
all_blocks.update(blocks_to_update)
10981136
return all_blocks
10991137

1100-
def get_triggered_blocks(self, *trigger_inputs):
1138+
def get_execution_blocks(self, *trigger_inputs):
11011139
trigger_inputs_all = self.trigger_inputs
11021140

11031141
if trigger_inputs is not None:
@@ -1130,14 +1168,14 @@ def __repr__(self):
11301168

11311169

11321170
if self.trigger_inputs:
1133-
header += "\n" # Add empty line before
1134-
header += " " + "=" * 100 + "\n" # Add decorative line
1135-
header += " This pipeline block contains dynamic blocks that are selected at runtime based on your inputs.\n"
1136-
header += " You can use `get_triggered_blocks(input1, input2,...)` to see which blocks will be used for your trigger inputs.\n"
1137-
header += " Use `get_triggered_blocks()` to see blocks will be used for default inputs (when no trigger inputs are provided)\n"
1171+
header += "\n"
1172+
header += " " + "=" * 100 + "\n"
1173+
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
11381174
header += f" Trigger Inputs: {self.trigger_inputs}\n"
1139-
header += " " + "=" * 100 + "\n" # Add decorative line
1140-
header += "\n" # Add empty line after
1175+
# Get first trigger input as example
1176+
example_input = next(t for t in self.trigger_inputs if t is not None)
1177+
header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
1178+
header += " " + "=" * 100 + "\n\n"
11411179

11421180
# Format description with proper indentation
11431181
desc_lines = self.description.split('\n')
@@ -1173,28 +1211,42 @@ def __repr__(self):
11731211

11741212
blocks_str = " Blocks:\n"
11751213
for i, (name, block) in enumerate(self.blocks.items()):
1176-
blocks_str += f" {i}. {name} ({block.__class__.__name__})\n"
1214+
# Get trigger input for this block
1215+
trigger = None
1216+
if hasattr(self, 'block_to_trigger_map'):
1217+
trigger = self.block_to_trigger_map.get(name)
1218+
# Format the trigger info
1219+
if trigger is None:
1220+
trigger_str = "[default]"
1221+
elif isinstance(trigger, (list, tuple)):
1222+
trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
1223+
else:
1224+
trigger_str = f"[trigger: {trigger}]"
1225+
# For AutoPipelineBlocks, add bullet points
1226+
blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
1227+
else:
1228+
# For SequentialPipelineBlocks, show execution order
1229+
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
11771230

1231+
# Add block description
11781232
desc_lines = block.description.split('\n')
1179-
# First line starts right after "Description:", subsequent lines get indented
11801233
indented_desc = desc_lines[0]
11811234
if len(desc_lines) > 1:
1182-
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) # Align with first line
1235+
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
11831236
blocks_str += f" Description: {indented_desc}\n"
11841237

11851238
# Format inputs
11861239
inputs_str = format_inputs_short(block.inputs)
11871240
blocks_str += f" inputs: {inputs_str}\n"
11881241

1189-
# Format intermediates with proper indentation
1242+
# Format intermediates
11901243
intermediates_str = format_intermediates_short(
1191-
block.intermediates_inputs,
1192-
block.required_intermediates_inputs,
1244+
block.intermediates_inputs,
1245+
block.required_intermediates_inputs,
11931246
block.intermediates_outputs
11941247
)
1195-
if intermediates_str != " (none)": # Only add if there are intermediates
1248+
if intermediates_str != " (none)":
11961249
blocks_str += " intermediates:\n"
1197-
# Add extra indentation to each line of intermediates
11981250
indented_intermediates = "\n".join(
11991251
" " + line for line in intermediates_str.split("\n")
12001252
)
@@ -1295,6 +1347,10 @@ def _execution_device(self):
12951347
return torch.device(module._hf_hook.execution_device)
12961348
return self.device
12971349

1350+
1351+
def get_execution_blocks(self, *trigger_inputs):
1352+
return self.pipeline_block.get_execution_blocks(*trigger_inputs)
1353+
12981354
@property
12991355
def dtype(self) -> torch.dtype:
13001356
r"""
@@ -1449,16 +1505,7 @@ def __repr__(self):
14491505

14501506
block = self.pipeline_block
14511507

1452-
if hasattr(block, "trigger_inputs") and block.trigger_inputs:
1453-
output += "\n"
1454-
output += " Trigger Inputs:\n"
1455-
output += " --------------\n"
1456-
output += f" This pipeline contains dynamic blocks that are selected at runtime based on your inputs.\n"
1457-
output += f" • Trigger inputs: {block.trigger_inputs}\n"
1458-
output += f" • Use .pipeline_block.get_triggered_blocks(*inputs) to see which blocks will be used for specific inputs\n"
1459-
output += f" • Use .pipeline_block.get_triggered_blocks() to see blocks will be used for default inputs (when no trigger inputs are provided)\n"
1460-
output += "\n"
1461-
1508+
# List the pipeline block structure first
14621509
output += "Pipeline Block:\n"
14631510
output += "--------------\n"
14641511
if hasattr(block, "blocks"):
@@ -1493,6 +1540,16 @@ def __repr__(self):
14931540
output += f"{name}: {config!r}\n"
14941541
output += "\n"
14951542

1543+
# Add auto blocks section
1544+
if hasattr(block, "trigger_inputs") and block.trigger_inputs:
1545+
output += "------------------\n"
1546+
output += "This pipeline contains blocks that are selected at runtime based on inputs.\n\n"
1547+
output += f"Trigger Inputs: {block.trigger_inputs}\n"
1548+
# Get first trigger input as example
1549+
example_input = next(t for t in block.trigger_inputs if t is not None)
1550+
output += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
1551+
output += "Check `.doc` of returned object for more information.\n\n"
1552+
14961553
# List the call parameters
14971554
full_doc = self.pipeline_block.doc
14981555
if "------------------------" in full_doc:

src/diffusers/pipelines/stable_diffusion_xl/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"StableDiffusionXLPrepareLatentsStep",
4040
"StableDiffusionXLSetTimestepsStep",
4141
"StableDiffusionXLTextEncoderStep",
42+
"StableDiffusionXLAutoPipeline",
4243
]
4344

4445
if is_transformers_available() and is_flax_available():
@@ -69,6 +70,7 @@
6970
StableDiffusionXLPrepareLatentsStep,
7071
StableDiffusionXLSetTimestepsStep,
7172
StableDiffusionXLTextEncoderStep,
73+
StableDiffusionXLAutoPipeline,
7274
)
7375

7476
try:

0 commit comments

Comments
 (0)