Skip to content

Commit 6985906

Browse files
committed
controlnet input & remove the MultiPipelineBlocks class
1 parent 54f410d commit 6985906

File tree

2 files changed

+182
-115
lines changed

2 files changed

+182
-115
lines changed

src/diffusers/pipelines/modular_pipeline.py

Lines changed: 107 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,38 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[Tuple[str, Any]]]])
197197
return list(combined_dict.items())
198198

199199

200-
class MultiPipelineBlocks:
200+
class AutoPipelineBlocks:
201201
"""
202-
A class that combines multiple pipeline block classes into one. When used, it has same API and properties as
203-
PipelineBlock. And it can be used in ModularPipeline as a single pipeline block.
202+
A class that automatically selects a block to run based on the inputs.
203+
204+
Attributes:
205+
block_classes: List of block classes to be used
206+
block_names: List of prefixes for each block
207+
block_trigger_inputs: List of input names that trigger specific blocks, with None for default
204208
"""
205209

206210
block_classes = []
207-
block_prefixes = []
211+
block_names = []
212+
block_trigger_inputs = []
213+
214+
def __init__(self):
215+
blocks = OrderedDict()
216+
for block_name, block_cls in zip(self.block_names, self.block_classes):
217+
blocks[block_name] = block_cls()
218+
self.blocks = blocks
219+
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
220+
raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.")
221+
default_blocks = [t for t in self.block_trigger_inputs if t is None]
222+
if len(default_blocks) > 1 or (
223+
len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None
224+
):
225+
raise ValueError(
226+
f"In {self.__class__.__name__}, exactly one None must be specified as the last element "
227+
"in block_trigger_inputs."
228+
)
229+
230+
# Map trigger inputs to block objects
231+
self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values()))
208232

209233
@property
210234
def model_name(self):
@@ -228,13 +252,6 @@ def expected_configs(self):
228252
expected_configs.append(config)
229253
return expected_configs
230254

231-
def __init__(self):
232-
blocks = OrderedDict()
233-
for block_prefix, block_cls in zip(self.block_prefixes, self.block_classes):
234-
block_name = f"{block_prefix}_step" if block_prefix != "" else "step"
235-
blocks[block_name] = block_cls()
236-
self.blocks = blocks
237-
238255
# YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc
239256
@property
240257
def components(self):
@@ -265,60 +282,6 @@ def configs(self):
265282
configs.update(block.configs)
266283
return configs
267284

268-
@property
269-
def inputs(self) -> List[Tuple[str, Any]]:
270-
raise NotImplementedError("inputs property must be implemented in subclasses")
271-
272-
@property
273-
def intermediates_inputs(self) -> List[str]:
274-
raise NotImplementedError("intermediates_inputs property must be implemented in subclasses")
275-
276-
@property
277-
def intermediates_outputs(self) -> List[str]:
278-
raise NotImplementedError("intermediates_outputs property must be implemented in subclasses")
279-
280-
def __call__(self, pipeline, state):
281-
raise NotImplementedError("__call__ method must be implemented in subclasses")
282-
283-
284-
285-
286-
# YiYi TODO: remove the trigger input logic and keep it more flexible and less convenient:
287-
# user will need to explicitly write the dispatch logic in __call__ for each subclass of this
288-
class AutoPipelineBlocks(MultiPipelineBlocks):
289-
"""
290-
A class that automatically selects which block to run based on trigger inputs.
291-
292-
Attributes:
293-
block_classes: List of block classes to be used
294-
block_prefixes: List of prefixes for each block
295-
block_trigger_inputs: List of input names that trigger specific blocks, with None for default
296-
"""
297-
298-
block_classes = []
299-
block_prefixes = []
300-
block_trigger_inputs = []
301-
302-
def __init__(self):
303-
super().__init__()
304-
self.__post_init__()
305-
306-
def __post_init__(self):
307-
"""
308-
Create mapping of trigger inputs directly to block objects. Validates that there is at most one default block
309-
(None trigger).
310-
"""
311-
# Check for at most one default block
312-
default_blocks = [t for t in self.block_trigger_inputs if t is None]
313-
if len(default_blocks) > 1:
314-
raise ValueError(
315-
f"Multiple default blocks specified in {self.__class__.__name__}. "
316-
"Must include at most one None in block_trigger_inputs."
317-
)
318-
319-
# Map trigger inputs to block objects
320-
self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.blocks.values()))
321-
322285
@property
323286
def inputs(self) -> List[Tuple[str, Any]]:
324287
named_inputs = [(name, block.inputs) for name, block in self.blocks.items()]
@@ -335,30 +298,15 @@ def intermediates_outputs(self) -> List[str]:
335298
@torch.no_grad()
336299
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
337300
# Find default block first (if any)
338-
default_block = self.trigger_to_block_map.get(None)
339-
340-
# Check which trigger inputs are present
341-
active_triggers = [
342-
input_name
343-
for input_name in self.block_trigger_inputs
344-
if input_name is not None and state.get_input(input_name) is not None
345-
]
346-
347-
# If multiple triggers are active, raise error
348-
if len(active_triggers) > 1:
349-
trigger_names = [f"'{t}'" for t in active_triggers]
350-
raise ValueError(
351-
f"Multiple trigger inputs found ({', '.join(trigger_names)}). "
352-
f"Only one trigger input can be provided for {self.__class__.__name__}."
353-
)
354301

355-
# Get the block to run (use default if no triggers active)
356-
block = self.trigger_to_block_map.get(active_triggers[0]) if active_triggers else default_block
357-
if block is None:
358-
logger.warning(f"No valid block found in {self.__class__.__name__}, skipping.")
359-
return pipeline, state
302+
block = self.trigger_to_block_map.get(None)
303+
for input_name in self.block_trigger_inputs:
304+
if input_name is not None and state.get_input(input_name) is not None:
305+
block = self.trigger_to_block_map[input_name]
306+
break
360307

361308
try:
309+
logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}")
362310
return block(pipeline, state)
363311
except Exception as e:
364312
error_msg = (
@@ -440,10 +388,70 @@ def __repr__(self):
440388
)
441389

442390

443-
class SequentialPipelineBlocks(MultiPipelineBlocks):
391+
class SequentialPipelineBlocks:
444392
"""
445393
A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence.
446394
"""
395+
block_classes = []
396+
block_names = []
397+
398+
@property
399+
def model_name(self):
400+
return next(iter(self.blocks.values())).model_name
401+
402+
@property
403+
def expected_components(self):
404+
expected_components = []
405+
for block in self.blocks.values():
406+
for component in block.expected_components:
407+
if component not in expected_components:
408+
expected_components.append(component)
409+
return expected_components
410+
411+
@property
412+
def expected_configs(self):
413+
expected_configs = []
414+
for block in self.blocks.values():
415+
for config in block.expected_configs:
416+
if config not in expected_configs:
417+
expected_configs.append(config)
418+
return expected_configs
419+
420+
def __init__(self):
421+
blocks = OrderedDict()
422+
for block_name, block_cls in zip(self.block_names, self.block_classes):
423+
blocks[block_name] = block_cls()
424+
self.blocks = blocks
425+
426+
# YiYi TODO: address the case where multiple blocks have the same component/auxiliary/config; give out warning etc
427+
@property
428+
def components(self):
429+
# Combine components from all blocks
430+
components = {}
431+
for block_name, block in self.blocks.items():
432+
for key, value in block.components.items():
433+
# Only update if:
434+
# 1. Key doesn't exist yet in components, OR
435+
# 2. New value is not None
436+
if key not in components or value is not None:
437+
components[key] = value
438+
return components
439+
440+
@property
441+
def auxiliaries(self):
442+
# Combine auxiliaries from all blocks
443+
auxiliaries = {}
444+
for block_name, block in self.blocks.items():
445+
auxiliaries.update(block.auxiliaries)
446+
return auxiliaries
447+
448+
@property
449+
def configs(self):
450+
# Combine configs from all blocks
451+
configs = {}
452+
for block_name, block in self.blocks.items():
453+
configs.update(block.configs)
454+
return configs
447455

448456
@property
449457
def inputs(self) -> List[Tuple[str, Any]]:
@@ -467,7 +475,11 @@ def intermediates_inputs(self) -> List[str]:
467475
@property
468476
def intermediates_outputs(self) -> List[str]:
469477
return list(set().union(*(block.intermediates_outputs for block in self.blocks.values())))
470-
478+
479+
@property
480+
def final_intermediates_outputs(self) -> List[str]:
481+
return next(reversed(self.blocks.values())).intermediates_outputs
482+
471483
@torch.no_grad()
472484
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
473485
for block_name, block in self.blocks.items():
@@ -536,7 +548,8 @@ def __repr__(self):
536548
intermediates_str = (
537549
"\n Intermediates:\n"
538550
f" - inputs: {', '.join(self.intermediates_inputs)}\n"
539-
f" - outputs: {', '.join(self.intermediates_outputs)}"
551+
f" - outputs: {', '.join(self.intermediates_outputs)}\n"
552+
f" - final outputs: {', '.join(self.final_intermediates_outputs)}"
540553
)
541554

542555
return (
@@ -772,7 +785,7 @@ def __repr__(self):
772785
output += "Pipeline Block:\n"
773786
output += "--------------\n"
774787
block = self.pipeline_block
775-
if isinstance(block, MultiPipelineBlocks):
788+
if hasattr(block, "blocks"):
776789
output += f"{block.__class__.__name__}\n"
777790
# Add sub-blocks information
778791
for sub_block_name, sub_block in block.blocks.items():
@@ -787,13 +800,10 @@ def __repr__(self):
787800
output += "\n"
788801

789802
# Add final intermediate outputs for SequentialPipelineBlocks
790-
if isinstance(block, SequentialPipelineBlocks):
791-
last_block = list(block.blocks.values())[-1]
792-
if hasattr(last_block, "intermediates_outputs"):
793-
final_outputs = last_block.intermediates_outputs
794-
final_intermediates_str = f" (final intermediate outputs: {', '.join(final_outputs)})"
795-
output += f" {final_intermediates_str}\n"
796-
output += "\n"
803+
if hasattr(block, "final_intermediate_output"):
804+
final_intermediates_str = f" (final intermediate outputs: {', '.join(block.final_intermediate_output)})"
805+
output += f" {final_intermediates_str}\n"
806+
output += "\n"
797807

798808
# List the components registered in the pipeline
799809
output += "Registered Components:\n"

0 commit comments

Comments
 (0)