diff --git a/src/diffusers/pipelines/modular_pipeline.py b/src/diffusers/pipelines/modular_pipeline.py index 636b543395df..4db8433768e4 100644 --- a/src/diffusers/pipelines/modular_pipeline.py +++ b/src/diffusers/pipelines/modular_pipeline.py @@ -12,29 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect +import importlib +import os import traceback import warnings from collections import OrderedDict +from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Union, Optional, Type - +from typing import Any, Dict, List, Optional, Tuple, Union import torch -from tqdm.auto import tqdm -import re -import os -import importlib - from huggingface_hub.utils import validate_hf_hub_args +from tqdm.auto import tqdm from ..configuration_utils import ConfigMixin, FrozenDict from ..utils import ( + PushToHubMixin, is_accelerate_available, - is_accelerate_version, logging, - PushToHubMixin, ) -from .pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj,_fetch_class_library_tuple +from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code +from .components_manager import ComponentsManager from .modular_pipeline_utils import ( ComponentSpec, ConfigSpec, @@ -42,18 +41,15 @@ OutputParam, format_components, format_configs, - format_input_params, format_inputs_short, format_intermediates_short, - format_output_params, - format_params, make_doc_string, ) -from .components_manager import ComponentsManager +from .pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj + -from copy import deepcopy if is_accelerate_available(): - import accelerate + pass logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -108,18 +104,16 @@ def format_value(v): intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) return ( - f"PipelineState(\n" - f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }}\n" - f")" + f"PipelineState(\n" f" inputs={{\n{inputs}\n }},\n" f" intermediates={{\n{intermediates}\n }}\n" f")" ) -@dataclass +@dataclass class BlockState: """ Container for block state data with attribute access and formatted representation. """ + def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) @@ -129,28 +123,28 @@ def format_value(v): # Handle tensors directly if hasattr(v, "shape") and hasattr(v, "dtype"): return f"Tensor(dtype={v.dtype}, shape={v.shape})" - + # Handle lists of tensors elif isinstance(v, list): if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): shapes = [t.shape for t in v] return f"List[{len(v)}] of Tensors with shapes {shapes}" return repr(v) - + # Handle tuples of tensors elif isinstance(v, tuple): if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"): shapes = [t.shape for t in v] return f"Tuple[{len(v)}] of Tensors with shapes {shapes}" return repr(v) - + # Handle dicts with tensor values elif isinstance(v, dict): if any(hasattr(val, "shape") and hasattr(val, "dtype") for val in v.values()): shapes = {k: val.shape for k, val in v.items() if hasattr(val, "shape")} return f"Dict of Tensors with shapes {shapes}" return repr(v) - + # Default case return repr(v) @@ -158,31 +152,92 @@ def format_value(v): return f"BlockState(\n{attributes}\n)" - -class ModularPipelineMixin: +class ModularPipelineMixin(ConfigMixin): """ Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks """ - - def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None): + config_name = "config.json" + + @classmethod + def _get_signature_keys(cls, obj): + parameters = inspect.signature(obj.__init__).parameters + required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty} + optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty}) + expected_modules = set(required_parameters.keys()) - {"self"} + + return expected_modules, optional_parameters + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + trust_remote_code: Optional[bool] = None, + **kwargs, + ): + hub_kwargs_names = [ + "cache_dir", + "force_download", + "local_files_only", + "proxies", + "resume_download", + "revision", + "subfolder", + "token", + ] + hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} + + config = cls.load_config(pretrained_model_name_or_path) + has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"] + trust_remote_code = resolve_trust_remote_code( + trust_remote_code, pretrained_model_name_or_path, has_remote_code + ) + if not (has_remote_code and trust_remote_code): + raise ValueError("") + + class_ref = config["auto_map"][cls.__name__] + module_file, class_name = class_ref.split(".") + module_file = module_file + ".py" + block_cls = get_class_from_dynamic_module( + pretrained_model_name_or_path, + module_file=module_file, + class_name=class_name, + is_modular=True, + **hub_kwargs, + **kwargs, + ) + expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls) + block_kwargs = { + name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs + } + + return block_cls(**block_kwargs) + + def setup_loader( + self, + modular_repo: Optional[Union[str, os.PathLike]] = None, + component_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + ): """ - create a mouldar loader, optionally accept modular_repo to load from hub. + create a ModularLoader, optionally accept modular_repo to load from hub. """ # Import components loader (it is model-specific class) - loader_class_name = MODULAR_LOADER_MAPPING[self.model_name] + loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__) + diffusers_module = importlib.import_module("diffusers") loader_class = getattr(diffusers_module, loader_class_name) - + # Create deep copies to avoid modifying the original specs component_specs = deepcopy(self.expected_components) config_specs = deepcopy(self.expected_configs) # Create the loader with the updated specs specs = component_specs + config_specs - - self.loader = loader_class(specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection) + self.loader = loader_class( + specs, modular_repo=modular_repo, component_manager=component_manager, collection=collection + ) @property def default_call_parameters(self) -> Dict[str, Any]: @@ -238,7 +293,6 @@ def run(self, state: PipelineState = None, output: Union[str, List[str]] = None, if output is None: return state - elif isinstance(output, str): return state.get_intermediate(output) @@ -268,9 +322,8 @@ def set_progress_bar_config(self, **kwargs): class PipelineBlock(ModularPipelineMixin): - model_name = None - + @property def description(self) -> str: """Description of the block. Must be implemented by subclasses.""" @@ -279,12 +332,11 @@ def description(self) -> str: @property def expected_components(self) -> List[ComponentSpec]: return [] - + @property def expected_configs(self) -> List[ConfigSpec]: return [] - # YiYi TODO: can we combine inputs and intermediates_inputs? the difference is inputs are immutable @property def inputs(self) -> List[InputParam]: @@ -322,7 +374,6 @@ def required_intermediates_inputs(self) -> List[str]: input_names.append(input_param.name) return input_names - def __call__(self, pipeline, state: PipelineState) -> PipelineState: raise NotImplementedError("__call__ method must be implemented in subclasses") @@ -331,14 +382,14 @@ def __repr__(self): base_class = self.__class__.__bases__[0].__name__ # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - use format_components with add_empty_lines=False expected_components = getattr(self, "expected_components", []) @@ -355,7 +406,9 @@ def __repr__(self): inputs = "Inputs:\n " + inputs_str # Intermediates section - intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs) + intermediates_str = format_intermediates_short( + self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs + ) intermediates = f"Intermediates:\n{intermediates_str}" return ( @@ -369,24 +422,22 @@ def __repr__(self): f")" ) - @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) - def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} - + # Check inputs for input_param in self.inputs: value = state.get_input(input_param.name) @@ -402,7 +453,7 @@ def get_block_state(self, state: PipelineState) -> dict: data[input_param.name] = value return BlockState(**data) - + def add_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediates_outputs: if not hasattr(block_state, output_param.name): @@ -412,26 +463,28 @@ def add_block_state(self, state: PipelineState, block_state: BlockState): def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: """ - Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if - current default value is None and new default value is not None. Warns if multiple non-None default values + Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if + current default value is None and new default value is not None. Warns if multiple non-None default values exist for the same input. Args: named_input_lists: List of tuples containing (block_name, input_param_list) pairs - + Returns: List[InputParam]: Combined list of unique InputParam objects """ combined_dict = {} # name -> InputParam value_sources = {} # name -> block_name - + for block_name, inputs in named_input_lists: for input_param in inputs: if input_param.name in combined_dict: current_param = combined_dict[input_param.name] - if (current_param.default is not None and - input_param.default is not None and - current_param.default != input_param.default): + if ( + current_param.default is not None + and input_param.default is not None + and current_param.default != input_param.default + ): warnings.warn( f"Multiple different default values found for input '{input_param.name}': " f"{current_param.default} (from block '{value_sources[input_param.name]}') and " @@ -443,9 +496,10 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li else: combined_dict[input_param.name] = input_param value_sources[input_param.name] = block_name - + return list(combined_dict.values()) + def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]: """ Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, @@ -453,17 +507,17 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> Args: named_output_lists: List of tuples containing (block_name, output_param_list) pairs - + Returns: List[OutputParam]: Combined list of unique OutputParam objects """ combined_dict = {} # name -> OutputParam - + for block_name, outputs in named_output_lists: for output_param in outputs: if output_param.name not in combined_dict: combined_dict[output_param.name] = output_param - + return list(combined_dict.values()) @@ -487,15 +541,15 @@ def __init__(self): blocks[block_name] = block_cls() self.blocks = blocks if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)): - raise ValueError(f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same.") + raise ValueError( + f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same." + ) default_blocks = [t for t in self.block_trigger_inputs if t is None] - # can only have 1 or 0 default block, and has to put in the last + # can only have 1 or 0 default block, and has to put in the last # the order of blocksmatters here because the first block with matching trigger will be dispatched # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"] # if both mask and image are provided, it is inpaint; if only image is provided, it is img2img - if len(default_blocks) > 1 or ( - len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None - ): + if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None): raise ValueError( f"In {self.__class__.__name__}, exactly one None must be specified as the last element " "in block_trigger_inputs." @@ -509,7 +563,7 @@ def __init__(self): @property def model_name(self): return next(iter(self.blocks.values())).model_name - + @property def description(self): return "" @@ -532,7 +586,6 @@ def expected_configs(self): expected_configs.append(config) return expected_configs - @property def required_inputs(self) -> List[str]: first_block = next(iter(self.blocks.values())) @@ -557,7 +610,6 @@ def required_intermediates_inputs(self) -> List[str]: return list(required_by_all) - # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: @@ -571,7 +623,6 @@ def inputs(self) -> List[Tuple[str, Any]]: input_param.required = False return combined_inputs - @property def intermediates_inputs(self) -> List[str]: named_inputs = [(name, block.intermediates_inputs) for name, block in self.blocks.items()] @@ -589,7 +640,7 @@ def intermediates_outputs(self) -> List[str]: named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] combined_outputs = combine_outputs(*named_outputs) return combined_outputs - + @property def outputs(self) -> List[str]: named_outputs = [(name, block.outputs) for name, block in self.blocks.items()] @@ -630,26 +681,27 @@ def _get_trigger_inputs(self): Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique block_trigger_inputs values """ + def fn_recursive_get_trigger(blocks): trigger_values = set() - + if blocks is not None: for name, block in blocks.items(): # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - + # If block has blocks, recursively check them - if hasattr(block, 'blocks'): + if hasattr(block, "blocks"): nested_triggers = fn_recursive_get_trigger(block.blocks) trigger_values.update(nested_triggers) - + return trigger_values - + trigger_inputs = set(self.block_trigger_inputs) trigger_inputs.update(fn_recursive_get_trigger(self.blocks)) - + return trigger_inputs @property @@ -660,12 +712,9 @@ def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" + f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n" ) - if self.trigger_inputs: header += "\n" header += " " + "=" * 100 + "\n" @@ -677,19 +726,19 @@ def __repr__(self): header += " " + "=" * 100 + "\n\n" # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) @@ -699,7 +748,7 @@ def __repr__(self): for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block trigger = None - if hasattr(self, 'block_to_trigger_map'): + if hasattr(self, "block_to_trigger_map"): trigger = self.block_to_trigger_map.get(name) # Format the trigger info if trigger is None: @@ -713,47 +762,41 @@ def __repr__(self): else: # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - + # Add block description - desc_lines = block.description.split('\n') + desc_lines = block.description.split("\n") indented_desc = desc_lines[0] if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) - + return f"{header}\n" f"{desc}\n\n" f"{components_str}\n\n" f"{configs_str}\n\n" f"{blocks_str}" f")" @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) + class SequentialPipelineBlocks(ModularPipelineMixin): """ A class that combines multiple pipeline block classes into one. When called, it will call each block in sequence. """ + block_classes = [] block_names = [] @property def model_name(self): return next(iter(self.blocks.values())).model_name - + @property def description(self): return "" @@ -779,10 +822,10 @@ def expected_configs(self): @classmethod def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks": """Creates a SequentialPipelineBlocks instance from a dictionary of blocks. - + Args: blocks_dict: Dictionary mapping block names to block instances - + Returns: A new SequentialPipelineBlocks instance """ @@ -791,14 +834,13 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo instance.block_names = list(blocks_dict.keys()) instance.blocks = blocks_dict return instance - + def __init__(self): blocks = OrderedDict() for block_name, block_cls in zip(self.block_names, self.block_classes): blocks[block_name] = block_cls() self.blocks = blocks - @property def required_inputs(self) -> List[str]: # Get the first block from the dictionary @@ -809,9 +851,9 @@ def required_inputs(self) -> List[str]: for block in list(self.blocks.values())[1:]: block_required = set(getattr(block, "required_inputs", set())) required_by_any.update(block_required) - + return list(required_by_any) - + @property def required_intermediates_inputs(self) -> List[str]: required_intermediates_inputs = [] @@ -847,7 +889,7 @@ def intermediates_inputs(self) -> List[str]: should_add_outputs = True if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: should_add_outputs = False - + if should_add_outputs: # Add this block's outputs block_intermediates_outputs = [out.name for out in block.intermediates_outputs] @@ -859,11 +901,11 @@ def intermediates_outputs(self) -> List[str]: named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] combined_outputs = combine_outputs(*named_outputs) return combined_outputs - + @property def outputs(self) -> List[str]: return next(reversed(self.blocks.values())).intermediates_outputs - + @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: for block_name, block in self.blocks.items(): @@ -878,29 +920,30 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: logger.error(error_msg) raise return pipeline, state - + def _get_trigger_inputs(self): """ Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique block_trigger_inputs values """ + def fn_recursive_get_trigger(blocks): trigger_values = set() - + if blocks is not None: for name, block in blocks.items(): # Check if current block has trigger inputs(i.e. auto block) - if hasattr(block, 'block_trigger_inputs') and block.block_trigger_inputs is not None: + if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None: # Add all non-None values from the trigger inputs list trigger_values.update(t for t in block.block_trigger_inputs if t is not None) - + # If block has blocks, recursively check them - if hasattr(block, 'blocks'): + if hasattr(block, "blocks"): nested_triggers = fn_recursive_get_trigger(block.blocks) trigger_values.update(nested_triggers) - + return trigger_values - + return fn_recursive_get_trigger(self.blocks) @property @@ -913,10 +956,10 @@ def _traverse_trigger_blocks(self, trigger_inputs): def fn_recursive_traverse(block, block_name, active_triggers): result_blocks = OrderedDict() - + # sequential or PipelineBlock - if not hasattr(block, 'block_trigger_inputs'): - if hasattr(block, 'blocks'): + if not hasattr(block, "block_trigger_inputs"): + if hasattr(block, "blocks"): # sequential for block_name, block in block.blocks.items(): blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) @@ -925,10 +968,10 @@ def fn_recursive_traverse(block, block_name, active_triggers): # PipelineBlock result_blocks[block_name] = block # Add this block's output names to active triggers if defined - if hasattr(block, 'outputs'): + if hasattr(block, "outputs"): active_triggers.update(out.name for out in block.outputs) return result_blocks - + # auto else: # Find first block_trigger_input that matches any value in our active_triggers @@ -939,36 +982,35 @@ def fn_recursive_traverse(block, block_name, active_triggers): this_block = block.trigger_to_block_map[trigger_input] matching_trigger = trigger_input break - + # If no matches found, try to get the default (None) block if this_block is None and None in block.block_trigger_inputs: this_block = block.trigger_to_block_map[None] matching_trigger = None - + if this_block is not None: # sequential/auto - if hasattr(this_block, 'blocks'): + if hasattr(this_block, "blocks"): result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers)) else: # PipelineBlock result_blocks[block_name] = this_block # Add this block's output names to active triggers if defined - if hasattr(this_block, 'outputs'): + if hasattr(this_block, "outputs"): active_triggers.update(out.name for out in this_block.outputs) return result_blocks - + all_blocks = OrderedDict() for block_name, block in self.blocks.items(): blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers) all_blocks.update(blocks_to_update) return all_blocks - + def get_execution_blocks(self, *trigger_inputs): trigger_inputs_all = self.trigger_inputs if trigger_inputs is not None: - if not isinstance(trigger_inputs, (list, tuple, set)): trigger_inputs = [trigger_inputs] invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all] @@ -977,7 +1019,7 @@ def get_execution_blocks(self, *trigger_inputs): f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}" ) trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all] - + if trigger_inputs is None: if None in trigger_inputs_all: trigger_inputs = [None] @@ -985,17 +1027,14 @@ def get_execution_blocks(self, *trigger_inputs): trigger_inputs = [trigger_inputs_all[0]] blocks_triggered = self._traverse_trigger_blocks(trigger_inputs) return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered) - + def __repr__(self): class_name = self.__class__.__name__ base_class = self.__class__.__bases__[0].__name__ header = ( - f"{class_name}(\n Class: {base_class}\n" - if base_class and base_class != "object" - else f"{class_name}(\n" + f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n" ) - if self.trigger_inputs: header += "\n" header += " " + "=" * 100 + "\n" @@ -1007,19 +1046,19 @@ def __repr__(self): header += " " + "=" * 100 + "\n\n" # Format description with proper indentation - desc_lines = self.description.split('\n') + desc_lines = self.description.split("\n") desc = [] # First line with "Description:" label desc.append(f" Description: {desc_lines[0]}") # Subsequent lines with proper indentation if len(desc_lines) > 1: desc.extend(f" {line}" for line in desc_lines[1:]) - desc = '\n'.join(desc) + '\n' + desc = "\n".join(desc) + "\n" # Components section - focus only on expected components expected_components = getattr(self, "expected_components", []) components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - + # Configs section - use format_configs with add_empty_lines=False expected_configs = getattr(self, "expected_configs", []) configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) @@ -1029,7 +1068,7 @@ def __repr__(self): for i, (name, block) in enumerate(self.blocks.items()): # Get trigger input for this block trigger = None - if hasattr(self, 'block_to_trigger_map'): + if hasattr(self, "block_to_trigger_map"): trigger = self.block_to_trigger_map.get(name) # Format the trigger info if trigger is None: @@ -1043,39 +1082,30 @@ def __repr__(self): else: # For SequentialPipelineBlocks, show execution order blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n" - + # Add block description - desc_lines = block.description.split('\n') + desc_lines = block.description.split("\n") indented_desc = desc_lines[0] if len(desc_lines) > 1: - indented_desc += '\n' + '\n'.join(' ' + line for line in desc_lines[1:]) + indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:]) blocks_str += f" Description: {indented_desc}\n\n" - return ( - f"{header}\n" - f"{desc}\n\n" - f"{components_str}\n\n" - f"{configs_str}\n\n" - f"{blocks_str}" - f")" - ) - + return f"{header}\n" f"{desc}\n\n" f"{components_str}\n\n" f"{configs_str}\n\n" f"{blocks_str}" f")" @property def doc(self): return make_doc_string( - self.inputs, - self.intermediates_inputs, - self.outputs, + self.inputs, + self.intermediates_inputs, + self.outputs, self.description, class_name=self.__class__.__name__, expected_components=self.expected_components, - expected_configs=self.expected_configs + expected_configs=self.expected_configs, ) - -# YiYi TODO: +# YiYi TODO: # 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 2. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader # 3. add validator for methods where we accpet kwargs to be passed to from_pretrained() @@ -1084,30 +1114,29 @@ class ModularLoader(ConfigMixin, PushToHubMixin): Base class for all Modular pipelines loaders. """ - config_name = "modular_model_index.json" + config_name = "modular_model_index.json" def register_components(self, **kwargs): """ - Register components with their corresponding specs. + Register components with their corresponding specs. This method is called when component changed or __init__ is called. Args: **kwargs: Keyword arguments where keys are component names and values are component objects. - + """ for name, module in kwargs.items(): - # current component spec component_spec = self._component_specs.get(name) if component_spec is None: logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") continue - + is_registered = hasattr(self, name) if module is not None and not hasattr(module, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") # actual library and class name of the module @@ -1115,10 +1144,10 @@ def register_components(self, **kwargs): library, class_name = _fetch_class_library_tuple(module) new_component_spec = ComponentSpec.from_component(name, module) component_spec_dict = self._component_spec_to_dict(new_component_spec) - + else: library, class_name = None, None - # if module is None, we do not update the spec, + # if module is None, we do not update the spec, # but we still need to update the config to make sure it's synced with the component spec # (in the case of the first time registration, we initilize the object with component spec, and then we call register_components() to register it to config) new_component_spec = component_spec @@ -1139,16 +1168,24 @@ def register_components(self, **kwargs): if module is not None and self._component_manager is not None: self._component_manager.add(name, module, self._collection) continue - + current_module = getattr(self, name, None) # skip if the component is already registered with the same object if current_module is module: - logger.info(f"ModularLoader.register_components: {name} is already registered with same object, skipping") + logger.info( + f"ModularLoader.register_components: {name} is already registered with same object, skipping" + ) continue - + # it module is not an instance of the expected type, still register it but with a warning - if module is not None and component_spec.type_hint is not None and not isinstance(module, component_spec.type_hint): - logger.warning(f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}") + if ( + module is not None + and component_spec.type_hint is not None + and not isinstance(module, component_spec.type_hint) + ): + logger.warning( + f"ModularLoader.register_components: adding {name} with new type: {module.__class__.__name__}, previous type: {component_spec.type_hint.__name__}" + ) # warn if unregister if current_module is not None and module is None: @@ -1157,10 +1194,12 @@ def register_components(self, **kwargs): f"(was {current_module.__class__.__name__})" ) # same type, new instance → debug - elif current_module is not None \ - and module is not None \ - and isinstance(module, current_module.__class__) \ - and current_module != module: + elif ( + current_module is not None + and module is not None + and isinstance(module, current_module.__class__) + and current_module != module + ): logger.debug( f"ModularLoader.register_components: replacing existing '{name}' " f"(same type {type(current_module).__name__}, new instance)" @@ -1175,46 +1214,51 @@ def register_components(self, **kwargs): if module is not None and self._component_manager is not None: self._component_manager.add(name, module, self._collection) - - # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name - def __init__(self, specs: List[Union[ComponentSpec, ConfigSpec]], modular_repo: Optional[str] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs): + def __init__( + self, + specs: List[Union[ComponentSpec, ConfigSpec]], + modular_repo: Optional[str] = None, + component_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + **kwargs, + ): """ Initialize the loader with a list of component specs and config specs. """ self._component_manager = component_manager self._collection = collection - self._component_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec) - } - self._config_specs = { - spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec) - } + self._component_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec)} + self._config_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)} # update component_specs and config_specs from modular_repo if modular_repo is not None: config_dict = self.load_config(modular_repo, **kwargs) for name, value in config_dict.items(): - if name in self._component_specs and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3: + if ( + name in self._component_specs + and self._component_specs[name].default_creation_method == "from_pretrained" + and isinstance(value, (tuple, list)) + and len(value) == 3 + ): library, class_name, component_spec_dict = value component_spec = self._dict_to_component_spec(name, component_spec_dict) self._component_specs[name] = component_spec elif name in self._config_specs: self._config_specs[name].default = value - + register_components_dict = {} for name, component_spec in self._component_specs.items(): register_components_dict[name] = None self.register_components(**register_components_dict) - + default_configs = {} for name, config_spec in self._config_specs.items(): default_configs[name] = config_spec.default self.register_to_config(**default_configs) - @property def device(self) -> torch.device: r""" @@ -1251,7 +1295,7 @@ def _execution_device(self): ): return torch.device(module._hf_hook.execution_device) return self.device - + @property def device(self) -> torch.device: r""" @@ -1280,23 +1324,18 @@ def dtype(self) -> torch.dtype: return torch.float32 - @property def components(self) -> Dict[str, Any]: # return only components we've actually set as attributes on self - return { - name: getattr(self, name) - for name in self._component_specs.keys() - if hasattr(self, name) - } + return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)} def update(self, **kwargs): """ Update components and configs after instance creation. - + Args: - """ + """ """ Update components and configuration values after the loader has been instantiated. @@ -1332,7 +1371,7 @@ def update(self, **kwargs): requires_safety_checker=False ) ``` - """ + """ # extract component_specs_updates & config_specs_updates from `specs` passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs} @@ -1340,29 +1379,25 @@ def update(self, **kwargs): for name, component in passed_components.items(): if not hasattr(component, "_diffusers_load_id"): - raise ValueError(f"`ModularLoader` only supports components created from `ComponentSpec`.") - + raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") + if len(kwargs) > 0: logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}") - self.register_components(**passed_components) - config_to_register = {} for name, new_value in passed_config_values.items(): - # e.g. requires_aesthetics_score = False self._config_specs[name].default = new_value config_to_register[name] = new_value self.register_to_config(**config_to_register) - # YiYi TODO: support map for additional from_pretrained kwargs def load(self, component_names: Optional[List[str]] = None, **kwargs): """ Load selectedcomponents from specs. - + Args: component_names: List of component names to load **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: @@ -1379,7 +1414,7 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): unknown_component_names = set([name for name in component_names if name not in self._component_specs]) if len(unknown_component_names) > 0: logger.warning(f"Unknown components will be ignored: {unknown_component_names}") - + components_to_register = {} for name in components_to_load: spec = self._component_specs[name] @@ -1399,7 +1434,7 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): components_to_register[name] = spec.create(**component_load_kwargs) except Exception as e: logger.warning(f"Failed to create component '{name}': {e}") - + # Register all components at once self.register_components(**components_to_register) @@ -1407,11 +1442,12 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs): def to(self, *args, **kwargs): pass - # YiYi TODO: + # YiYi TODO: # 1. should support save some components too! currently only modular_model_index.json is saved # 2. maybe order the json file to make it more readable: configs first, then components - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs): - + def save_pretrained( + self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs + ): component_names = list(self._component_specs.keys()) config_names = list(self._config_specs.keys()) self.register_to_config(_components_names=component_names, _configs_names=config_names) @@ -1421,11 +1457,11 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: config.pop("_configs_names", None) self._internal_dict = FrozenDict(config) - @classmethod @validate_hf_hub_args - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs): - + def from_pretrained( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], spec_only: bool = True, **kwargs + ): config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) expected_component = set(config_dict.pop("_components_names")) expected_config = set(config_dict.pop("_configs_names")) @@ -1440,7 +1476,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P elif name in expected_config: config_specs.append(ConfigSpec(name=name, default=value)) - + for name in expected_component: for spec in component_specs: if spec.name == name: @@ -1450,7 +1486,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P component_specs.append(ComponentSpec(name=name, default_creation_method="from_config")) return cls(component_specs + config_specs) - @staticmethod def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: """ @@ -1533,4 +1568,4 @@ def _dict_to_component_spec( name=name, type_hint=type_hint, **spec_dict, - ) \ No newline at end of file + ) diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 5d0752af8983..5d5eb23969ab 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -15,13 +15,16 @@ """Utilities to dynamically load objects from the Hub.""" import importlib +import signal import inspect import json import os import re import shutil import sys +import threading from pathlib import Path +from types import ModuleType from typing import Dict, Optional, Union from urllib import request @@ -37,6 +40,8 @@ # See https://huggingface.co/datasets/diffusers/community-pipelines-mirror COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror" +TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15)) +_HF_REMOTE_CODE_LOCK = threading.Lock() def get_diffusers_versions(): @@ -154,15 +159,87 @@ def check_imports(filename): return get_relative_imports(filename) -def get_class_in_module(class_name, module_path): +def _raise_timeout_error(signum, frame): + raise ValueError( + "Loading this model requires you to execute custom code contained in the model repository on your local " + "machine. Please set the option `trust_remote_code=True` to permit loading of this model." + ) + + +def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code): + if trust_remote_code is None: + if has_remote_code and TIME_OUT_REMOTE_CODE > 0: + prev_sig_handler = None + try: + prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) + signal.alarm(TIME_OUT_REMOTE_CODE) + while trust_remote_code is None: + answer = input( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" + f"Do you wish to run the custom code? [y/N] " + ) + if answer.lower() in ["yes", "y", "1"]: + trust_remote_code = True + elif answer.lower() in ["no", "n", "0", ""]: + trust_remote_code = False + signal.alarm(0) + except Exception: + # OS which does not support signal.SIGALRM + raise ValueError( + f"The repository for {model_name} contains custom code which must be executed to correctly " + f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" + f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + ) + finally: + if prev_sig_handler is not None: + signal.signal(signal.SIGALRM, prev_sig_handler) + signal.alarm(0) + elif has_remote_code: + # For the CI which puts the timeout at 0 + _raise_timeout_error(None, None) + + if has_remote_code and not trust_remote_code: + raise ValueError( + f"Loading {model_name} requires you to execute the configuration file in that" + " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" + " set the option `trust_remote_code=True` to remove this error." + ) + + return trust_remote_code + + +def get_class_in_module(class_name, module_path, force_reload=False): """ Import a module on the cache directory for modules and extract a class from it. """ - module_path = module_path.replace(os.path.sep, ".") - module = importlib.import_module(module_path) + name = os.path.normpath(module_path) + if name.endswith(".py"): + name = name[:-3] + name = name.replace(os.path.sep, ".") + module_file: Path = Path(HF_MODULES_CACHE) / module_path + + with _HF_REMOTE_CODE_LOCK: + if force_reload: + sys.modules.pop(name, None) + importlib.invalidate_caches() + cached_module: Optional[ModuleType] = sys.modules.get(name) + module_spec = importlib.util.spec_from_file_location(name, location=module_file) + + module: ModuleType + if cached_module is None: + module = importlib.util.module_from_spec(module_spec) + # insert it into sys.modules before any loading begins + sys.modules[name] = module + else: + module = cached_module + + module_spec.loader.exec_module(module) if class_name is None: return find_pipeline_class(module) + return getattr(module, class_name) @@ -454,4 +531,4 @@ def get_class_from_dynamic_module( revision=revision, local_files_only=local_files_only, ) - return get_class_in_module(class_name, final_module.replace(".py", "")) + return get_class_in_module(class_name, final_module)