Skip to content

Commit 54f410d

Browse files
committed
add inpaint
1 parent c12a05b commit 54f410d

File tree

2 files changed

+456
-239
lines changed

2 files changed

+456
-239
lines changed

src/diffusers/pipelines/modular_pipeline.py

Lines changed: 29 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -166,25 +166,34 @@ def __repr__(self):
166166
)
167167

168168

169-
def combine_inputs(*input_lists: List[Tuple[str, Any]]) -> List[Tuple[str, Any]]:
169+
def combine_inputs(*named_input_lists: List[Tuple[str, List[Tuple[str, Any]]]]) -> List[Tuple[str, Any]]:
170170
"""
171-
Combines multiple lists of (name, default_value) tuples. For duplicate inputs, updates only if current value is
172-
None and new value is not None. Warns if multiple non-None default values exist for the same input.
171+
Combines multiple lists of (name, default_value) tuples from different blocks. For duplicate inputs, updates only if
172+
current value is None and new value is not None. Warns if multiple non-None default values exist for the same input.
173+
174+
Args:
175+
named_input_lists: List of tuples containing (block_name, input_list) pairs
173176
"""
174177
combined_dict = {}
175-
for inputs in input_lists:
178+
# Track which block provided which value
179+
value_sources = {}
180+
181+
for block_name, inputs in named_input_lists:
176182
for name, value in inputs:
177183
if name in combined_dict:
178184
current_value = combined_dict[name]
179185
if current_value is not None and value is not None and current_value != value:
180186
warnings.warn(
181187
f"Multiple different default values found for input '{name}': "
182-
f"{current_value} and {value}. Using {current_value}."
188+
f"{current_value} (from block '{value_sources[name]}') and "
189+
f"{value} (from block '{block_name}'). Using {current_value}."
183190
)
184191
if current_value is None and value is not None:
185192
combined_dict[name] = value
193+
value_sources[name] = block_name
186194
else:
187195
combined_dict[name] = value
196+
value_sources[name] = block_name
188197
return list(combined_dict.items())
189198

190199

@@ -268,62 +277,10 @@ def intermediates_inputs(self) -> List[str]:
268277
def intermediates_outputs(self) -> List[str]:
269278
raise NotImplementedError("intermediates_outputs property must be implemented in subclasses")
270279

271-
@property
272-
def model_cpu_offload_seq(self):
273-
raise NotImplementedError("model_cpu_offload_seq property must be implemented in subclasses")
274-
275280
def __call__(self, pipeline, state):
276281
raise NotImplementedError("__call__ method must be implemented in subclasses")
277282

278-
def __repr__(self):
279-
class_name = self.__class__.__name__
280-
281-
# Components section
282-
expected_components = set(getattr(self, "expected_components", []))
283-
loaded_components = set(self.components.keys())
284-
all_components = sorted(expected_components | loaded_components)
285-
components_str = " Components:\n" + "\n".join(
286-
f" - {k}={type(self.components[k]).__name__}" if k in loaded_components else f" - {k}"
287-
for k in all_components
288-
)
289-
290-
# Auxiliaries section
291-
auxiliaries_str = " Auxiliaries:\n" + "\n".join(
292-
f" - {k}={type(v).__name__}" for k, v in self.auxiliaries.items()
293-
)
294-
295-
# Configs section
296-
expected_configs = set(getattr(self, "expected_configs", []))
297-
loaded_configs = set(self.configs.keys())
298-
all_configs = sorted(expected_configs | loaded_configs)
299-
configs_str = " Configs:\n" + "\n".join(
300-
f" - {k}={self.configs[k]}" if k in loaded_configs else f" - {k}" for k in all_configs
301-
)
302-
303-
# Blocks section
304-
blocks_str = " Blocks:\n" + "\n".join(
305-
f" - {name}={block.__class__.__name__}" for name, block in self.blocks.items()
306-
)
307-
308-
# Other information
309-
inputs_str = " Inputs:\n" + "\n".join(f" - {name}={default}" for name, default in self.inputs)
310-
311-
intermediates_str = (
312-
" Intermediates:\n"
313-
f" - inputs: {', '.join(self.intermediates_inputs)}\n"
314-
f" - outputs: {', '.join(self.intermediates_outputs)}"
315-
)
316283

317-
return (
318-
f"{class_name}(\n"
319-
f"{components_str}\n"
320-
f"{auxiliaries_str}\n"
321-
f"{configs_str}\n"
322-
f"{blocks_str}\n"
323-
f"{inputs_str}\n"
324-
f"{intermediates_str}\n"
325-
f")"
326-
)
327284

328285

329286
# YiYi TODO: remove the trigger input logic and keep it more flexible and less convenient:
@@ -364,7 +321,8 @@ def __post_init__(self):
364321

365322
@property
366323
def inputs(self) -> List[Tuple[str, Any]]:
367-
return combine_inputs(*(block.inputs for block in self.blocks.values()))
324+
named_inputs = [(name, block.inputs) for name, block in self.blocks.items()]
325+
return combine_inputs(*named_inputs)
368326

369327
@property
370328
def intermediates_inputs(self) -> List[str]:
@@ -489,7 +447,8 @@ class SequentialPipelineBlocks(MultiPipelineBlocks):
489447

490448
@property
491449
def inputs(self) -> List[Tuple[str, Any]]:
492-
return combine_inputs(*(block.inputs for block in self.blocks.values()))
450+
named_inputs = [(name, block.inputs) for name, block in self.blocks.items()]
451+
return combine_inputs(*named_inputs)
493452

494453
@property
495454
def intermediates_inputs(self) -> List[str]:
@@ -822,21 +781,19 @@ def __repr__(self):
822781
output += f"{block.__class__.__name__}\n"
823782
output += "\n"
824783

825-
intermediates_str = ""
826-
if hasattr(block, "intermediates_inputs"):
827-
intermediates_str += f"{', '.join(block.intermediates_inputs)}"
828-
829784
if hasattr(block, "intermediates_outputs"):
830-
if intermediates_str:
831-
intermediates_str += " -> "
832-
else:
833-
intermediates_str += "-> "
834-
intermediates_str += f"{', '.join(block.intermediates_outputs)}"
835-
836-
if intermediates_str:
785+
intermediates_str = f"-> {', '.join(block.intermediates_outputs)}"
837786
output += f" {intermediates_str}\n"
787+
output += "\n"
838788

839-
output += "\n"
789+
# Add final intermediate outputs for SequentialPipelineBlocks
790+
if isinstance(block, SequentialPipelineBlocks):
791+
last_block = list(block.blocks.values())[-1]
792+
if hasattr(last_block, "intermediates_outputs"):
793+
final_outputs = last_block.intermediates_outputs
794+
final_intermediates_str = f" (final intermediate outputs: {', '.join(final_outputs)})"
795+
output += f" {final_intermediates_str}\n"
796+
output += "\n"
840797

841798
# List the components registered in the pipeline
842799
output += "Registered Components:\n"
@@ -861,7 +818,7 @@ def __repr__(self):
861818
for name, default in self.default_call_parameters.items():
862819
output += f"{name}: {default!r}\n"
863820

864-
output += "\nRequired intermediate inputs:\n"
821+
output += "\nIntermediate inputs:\n"
865822
output += "--------------------------\n"
866823
for name in self.pipeline_block.intermediates_inputs:
867824
output += f"{name}: \n"

0 commit comments

Comments
 (0)