Skip to content

Commit 144eae4

Browse files
committed
add block state will also make sure modifed intermediates_inputs will be updated
1 parent 796453c commit 144eae4

File tree

1 file changed

+206
-35
lines changed

1 file changed

+206
-35
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 206 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None,
282282
state = PipelineState()
283283

284284
if not hasattr(self, "loader"):
285-
logger.warning("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.")
285+
logger.info("Loader is not set, please call `setup_loader()` if you need to load checkpoints for your pipeline.")
286286
self.loader = None
287287

288288
# Make a copy of the input kwargs
@@ -313,7 +313,7 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None,
313313

314314
# Warn about unexpected inputs
315315
if len(passed_kwargs) > 0:
316-
logger.warning(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
316+
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
317317
# Run the pipeline
318318
with torch.no_grad():
319319
try:
@@ -373,7 +373,6 @@ def expected_configs(self) -> List[ConfigSpec]:
373373
return []
374374

375375

376-
# YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable
377376
@property
378377
def inputs(self) -> List[InputParam]:
379378
"""List of input parameters. Must be implemented by subclasses."""
@@ -389,27 +388,40 @@ def intermediates_outputs(self) -> List[OutputParam]:
389388
"""List of intermediate output parameters. Must be implemented by subclasses."""
390389
return []
391390

391+
def _get_outputs(self):
392+
return self.intermediates_outputs
393+
394+
# YiYi TODO: is it too easy for user to unintentionally override these properties?
392395
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
393396
@property
394397
def outputs(self) -> List[OutputParam]:
395-
return self.intermediates_outputs
398+
return self._get_outputs()
396399

397-
@property
398-
def required_inputs(self) -> List[str]:
400+
def _get_required_inputs(self):
399401
input_names = []
400402
for input_param in self.inputs:
401403
if input_param.required:
402404
input_names.append(input_param.name)
403405
return input_names
404406

405407
@property
406-
def required_intermediates_inputs(self) -> List[str]:
408+
def required_inputs(self) -> List[str]:
409+
return self._get_required_inputs()
410+
411+
412+
def _get_required_intermediates_inputs(self):
407413
input_names = []
408414
for input_param in self.intermediates_inputs:
409415
if input_param.required:
410416
input_names.append(input_param.name)
411417
return input_names
412418

419+
# YiYi TODO: maybe we do not need this, it is only used in docstring,
420+
# intermediate_inputs is by default required, unless you manually handle it inside the block
421+
@property
422+
def required_intermediates_inputs(self) -> List[str]:
423+
return self._get_required_intermediates_inputs()
424+
413425

414426
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
415427
raise NotImplementedError("__call__ method must be implemented in subclasses")
@@ -521,6 +533,30 @@ def add_block_state(self, state: PipelineState, block_state: BlockState):
521533
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
522534
param = getattr(block_state, output_param.name)
523535
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
536+
537+
for input_param in self.intermediates_inputs:
538+
if hasattr(block_state, input_param.name):
539+
param = getattr(block_state, input_param.name)
540+
# Only add if the value is different from what's in the state
541+
current_value = state.get_intermediate(input_param.name)
542+
if current_value is not param: # Using identity comparison to check if object was modified
543+
state.add_intermediate(input_param.name, param, input_param.kwargs_type)
544+
545+
for input_param in self.intermediates_inputs:
546+
if input_param.name and hasattr(block_state, input_param.name):
547+
param = getattr(block_state, input_param.name)
548+
# Only add if the value is different from what's in the state
549+
current_value = state.get_intermediate(input_param.name)
550+
if current_value is not param: # Using identity comparison to check if object was modified
551+
state.add_intermediate(input_param.name, param, input_param.kwargs_type)
552+
elif input_param.kwargs_type:
553+
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
554+
# we need to first find out which inputs are and loop through them.
555+
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
556+
for param_name, current_value in intermediates_kwargs.items():
557+
param = getattr(block_state, param_name)
558+
if current_value is not param: # Using identity comparison to check if object was modified
559+
state.add_intermediate(param_name, param, input_param.kwargs_type)
524560

525561

526562
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
@@ -550,16 +586,16 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li
550586
input_param.default is not None and
551587
current_param.default != input_param.default):
552588
warnings.warn(
553-
f"Multiple different default values found for input '{input_param.name}': "
554-
f"{current_param.default} (from block '{value_sources[input_param.name]}') and "
589+
f"Multiple different default values found for input '{input_name}': "
590+
f"{current_param.default} (from block '{value_sources[input_name]}') and "
555591
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
556592
)
557593
if current_param.default is None and input_param.default is not None:
558-
combined_dict[input_param.name] = input_param
559-
value_sources[input_param.name] = block_name
594+
combined_dict[input_name] = input_param
595+
value_sources[input_name] = block_name
560596
else:
561-
combined_dict[input_param.name] = input_param
562-
value_sources[input_param.name] = block_name
597+
combined_dict[input_name] = input_param
598+
value_sources[input_name] = block_name
563599

564600
return list(combined_dict.values())
565601

@@ -661,7 +697,9 @@ def required_inputs(self) -> List[str]:
661697
required_by_all.intersection_update(block_required)
662698

663699
return list(required_by_all)
664-
700+
701+
# YiYi TODO: maybe we do not need this, it is only used in docstring,
702+
# intermediate_inputs is by default required, unless you manually handle it inside the block
665703
@property
666704
def required_intermediates_inputs(self) -> List[str]:
667705
first_block = next(iter(self.blocks.values()))
@@ -838,14 +876,21 @@ def __repr__(self):
838876
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
839877
blocks_str += f" Description: {indented_desc}\n\n"
840878

841-
return (
842-
f"{header}\n"
843-
f"{desc}\n\n"
844-
f"{components_str}\n\n"
845-
f"{configs_str}\n\n"
846-
f"{blocks_str}"
847-
f")"
848-
)
879+
# Build the representation with conditional sections
880+
result = f"{header}\n{desc}"
881+
882+
# Only add components section if it has content
883+
if components_str.strip():
884+
result += f"\n\n{components_str}"
885+
886+
# Only add configs section if it has content
887+
if configs_str.strip():
888+
result += f"\n\n{configs_str}"
889+
890+
# Always add blocks section
891+
result += f"\n\n{blocks_str})"
892+
893+
return result
849894

850895

851896
@property
@@ -867,13 +912,15 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
867912
block_classes = []
868913
block_names = []
869914

870-
@property
871-
def model_name(self):
872-
return next(iter(self.blocks.values())).model_name
873915

874916
@property
875917
def description(self):
876918
return ""
919+
920+
@property
921+
def model_name(self):
922+
return next(iter(self.blocks.values())).model_name
923+
877924

878925
@property
879926
def expected_components(self):
@@ -929,6 +976,8 @@ def required_inputs(self) -> List[str]:
929976

930977
return list(required_by_any)
931978

979+
# YiYi TODO: maybe we do not need this, it is only used in docstring,
980+
# intermediate_inputs is by default required, unless you manually handle it inside the block
932981
@property
933982
def required_intermediates_inputs(self) -> List[str]:
934983
required_intermediates_inputs = []
@@ -960,11 +1009,15 @@ def intermediates_inputs(self) -> List[str]:
9601009
def get_intermediates_inputs(self):
9611010
inputs = []
9621011
outputs = set()
1012+
added_inputs = set()
9631013

9641014
# Go through all blocks in order
9651015
for block in self.blocks.values():
9661016
# Add inputs that aren't in outputs yet
967-
inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs)
1017+
for inp in block.intermediates_inputs:
1018+
if inp.name not in outputs and inp.name not in added_inputs:
1019+
inputs.append(inp)
1020+
added_inputs.add(inp.name)
9681021

9691022
# Only add outputs if the block cannot be skipped
9701023
should_add_outputs = True
@@ -1176,14 +1229,21 @@ def __repr__(self):
11761229
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
11771230
blocks_str += f" Description: {indented_desc}\n\n"
11781231

1179-
return (
1180-
f"{header}\n"
1181-
f"{desc}\n\n"
1182-
f"{components_str}\n\n"
1183-
f"{configs_str}\n\n"
1184-
f"{blocks_str}"
1185-
f")"
1186-
)
1232+
# Build the representation with conditional sections
1233+
result = f"{header}\n{desc}"
1234+
1235+
# Only add components section if it has content
1236+
if components_str.strip():
1237+
result += f"\n\n{components_str}"
1238+
1239+
# Only add configs section if it has content
1240+
if configs_str.strip():
1241+
result += f"\n\n{configs_str}"
1242+
1243+
# Always add blocks section
1244+
result += f"\n\n{blocks_str})"
1245+
1246+
return result
11871247

11881248

11891249
@property
@@ -1348,7 +1408,8 @@ def required_inputs(self) -> List[str]:
13481408

13491409
return list(required_by_any)
13501410

1351-
# modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block
1411+
# YiYi TODO: maybe we do not need this, it is only used in docstring,
1412+
# intermediate_inputs is by default required, unless you manually handle it inside the block
13521413
@property
13531414
def required_intermediates_inputs(self) -> List[str]:
13541415
required_intermediates_inputs = []
@@ -1384,6 +1445,22 @@ def __init__(self):
13841445
for block_name, block_cls in zip(self.block_names, self.block_classes):
13851446
blocks[block_name] = block_cls()
13861447
self.blocks = blocks
1448+
1449+
@classmethod
1450+
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks":
1451+
"""Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks.
1452+
1453+
Args:
1454+
blocks_dict: Dictionary mapping block names to block instances
1455+
1456+
Returns:
1457+
A new LoopSequentialPipelineBlocks instance
1458+
"""
1459+
instance = cls()
1460+
instance.block_classes = [block.__class__ for block in blocks_dict.values()]
1461+
instance.block_names = list(blocks_dict.keys())
1462+
instance.blocks = blocks_dict
1463+
return instance
13871464

13881465
def loop_step(self, components, state: PipelineState, **kwargs):
13891466

@@ -1455,6 +1532,100 @@ def add_block_state(self, state: PipelineState, block_state: BlockState):
14551532
param = getattr(block_state, output_param.name)
14561533
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
14571534

1535+
for input_param in self.intermediates_inputs:
1536+
if input_param.name and hasattr(block_state, input_param.name):
1537+
param = getattr(block_state, input_param.name)
1538+
# Only add if the value is different from what's in the state
1539+
current_value = state.get_intermediate(input_param.name)
1540+
if current_value is not param: # Using identity comparison to check if object was modified
1541+
state.add_intermediate(input_param.name, param, input_param.kwargs_type)
1542+
elif input_param.kwargs_type:
1543+
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
1544+
# we need to first find out which inputs are and loop through them.
1545+
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
1546+
for param_name, current_value in intermediates_kwargs.items():
1547+
if not hasattr(block_state, param_name):
1548+
continue
1549+
param = getattr(block_state, param_name)
1550+
if current_value is not param: # Using identity comparison to check if object was modified
1551+
state.add_intermediate(param_name, param, input_param.kwargs_type)
1552+
1553+
1554+
@property
1555+
def doc(self):
1556+
return make_doc_string(
1557+
self.inputs,
1558+
self.intermediates_inputs,
1559+
self.outputs,
1560+
self.description,
1561+
class_name=self.__class__.__name__,
1562+
expected_components=self.expected_components,
1563+
expected_configs=self.expected_configs
1564+
)
1565+
1566+
# modified from SequentialPipelineBlocks,
1567+
#(does not need trigger_inputs related part so removed them,
1568+
# do not need to support auto block for loop blocks)
1569+
def __repr__(self):
1570+
class_name = self.__class__.__name__
1571+
base_class = self.__class__.__bases__[0].__name__
1572+
header = (
1573+
f"{class_name}(\n Class: {base_class}\n"
1574+
if base_class and base_class != "object"
1575+
else f"{class_name}(\n"
1576+
)
1577+
1578+
# Format description with proper indentation
1579+
desc_lines = self.description.split('\n')
1580+
desc = []
1581+
# First line with "Description:" label
1582+
desc.append(f" Description: {desc_lines[0]}")
1583+
# Subsequent lines with proper indentation
1584+
if len(desc_lines) > 1:
1585+
desc.extend(f" {line}" for line in desc_lines[1:])
1586+
desc = '\n'.join(desc) + '\n'
1587+
1588+
# Components section - focus only on expected components
1589+
expected_components = getattr(self, "expected_components", [])
1590+
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
1591+
1592+
# Configs section - use format_configs with add_empty_lines=False
1593+
expected_configs = getattr(self, "expected_configs", [])
1594+
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
1595+
1596+
# Blocks section - moved to the end with simplified format
1597+
blocks_str = " Blocks:\n"
1598+
for i, (name, block) in enumerate(self.blocks.items()):
1599+
1600+
# For SequentialPipelineBlocks, show execution order
1601+
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
1602+
1603+
# Add block description
1604+
desc_lines = block.description.split('\n')
1605+
indented_desc = desc_lines[0]
1606+
if len(desc_lines) > 1:
1607+
indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:])
1608+
blocks_str += f" Description: {indented_desc}\n\n"
1609+
1610+
# Build the representation with conditional sections
1611+
result = f"{header}\n{desc}"
1612+
1613+
# Only add components section if it has content
1614+
if components_str.strip():
1615+
result += f"\n\n{components_str}"
1616+
1617+
# Only add configs section if it has content
1618+
if configs_str.strip():
1619+
result += f"\n\n{configs_str}"
1620+
1621+
# Always add blocks section
1622+
result += f"\n\n{blocks_str})"
1623+
1624+
return result
1625+
1626+
1627+
1628+
14581629
# YiYi TODO:
14591630
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
14601631
# 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader

0 commit comments

Comments
 (0)