Skip to content

Commit 485f8d1

Browse files
committed
more refactor
1 parent cff0fd6 commit 485f8d1

File tree

3 files changed

+157
-23
lines changed

3 files changed

+157
-23
lines changed

src/diffusers/pipelines/components_manager.py

Lines changed: 112 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,99 @@ def remove(self, name):
256256
if self._auto_offload_enabled:
257257
self.enable_auto_cpu_offload(self._auto_offload_device)
258258

259+
# YiYi TODO: looking into improving the search pattern
259260
def get(self, names: Union[str, List[str]]):
261+
"""
262+
Get components by name with simple pattern matching.
263+
264+
Args:
265+
names: Component name(s) or pattern(s)
266+
Patterns:
267+
- "unet" : exact match
268+
- "!unet" : everything except exact match "unet"
269+
- "base_*" : everything starting with "base_"
270+
- "!base_*" : everything NOT starting with "base_"
271+
- "*unet*" : anything containing "unet"
272+
- "!*unet*" : anything NOT containing "unet"
273+
- "refiner|vae|unet" : anything containing any of these terms
274+
- "!refiner|vae|unet" : anything NOT containing any of these terms
275+
276+
Returns:
277+
Single component if names is str and matches one component,
278+
dict of components if names matches multiple components or is a list
279+
"""
260280
if isinstance(names, str):
261-
if names not in self.components:
281+
# Check if this is a "not" pattern
282+
is_not_pattern = names.startswith('!')
283+
if is_not_pattern:
284+
names = names[1:] # Remove the ! prefix
285+
286+
# Handle OR patterns (containing |)
287+
if '|' in names:
288+
terms = names.split('|')
289+
matches = {
290+
name: comp for name, comp in self.components.items()
291+
if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern
292+
}
293+
if is_not_pattern:
294+
logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}")
295+
else:
296+
logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}")
297+
298+
# Exact match
299+
elif names in self.components:
300+
if is_not_pattern:
301+
matches = {
302+
name: comp for name, comp in self.components.items()
303+
if name != names
304+
}
305+
logger.info(f"Getting all components except '{names}': {list(matches.keys())}")
306+
else:
307+
logger.info(f"Getting component: {names}")
308+
return self.components[names]
309+
310+
# Prefix match (ends with *)
311+
elif names.endswith('*'):
312+
prefix = names[:-1]
313+
matches = {
314+
name: comp for name, comp in self.components.items()
315+
if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern
316+
}
317+
if is_not_pattern:
318+
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
319+
else:
320+
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
321+
322+
# Contains match (starts with *)
323+
elif names.startswith('*'):
324+
search = names[1:-1] if names.endswith('*') else names[1:]
325+
matches = {
326+
name: comp for name, comp in self.components.items()
327+
if (search in name) != is_not_pattern # Flip condition if not pattern
328+
}
329+
if is_not_pattern:
330+
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
331+
else:
332+
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
333+
334+
else:
262335
raise ValueError(f"Component '{names}' not found in ComponentsManager")
263-
return self.components[names]
336+
337+
if not matches:
338+
raise ValueError(f"No components found matching pattern '{names}'")
339+
return matches if len(matches) > 1 else next(iter(matches.values()))
340+
264341
elif isinstance(names, list):
265-
return {n: self.components[n] for n in names}
342+
results = {}
343+
for name in names:
344+
result = self.get(name)
345+
if isinstance(result, dict):
346+
results.update(result)
347+
else:
348+
results[name] = result
349+
logger.info(f"Getting multiple components: {list(results.keys())}")
350+
return results
351+
266352
else:
267353
raise ValueError(f"Invalid type for names: {type(names)}")
268354

@@ -431,18 +517,34 @@ def __repr__(self):
431517

432518
return output
433519

434-
def add_from_pretrained(self, pretrained_model_name_or_path, **kwargs):
520+
def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
521+
"""
522+
Load components from a pretrained model and add them to the manager.
523+
524+
Args:
525+
pretrained_model_name_or_path (str): The path or identifier of the pretrained model
526+
prefix (str, optional): Prefix to add to all component names loaded from this model.
527+
If provided, components will be named as "{prefix}_{component_name}"
528+
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
529+
"""
435530
from ..pipelines.pipeline_utils import DiffusionPipeline
436531

437532
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
438533
for name, component in pipe.components.items():
439-
if name not in self.components and component is not None:
440-
self.add(name, component)
441-
elif name in self.components:
534+
535+
if component is None:
536+
continue
537+
538+
# Add prefix if specified
539+
component_name = f"{prefix}_{name}" if prefix else name
540+
541+
if component_name not in self.components:
542+
self.add(component_name, component)
543+
else:
442544
logger.warning(
443-
f"Component '{name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
444-
f"1. remove the existing component with remove('{name}')\n"
445-
f"2. Use a different name: add('{name}_2', component)"
545+
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
546+
f"1. remove the existing component with remove('{component_name}')\n"
547+
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
446548
)
447549

448550
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:

src/diffusers/pipelines/modular_pipeline.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -975,10 +975,16 @@ def intermediates_inputs(self) -> List[str]:
975975
for block in self.blocks.values():
976976
# Add inputs that aren't in outputs yet
977977
inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs)
978-
# Add this block's outputs
979-
block_intermediates_outputs = [out.name for out in block.intermediates_outputs]
980-
outputs.update(block_intermediates_outputs)
981978

979+
# Only add outputs if the block cannot be skipped
980+
should_add_outputs = True
981+
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
982+
should_add_outputs = False
983+
984+
if should_add_outputs:
985+
# Add this block's outputs
986+
block_intermediates_outputs = [out.name for out in block.intermediates_outputs]
987+
outputs.update(block_intermediates_outputs)
982988
return inputs
983989

984990
@property
@@ -1035,47 +1041,59 @@ def trigger_inputs(self):
10351041
return self._get_trigger_inputs()
10361042

10371043
def _traverse_trigger_blocks(self, trigger_inputs):
1044+
# Convert trigger_inputs to a set for easier manipulation
1045+
active_triggers = set(trigger_inputs)
10381046

1039-
def fn_recursive_traverse(block, block_name, trigger_inputs):
1047+
def fn_recursive_traverse(block, block_name, active_triggers):
10401048
result_blocks = OrderedDict()
1049+
10411050
# sequential or PipelineBlock
10421051
if not hasattr(block, 'block_trigger_inputs'):
10431052
if hasattr(block, 'blocks'):
10441053
# sequential
10451054
for block_name, block in block.blocks.items():
1046-
blocks_to_update = fn_recursive_traverse(block, block_name, trigger_inputs)
1055+
blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
10471056
result_blocks.update(blocks_to_update)
10481057
else:
10491058
# PipelineBlock
10501059
result_blocks[block_name] = block
1060+
# Add this block's output names to active triggers if defined
1061+
if hasattr(block, 'outputs'):
1062+
active_triggers.update(out.name for out in block.outputs)
10511063
return result_blocks
10521064

10531065
# auto
10541066
else:
1055-
# Find first block_trigger_input that matches any value in our trigger_value tuple
1067+
# Find first block_trigger_input that matches any value in our active_triggers
10561068
this_block = None
1069+
matching_trigger = None
10571070
for trigger_input in block.block_trigger_inputs:
1058-
if trigger_input is not None and trigger_input in trigger_inputs:
1071+
if trigger_input is not None and trigger_input in active_triggers:
10591072
this_block = block.trigger_to_block_map[trigger_input]
1073+
matching_trigger = trigger_input
10601074
break
10611075

10621076
# If no matches found, try to get the default (None) block
10631077
if this_block is None and None in block.block_trigger_inputs:
10641078
this_block = block.trigger_to_block_map[None]
1079+
matching_trigger = None
10651080

10661081
if this_block is not None:
10671082
# sequential/auto
10681083
if hasattr(this_block, 'blocks'):
1069-
result_blocks.update(fn_recursive_traverse(this_block, block_name, trigger_inputs))
1084+
result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
10701085
else:
10711086
# PipelineBlock
10721087
result_blocks[block_name] = this_block
1088+
# Add this block's output names to active triggers if defined
1089+
if hasattr(this_block, 'outputs'):
1090+
active_triggers.update(out.name for out in this_block.outputs)
10731091

10741092
return result_blocks
10751093

10761094
all_blocks = OrderedDict()
10771095
for block_name, block in self.blocks.items():
1078-
blocks_to_update = fn_recursive_traverse(block, block_name, trigger_inputs)
1096+
blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
10791097
all_blocks.update(blocks_to_update)
10801098
return all_blocks
10811099

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2994,14 +2994,28 @@ def description(self):
29942994
" - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \
29952995
" - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided."
29962996

2997+
class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
2998+
block_classes = [StableDiffusionXLIPAdapterStep]
2999+
block_names = ["ip_adapter"]
3000+
block_trigger_inputs = ["ip_adapter_image"]
29973001

2998-
class StableDiffusionAutoPipeline(SequentialPipelineBlocks):
2999-
block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep]
3000-
block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decode"]
3002+
@property
3003+
def description(self):
3004+
return "Run IP Adapter step if `ip_adapter_image` is provided."
3005+
3006+
class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks):
3007+
block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep]
3008+
block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decode"]
30013009

30023010
@property
30033011
def description(self):
3004-
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n"
3012+
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \
3013+
"- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \
3014+
"- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \
3015+
"- to run the controlnet workflow, you need to provide `control_image`\n" + \
3016+
"- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \
3017+
"- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \
3018+
"- for text-to-image generation, all you need to provide is `prompt`"
30053019

30063020
# block mapping
30073021
TEXT2IMAGE_BLOCKS = OrderedDict([

0 commit comments

Comments
 (0)