From 60d1b81023b95952bdf2807b0748541397002ab4 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 21 Jul 2025 18:44:44 +0200 Subject: [PATCH 01/14] update --- .../modular_pipelines/modular_pipeline.py | 179 +++++++----------- 1 file changed, 64 insertions(+), 115 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 6f1c617bc261..5b5fded3b8fc 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -342,6 +342,16 @@ def expected_components(self) -> List[ComponentSpec]: def expected_configs(self) -> List[ConfigSpec]: return [] + @property + def intermediate_inputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] + @classmethod def from_pretrained( cls, @@ -423,6 +433,60 @@ def init_pipeline( ) return modular_pipeline + def get_block_state(self, state: PipelineState) -> dict: + """Get all inputs and intermediates in one dictionary""" + data = {} + state_inputs = self.inputs + self.intermediate_inputs + + # Check inputs + for input_param in state_inputs: + if input_param.name: + value = state.get_input(input_param.name) or state.get_intermediate(input_param.name) + if input_param.required and value is None: + raise ValueError(f"Required input '{input_param.name}' is missing") + elif value is not None or (value is None and input_param.name not in data): + data[input_param.name] = value + + elif input_param.kwargs_type: + # if kwargs_type is provided, get all inputs with matching kwargs_type + if input_param.kwargs_type not in data: + data[input_param.kwargs_type] = {} + inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) or state.get_intermediate_kwargs( + input_param.kwargs_type + ) + if inputs_kwargs: + for k, v in inputs_kwargs.items(): + if v is not None: + data[k] = v + data[input_param.kwargs_type][k] = v + + return BlockState(**data) + + def set_block_state(self, state: PipelineState, block_state: BlockState): + for output_param in self.intermediate_outputs: + if not hasattr(block_state, output_param.name): + raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") + param = getattr(block_state, output_param.name) + state.set_intermediate(output_param.name, param, output_param.kwargs_type) + + for input_param in self.intermediate_inputs: + if input_param.name and hasattr(block_state, input_param.name): + param = getattr(block_state, input_param.name) + # Only add if the value is different from what's in the state + current_value = state.get_intermediate(input_param.name) + if current_value is not param: # Using identity comparison to check if object was modified + state.set_intermediate(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: + # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters + # we need to first find out which inputs are and loop through them. + intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) + for param_name, current_value in intermediate_kwargs.items(): + if not hasattr(block_state, param_name): + continue + param = getattr(block_state, param_name) + if current_value is not param: # Using identity comparison to check if object was modified + state.set_intermediate(param_name, param, input_param.kwargs_type) + @staticmethod def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: """ @@ -652,51 +716,6 @@ def doc(self): expected_configs=self.expected_configs, ) - # YiYi TODO: input and inteermediate inputs with same name? should warn? - def get_block_state(self, state: PipelineState) -> dict: - """Get all inputs and intermediates in one dictionary""" - data = {} - - # Check inputs - for input_param in self.inputs: - if input_param.name: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all inputs with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) - if inputs_kwargs: - for k, v in inputs_kwargs.items(): - if v is not None: - data[k] = v - data[input_param.kwargs_type][k] = v - - # Check intermediates - for input_param in self.intermediate_inputs: - if input_param.name: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all intermediates with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) - if intermediate_kwargs: - for k, v in intermediate_kwargs.items(): - if v is not None: - if k not in data: - data[k] = v - data[input_param.kwargs_type][k] = v - return BlockState(**data) - def set_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediate_outputs: if not hasattr(block_state, output_param.name): @@ -1633,75 +1652,6 @@ def loop_step(self, components, state: PipelineState, **kwargs): def __call__(self, components, state: PipelineState) -> PipelineState: raise NotImplementedError("`__call__` method needs to be implemented by the subclass") - def get_block_state(self, state: PipelineState) -> dict: - """Get all inputs and intermediates in one dictionary""" - data = {} - - # Check inputs - for input_param in self.inputs: - if input_param.name: - value = state.get_input(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all inputs with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) - if inputs_kwargs: - for k, v in inputs_kwargs.items(): - if v is not None: - data[k] = v - data[input_param.kwargs_type][k] = v - - # Check intermediates - for input_param in self.intermediate_inputs: - if input_param.name: - value = state.get_intermediate(input_param.name) - if input_param.required and value is None: - raise ValueError(f"Required intermediate input '{input_param.name}' is missing") - elif value is not None or (value is None and input_param.name not in data): - data[input_param.name] = value - elif input_param.kwargs_type: - # if kwargs_type is provided, get all intermediates with matching kwargs_type - if input_param.kwargs_type not in data: - data[input_param.kwargs_type] = {} - intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) - if intermediate_kwargs: - for k, v in intermediate_kwargs.items(): - if v is not None: - if k not in data: - data[k] = v - data[input_param.kwargs_type][k] = v - return BlockState(**data) - - def set_block_state(self, state: PipelineState, block_state: BlockState): - for output_param in self.intermediate_outputs: - if not hasattr(block_state, output_param.name): - raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - param = getattr(block_state, output_param.name) - state.set_intermediate(output_param.name, param, output_param.kwargs_type) - - for input_param in self.intermediate_inputs: - if input_param.name and hasattr(block_state, input_param.name): - param = getattr(block_state, input_param.name) - # Only add if the value is different from what's in the state - current_value = state.get_intermediate(input_param.name) - if current_value is not param: # Using identity comparison to check if object was modified - state.set_intermediate(input_param.name, param, input_param.kwargs_type) - elif input_param.kwargs_type: - # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters - # we need to first find out which inputs are and loop through them. - intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) - for param_name, current_value in intermediate_kwargs.items(): - if not hasattr(block_state, param_name): - continue - param = getattr(block_state, param_name) - if current_value is not param: # Using identity comparison to check if object was modified - state.set_intermediate(param_name, param, input_param.kwargs_type) - @property def doc(self): return make_doc_string( @@ -1974,7 +1924,6 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = # Add inputs to state, using defaults if not provided in the kwargs or the state # if same input already in the state, will override it if provided in the kwargs - intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs] for expected_input_param in self.blocks.inputs: name = expected_input_param.name From 4423097b238ede21eab3ac26f6462300eea31b5a Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 22 Jul 2025 19:31:22 +0530 Subject: [PATCH 02/14] update --- src/diffusers/modular_pipelines/modular_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 5b5fded3b8fc..6056623d7feb 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -322,7 +322,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): """ - config_name = "config.json" + config_name = "modular_config.json" model_name = None @classmethod From 966a2ff8df60e1b76a19d994386ce19be4efa53d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 29 Jul 2025 21:06:40 +0200 Subject: [PATCH 03/14] update --- .../modular_pipelines/modular_pipeline.py | 756 +++++------------- .../stable_diffusion_xl/before_denoise.py | 78 +- .../stable_diffusion_xl/decoders.py | 13 +- .../stable_diffusion_xl/denoise.py | 40 +- .../stable_diffusion_xl/encoders.py | 20 +- .../stable_diffusion_xl/modular_pipeline.py | 11 - 6 files changed, 234 insertions(+), 684 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 5b5fded3b8fc..d7c80c65159e 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -45,8 +45,6 @@ OutputParam, format_components, format_configs, - format_inputs_short, - format_intermediates_short, make_doc_string, ) @@ -76,139 +74,59 @@ class PipelineState: [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. """ - inputs: Dict[str, Any] = field(default_factory=dict) - intermediates: Dict[str, Any] = field(default_factory=dict) - input_kwargs: Dict[str, List[str]] = field(default_factory=dict) - intermediate_kwargs: Dict[str, List[str]] = field(default_factory=dict) + values: Dict[str, Any] = field(default_factory=dict) + kwargs_mapping: Dict[str, List[str]] = field(default_factory=dict) - def set_input(self, key: str, value: Any, kwargs_type: str = None): + def set(self, key: str, value: Any, kwargs_type: str = None): """ - Add an input to the immutable pipeline state, i.e, pipeline_state.inputs. - - The kwargs_type parameter allows you to associate inputs with specific input types. For example, if you call - set_input(prompt_embeds=..., kwargs_type="guider_kwargs"), this input will be automatically fetched when a - pipeline block has "guider_kwargs" in its expected_inputs list. + Add a value to the pipeline state. Args: - key (str): The key for the input - value (Any): The input value - kwargs_type (str): The kwargs_type with which the input is associated - """ - self.inputs[key] = value - if kwargs_type is not None: - if kwargs_type not in self.input_kwargs: - self.input_kwargs[kwargs_type] = [key] - else: - self.input_kwargs[kwargs_type].append(key) - - def set_intermediate(self, key: str, value: Any, kwargs_type: str = None): + key (str): The key for the value + value (Any): The value to store + kwargs_type (str): The kwargs_type with which the value is associated """ - Add an intermediate value to the mutable pipeline state, i.e, pipeline_state.intermediates. - - The kwargs_type parameter allows you to associate intermediate values with specific input types. For example, - if you call set_intermediate(latents=..., kwargs_type="latents_kwargs"), this intermediate value will be - automatically fetched when a pipeline block has "latents_kwargs" in its expected_intermediate_inputs list. + self.values[key] = value - Args: - key (str): The key for the intermediate value - value (Any): The intermediate value - kwargs_type (str): The kwargs_type with which the intermediate value is associated - """ - self.intermediates[key] = value if kwargs_type is not None: - if kwargs_type not in self.intermediate_kwargs: - self.intermediate_kwargs[kwargs_type] = [key] + if kwargs_type not in self.kwargs_mapping: + self.kwargs_mapping[kwargs_type] = [key] else: - self.intermediate_kwargs[kwargs_type].append(key) - - def get_input(self, key: str, default: Any = None) -> Any: - """ - Get an input from the pipeline state. - - Args: - key (str): The key for the input - default (Any): The default value to return if the input is not found - - Returns: - Any: The input value - """ - value = self.inputs.get(key, default) - if value is not None: - return deepcopy(value) + self.kwargs_mapping[kwargs_type].append(key) - def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: + def get(self, keys: Union[str, List[str]], default: Any = None) -> Union[Any, Dict[str, Any]]: """ - Get multiple inputs from the pipeline state. + Get one or multiple values from the pipeline state. Args: - keys (List[str]): The keys for the inputs - default (Any): The default value to return if the input is not found + keys (Union[str, List[str]]): Key or list of keys for the values + default (Any): The default value to return if not found Returns: - Dict[str, Any]: Dictionary of inputs with matching keys + Union[Any, Dict[str, Any]]: Single value if keys is str, dictionary of values if keys is list """ - return {key: self.inputs.get(key, default) for key in keys} + if isinstance(keys, str): + return self.values.get(keys, default) + return {key: self.values.get(key, default) for key in keys} - def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]: + def get_by_kwargs(self, kwargs_type: str) -> Dict[str, Any]: """ - Get all inputs with matching kwargs_type. + Get all values with matching kwargs_type. Args: kwargs_type (str): The kwargs_type to filter by Returns: - Dict[str, Any]: Dictionary of inputs with matching kwargs_type - """ - input_names = self.input_kwargs.get(kwargs_type, []) - return self.get_inputs(input_names) - - def get_intermediate_kwargs(self, kwargs_type: str) -> Dict[str, Any]: - """ - Get all intermediates with matching kwargs_type. - - Args: - kwargs_type (str): The kwargs_type to filter by - - Returns: - Dict[str, Any]: Dictionary of intermediates with matching kwargs_type - """ - intermediate_names = self.intermediate_kwargs.get(kwargs_type, []) - return self.get_intermediates(intermediate_names) - - def get_intermediate(self, key: str, default: Any = None) -> Any: - """ - Get an intermediate value from the pipeline state. - - Args: - key (str): The key for the intermediate value - default (Any): The default value to return if the intermediate value is not found - - Returns: - Any: The intermediate value - """ - return self.intermediates.get(key, default) - - def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]: - """ - Get multiple intermediate values from the pipeline state. - - Args: - keys (List[str]): The keys for the intermediate values - default (Any): The default value to return if the intermediate value is not found - - Returns: - Dict[str, Any]: Dictionary of intermediate values with matching keys + Dict[str, Any]: Dictionary of values with matching kwargs_type """ - return {key: self.intermediates.get(key, default) for key in keys} + value_names = self.kwargs_mapping.get(kwargs_type, []) + return self.get(value_names) def to_dict(self) -> Dict[str, Any]: """ Convert PipelineState to a dictionary. - - Returns: - Dict[str, Any]: Dictionary containing all attributes of the PipelineState """ - return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} + return {**self.__dict__} def __repr__(self): def format_value(v): @@ -219,21 +137,10 @@ def format_value(v): else: return repr(v) - inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) - intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) + values_str = "\n".join(f" {k}: {format_value(v)}" for k, v in self.values.items()) + kwargs_mapping_str = "\n".join(f" {k}: {v}" for k, v in self.kwargs_mapping.items()) - # Format input_kwargs and intermediate_kwargs - input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items()) - intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items()) - - return ( - f"PipelineState(\n" - f" inputs={{\n{inputs}\n }},\n" - f" intermediates={{\n{intermediates}\n }},\n" - f" input_kwargs={{\n{input_kwargs_str}\n }},\n" - f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n" - f")" - ) + return f"PipelineState(\n values={{\n{values_str}\n }},\n kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n)" @dataclass @@ -322,7 +229,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): """ - config_name = "config.json" + config_name = "modular_config.json" model_name = None @classmethod @@ -334,6 +241,14 @@ def _get_signature_keys(cls, obj): return expected_modules, optional_parameters + def __init__(self): + self.sub_blocks = InsertableDict() + + @property + def description(self) -> str: + """Description of the block. Must be implemented by subclasses.""" + return "" + @property def expected_components(self) -> List[ComponentSpec]: return [] @@ -343,8 +258,8 @@ def expected_configs(self) -> List[ConfigSpec]: return [] @property - def intermediate_inputs(self) -> List[OutputParam]: - """List of intermediate output parameters. Must be implemented by subclasses.""" + def inputs(self) -> List[InputParam]: + """List of input parameters. Must be implemented by subclasses.""" return [] @property @@ -352,6 +267,13 @@ def intermediate_outputs(self) -> List[OutputParam]: """List of intermediate output parameters. Must be implemented by subclasses.""" return [] + def _get_outputs(self): + return self.intermediate_outputs + + @property + def outputs(self) -> List[OutputParam]: + return self._get_outputs() + @classmethod def from_pretrained( cls, @@ -436,12 +358,12 @@ def init_pipeline( def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} - state_inputs = self.inputs + self.intermediate_inputs + state_inputs = self.inputs # Check inputs for input_param in state_inputs: if input_param.name: - value = state.get_input(input_param.name) or state.get_intermediate(input_param.name) + value = state.get(input_param.name) if input_param.required and value is None: raise ValueError(f"Required input '{input_param.name}' is missing") elif value is not None or (value is None and input_param.name not in data): @@ -451,9 +373,7 @@ def get_block_state(self, state: PipelineState) -> dict: # if kwargs_type is provided, get all inputs with matching kwargs_type if input_param.kwargs_type not in data: data[input_param.kwargs_type] = {} - inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) or state.get_intermediate_kwargs( - input_param.kwargs_type - ) + inputs_kwargs = state.get_by_kwargs(input_param.kwargs_type) if inputs_kwargs: for k, v in inputs_kwargs.items(): if v is not None: @@ -467,25 +387,30 @@ def set_block_state(self, state: PipelineState, block_state: BlockState): if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") param = getattr(block_state, output_param.name) - state.set_intermediate(output_param.name, param, output_param.kwargs_type) + state.set(output_param.name, param, output_param.kwargs_type) - for input_param in self.intermediate_inputs: + for input_param in self.inputs: if input_param.name and hasattr(block_state, input_param.name): param = getattr(block_state, input_param.name) # Only add if the value is different from what's in the state - current_value = state.get_intermediate(input_param.name) + current_value = state.get(input_param.name) if current_value is not param: # Using identity comparison to check if object was modified - state.set_intermediate(input_param.name, param, input_param.kwargs_type) + state.set(input_param.name, param, input_param.kwargs_type) + elif input_param.kwargs_type: # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters # we need to first find out which inputs are and loop through them. - intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) + intermediate_kwargs = state.get_by_kwargs(input_param.kwargs_type) for param_name, current_value in intermediate_kwargs.items(): + if param_name is None: + continue + if not hasattr(block_state, param_name): continue + param = getattr(block_state, param_name) if current_value is not param: # Using identity comparison to check if object was modified - state.set_intermediate(param_name, param, input_param.kwargs_type) + state.set(param_name, param, input_param.kwargs_type) @staticmethod def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: @@ -553,199 +478,17 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> return list(combined_dict.values()) - -class PipelineBlock(ModularPipelineBlocks): - """ - A Pipeline Block is the basic building block of a Modular Pipeline. - - This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the - library implements for all the pipeline blocks (such as loading or saving etc.) - - - - This is an experimental feature and is likely to change in the future. - - - - Args: - description (str, optional): A description of the block, defaults to None. Define as a property in subclasses. - expected_components (List[ComponentSpec], optional): - A list of components that are expected to be used in the block, defaults to []. To override, define as a - property in subclasses. - expected_configs (List[ConfigSpec], optional): - A list of configs that are expected to be used in the block, defaults to []. To override, define as a - property in subclasses. - inputs (List[InputParam], optional): - A list of inputs that are expected to be used in the block, defaults to []. To override, define as a - property in subclasses. - intermediate_inputs (List[InputParam], optional): - A list of intermediate inputs that are expected to be used in the block, defaults to []. To override, - define as a property in subclasses. - intermediate_outputs (List[OutputParam], optional): - A list of intermediate outputs that are expected to be used in the block, defaults to []. To override, - define as a property in subclasses. - outputs (List[OutputParam], optional): - A list of outputs that are expected to be used in the block, defaults to []. To override, define as a - property in subclasses. - required_inputs (List[str], optional): - A list of required inputs that are expected to be used in the block, defaults to []. To override, define as - a property in subclasses. - required_intermediate_inputs (List[str], optional): - A list of required intermediate inputs that are expected to be used in the block, defaults to []. To - override, define as a property in subclasses. - required_intermediate_outputs (List[str], optional): - A list of required intermediate outputs that are expected to be used in the block, defaults to []. To - override, define as a property in subclasses. - """ - - model_name = None - - def __init__(self): - self.sub_blocks = InsertableDict() - - @property - def description(self) -> str: - """Description of the block. Must be implemented by subclasses.""" - # raise NotImplementedError("description method must be implemented in subclasses") - return "TODO: add a description" - - @property - def expected_components(self) -> List[ComponentSpec]: - return [] - - @property - def expected_configs(self) -> List[ConfigSpec]: - return [] - - @property - def inputs(self) -> List[InputParam]: - """List of input parameters. Must be implemented by subclasses.""" - return [] - - @property - def intermediate_inputs(self) -> List[InputParam]: - """List of intermediate input parameters. Must be implemented by subclasses.""" - return [] - - @property - def intermediate_outputs(self) -> List[OutputParam]: - """List of intermediate output parameters. Must be implemented by subclasses.""" - return [] - - def _get_outputs(self): - return self.intermediate_outputs - - # YiYi TODO: is it too easy for user to unintentionally override these properties? - # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks - @property - def outputs(self) -> List[OutputParam]: - return self._get_outputs() - - def _get_required_inputs(self): - input_names = [] - for input_param in self.inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - @property - def required_inputs(self) -> List[str]: - return self._get_required_inputs() + def input_names(self) -> List[str]: + return [input_param.name for input_param in self.inputs] - def _get_required_intermediate_inputs(self): - input_names = [] - for input_param in self.intermediate_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names - - # YiYi TODO: maybe we do not need this, it is only used in docstring, - # intermediate_inputs is by default required, unless you manually handle it inside the block @property - def required_intermediate_inputs(self) -> List[str]: - return self._get_required_intermediate_inputs() - - def __call__(self, pipeline, state: PipelineState) -> PipelineState: - raise NotImplementedError("__call__ method must be implemented in subclasses") - - def __repr__(self): - class_name = self.__class__.__name__ - base_class = self.__class__.__bases__[0].__name__ - - # Format description with proper indentation - desc_lines = self.description.split("\n") - desc = [] - # First line with "Description:" label - desc.append(f" Description: {desc_lines[0]}") - # Subsequent lines with proper indentation - if len(desc_lines) > 1: - desc.extend(f" {line}" for line in desc_lines[1:]) - desc = "\n".join(desc) + "\n" - - # Components section - use format_components with add_empty_lines=False - expected_components = getattr(self, "expected_components", []) - components_str = format_components(expected_components, indent_level=2, add_empty_lines=False) - components = " " + components_str.replace("\n", "\n ") - - # Configs section - use format_configs with add_empty_lines=False - expected_configs = getattr(self, "expected_configs", []) - configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False) - configs = " " + configs_str.replace("\n", "\n ") - - # Inputs section - inputs_str = format_inputs_short(self.inputs) - inputs = "Inputs:\n " + inputs_str - - # Intermediates section - intermediates_str = format_intermediates_short( - self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs - ) - intermediates = f"Intermediates:\n{intermediates_str}" - - return f"{class_name}(\n Class: {base_class}\n{desc}{components}\n{configs}\n {inputs}\n {intermediates}\n)" + def intermediate_output_names(self) -> List[str]: + return [output_param.name for output_param in self.intermediate_outputs] @property - def doc(self): - return make_doc_string( - self.inputs, - self.intermediate_inputs, - self.outputs, - self.description, - class_name=self.__class__.__name__, - expected_components=self.expected_components, - expected_configs=self.expected_configs, - ) - - def set_block_state(self, state: PipelineState, block_state: BlockState): - for output_param in self.intermediate_outputs: - if not hasattr(block_state, output_param.name): - raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") - param = getattr(block_state, output_param.name) - state.set_intermediate(output_param.name, param, output_param.kwargs_type) - - for input_param in self.intermediate_inputs: - if hasattr(block_state, input_param.name): - param = getattr(block_state, input_param.name) - # Only add if the value is different from what's in the state - current_value = state.get_intermediate(input_param.name) - if current_value is not param: # Using identity comparison to check if object was modified - state.set_intermediate(input_param.name, param, input_param.kwargs_type) - - for input_param in self.intermediate_inputs: - if input_param.name and hasattr(block_state, input_param.name): - param = getattr(block_state, input_param.name) - # Only add if the value is different from what's in the state - current_value = state.get_intermediate(input_param.name) - if current_value is not param: # Using identity comparison to check if object was modified - state.set_intermediate(input_param.name, param, input_param.kwargs_type) - elif input_param.kwargs_type: - # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters - # we need to first find out which inputs are and loop through them. - intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type) - for param_name, current_value in intermediate_kwargs.items(): - param = getattr(block_state, param_name) - if current_value is not param: # Using identity comparison to check if object was modified - state.set_intermediate(param_name, param, input_param.kwargs_type) + def output_names(self) -> List[str]: + return [output_param.name for output_param in self.outputs] class AutoPipelineBlocks(ModularPipelineBlocks): @@ -836,22 +579,6 @@ def required_inputs(self) -> List[str]: return list(required_by_all) - # YiYi TODO: maybe we do not need this, it is only used in docstring, - # intermediate_inputs is by default required, unless you manually handle it inside the block - @property - def required_intermediate_inputs(self) -> List[str]: - if None not in self.block_trigger_inputs: - return [] - first_block = next(iter(self.sub_blocks.values())) - required_by_all = set(getattr(first_block, "required_intermediate_inputs", set())) - - # Intersect with required inputs from all other blocks - for block in list(self.sub_blocks.values())[1:]: - block_required = set(getattr(block, "required_intermediate_inputs", set())) - required_by_all.intersection_update(block_required) - - return list(required_by_all) - # YiYi TODO: add test for this @property def inputs(self) -> List[Tuple[str, Any]]: @@ -865,18 +592,6 @@ def inputs(self) -> List[Tuple[str, Any]]: input_param.required = False return combined_inputs - @property - def intermediate_inputs(self) -> List[str]: - named_inputs = [(name, block.intermediate_inputs) for name, block in self.sub_blocks.items()] - combined_inputs = self.combine_inputs(*named_inputs) - # mark Required inputs only if that input is required by all the blocks - for input_param in combined_inputs: - if input_param.name in self.required_intermediate_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - @property def intermediate_outputs(self) -> List[str]: named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()] @@ -895,10 +610,10 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: block = self.trigger_to_block_map.get(None) for input_name in self.block_trigger_inputs: - if input_name is not None and state.get_input(input_name) is not None: + if input_name is not None and state.get(input_name) is not None: block = self.trigger_to_block_map[input_name] break - elif input_name is not None and state.get_intermediate(input_name) is not None: + elif input_name is not None and state.get(input_name) is not None: block = self.trigger_to_block_map[input_name] break @@ -1117,61 +832,16 @@ def __init__(self): sub_blocks[block_name] = block_cls() self.sub_blocks = sub_blocks - @property - def required_inputs(self) -> List[str]: - # Get the first block from the dictionary - first_block = next(iter(self.sub_blocks.values())) - required_by_any = set(getattr(first_block, "required_inputs", set())) - - # Union with required inputs from all other blocks - for block in list(self.sub_blocks.values())[1:]: - block_required = set(getattr(block, "required_inputs", set())) - required_by_any.update(block_required) - - return list(required_by_any) - - # YiYi TODO: maybe we do not need this, it is only used in docstring, - # intermediate_inputs is by default required, unless you manually handle it inside the block - @property - def required_intermediate_inputs(self) -> List[str]: - required_intermediate_inputs = [] - for input_param in self.intermediate_inputs: - if input_param.required: - required_intermediate_inputs.append(input_param.name) - return required_intermediate_inputs - - # YiYi TODO: add test for this - @property - def inputs(self) -> List[Tuple[str, Any]]: - return self.get_inputs() - - def get_inputs(self): - named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] - combined_inputs = self.combine_inputs(*named_inputs) - # mark Required inputs only if that input is required any of the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - @property - def intermediate_inputs(self) -> List[str]: - return self.get_intermediate_inputs() - - def get_intermediate_inputs(self): + def _get_inputs(self): inputs = [] outputs = set() - added_inputs = set() # Go through all blocks in order for block in self.sub_blocks.values(): # Add inputs that aren't in outputs yet - for inp in block.intermediate_inputs: - if inp.name not in outputs and inp.name not in added_inputs: + for inp in block.inputs: + if inp.name not in outputs and inp.name not in {input.name for input in inputs}: inputs.append(inp) - added_inputs.add(inp.name) # Only add outputs if the block cannot be skipped should_add_outputs = True @@ -1182,13 +852,32 @@ def get_intermediate_inputs(self): # Add this block's outputs block_intermediate_outputs = [out.name for out in block.intermediate_outputs] outputs.update(block_intermediate_outputs) + return inputs + # YiYi TODO: add test for this + @property + def inputs(self) -> List[Tuple[str, Any]]: + return self._get_inputs() + + @property + def required_inputs(self) -> List[str]: + # Get the first block from the dictionary + first_block = next(iter(self.sub_blocks.values())) + required_by_any = set(getattr(first_block, "required_inputs", set())) + + # Union with required inputs from all other blocks + for block in list(self.sub_blocks.values())[1:]: + block_required = set(getattr(block, "required_inputs", set())) + required_by_any.update(block_required) + + return list(required_by_any) + @property def intermediate_outputs(self) -> List[str]: named_outputs = [] for name, block in self.sub_blocks.items(): - inp_names = {inp.name for inp in block.intermediate_inputs} + inp_names = {inp.name for inp in block.inputs} # so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) # filter out them here so they do not end up as intermediate_outputs if name not in inp_names: @@ -1406,7 +1095,6 @@ def __repr__(self): def doc(self): return make_doc_string( self.inputs, - self.intermediate_inputs, self.outputs, self.description, class_name=self.__class__.__name__, @@ -1456,16 +1144,6 @@ def loop_inputs(self) -> List[InputParam]: """List of input parameters. Must be implemented by subclasses.""" return [] - @property - def loop_intermediate_inputs(self) -> List[InputParam]: - """List of intermediate input parameters. Must be implemented by subclasses.""" - return [] - - @property - def loop_intermediate_outputs(self) -> List[OutputParam]: - """List of intermediate output parameters. Must be implemented by subclasses.""" - return [] - @property def loop_required_inputs(self) -> List[str]: input_names = [] @@ -1475,12 +1153,9 @@ def loop_required_inputs(self) -> List[str]: return input_names @property - def loop_required_intermediate_inputs(self) -> List[str]: - input_names = [] - for input_param in self.loop_intermediate_inputs: - if input_param.required: - input_names.append(input_param.name) - return input_names + def loop_intermediate_outputs(self) -> List[OutputParam]: + """List of intermediate output parameters. Must be implemented by subclasses.""" + return [] # modified from SequentialPipelineBlocks to include loop_expected_components @property @@ -1508,43 +1183,16 @@ def expected_configs(self): expected_configs.append(config) return expected_configs - # modified from SequentialPipelineBlocks to include loop_inputs - def get_inputs(self): - named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()] - named_inputs.append(("loop", self.loop_inputs)) - combined_inputs = self.combine_inputs(*named_inputs) - # mark Required inputs only if that input is required any of the blocks - for input_param in combined_inputs: - if input_param.name in self.required_inputs: - input_param.required = True - else: - input_param.required = False - return combined_inputs - - @property - # Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs - def inputs(self): - return self.get_inputs() - - # modified from SequentialPipelineBlocks to include loop_intermediate_inputs - @property - def intermediate_inputs(self): - intermediates = self.get_intermediate_inputs() - intermediate_names = [input.name for input in intermediates] - for loop_intermediate_input in self.loop_intermediate_inputs: - if loop_intermediate_input.name not in intermediate_names: - intermediates.append(loop_intermediate_input) - return intermediates - - # modified from SequentialPipelineBlocks - def get_intermediate_inputs(self): + def _get_inputs(self): inputs = [] + inputs.extend(self.loop_inputs) outputs = set() - # Go through all blocks in order - for block in self.sub_blocks.values(): + for name, block in self.sub_blocks.items(): # Add inputs that aren't in outputs yet - inputs.extend(input_name for input_name in block.intermediate_inputs if input_name.name not in outputs) + for inp in block.inputs: + if inp.name not in outputs and inp not in inputs: + inputs.append(inp) # Only add outputs if the block cannot be skipped should_add_outputs = True @@ -1555,8 +1203,20 @@ def get_intermediate_inputs(self): # Add this block's outputs block_intermediate_outputs = [out.name for out in block.intermediate_outputs] outputs.update(block_intermediate_outputs) + + for input_param in inputs: + if input_param.name in self.required_inputs: + input_param.required = True + else: + input_param.required = False + return inputs + @property + # Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs + def inputs(self): + return self._get_inputs() + # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block @property def required_inputs(self) -> List[str]: @@ -1574,19 +1234,6 @@ def required_inputs(self) -> List[str]: return list(required_by_any) - # YiYi TODO: maybe we do not need this, it is only used in docstring, - # intermediate_inputs is by default required, unless you manually handle it inside the block - @property - def required_intermediate_inputs(self) -> List[str]: - required_intermediate_inputs = [] - for input_param in self.intermediate_inputs: - if input_param.required: - required_intermediate_inputs.append(input_param.name) - for input_param in self.loop_intermediate_inputs: - if input_param.required: - required_intermediate_inputs.append(input_param.name) - return required_intermediate_inputs - # YiYi TODO: this need to be thought about more # modified from SequentialPipelineBlocks to include loop_intermediate_outputs @property @@ -1876,96 +1523,6 @@ def default_call_parameters(self) -> Dict[str, Any]: params[input_param.name] = input_param.default return params - def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): - """ - Execute the pipeline by running the pipeline blocks with the given inputs. - - Args: - state (`PipelineState`, optional): - PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be - created based on the user inputs and the pipeline blocks's requirement. - output (`str` or `List[str]`, optional): - Optional specification of what to return: - - None: Returns the complete `PipelineState` with all inputs and intermediates (default) - - str: Returns a specific intermediate value from the state (e.g. `output="image"`) - - List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image", - "latents"]`) - - - Examples: - ```python - # Get complete pipeline state - state = pipeline(prompt="A beautiful sunset", num_inference_steps=20) - print(state.intermediates) # All intermediate outputs - - # Get specific output - image = pipeline(prompt="A beautiful sunset", output="image") - - # Get multiple specific outputs - results = pipeline(prompt="A beautiful sunset", output=["image", "latents"]) - image, latents = results["image"], results["latents"] - - # Continue from previous state - state = pipeline(prompt="A beautiful sunset") - new_state = pipeline(state=state, output="image") # Continue processing - ``` - - Returns: - - If `output` is None: Complete `PipelineState` containing all inputs and intermediates - - If `output` is str: The specific intermediate value from the state (e.g. `output="image"`) - - If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g. - `output=["image", "latents"]`) - """ - if state is None: - state = PipelineState() - - # Make a copy of the input kwargs - passed_kwargs = kwargs.copy() - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs] - for expected_input_param in self.blocks.inputs: - name = expected_input_param.name - default = expected_input_param.default - kwargs_type = expected_input_param.kwargs_type - if name in passed_kwargs: - if name not in intermediate_inputs: - state.set_input(name, passed_kwargs.pop(name), kwargs_type) - else: - state.set_input(name, passed_kwargs[name], kwargs_type) - elif name not in state.inputs: - state.set_input(name, default, kwargs_type) - - for expected_intermediate_param in self.blocks.intermediate_inputs: - name = expected_intermediate_param.name - kwargs_type = expected_intermediate_param.kwargs_type - if name in passed_kwargs: - state.set_intermediate(name, passed_kwargs.pop(name), kwargs_type) - - # Warn about unexpected inputs - if len(passed_kwargs) > 0: - warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - _, state = self.blocks(self, state) - except Exception: - error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n" - logger.error(error_msg) - raise - - if output is None: - return state - - elif isinstance(output, str): - return state.get_intermediate(output) - - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - def load_default_components(self, **kwargs): """ Load from_pretrained components using the loading specs in the config dict. @@ -2784,3 +2341,92 @@ def _dict_to_component_spec( type_hint=type_hint, **spec_dict, ) + + def set_progress_bar_config(self, **kwargs): + for sub_block_name, sub_block in self.blocks.sub_blocks.items(): + if hasattr(sub_block, "set_progress_bar_config"): + sub_block.set_progress_bar_config(**kwargs) + + def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Execute the pipeline by running the pipeline blocks with the given inputs. + + Args: + state (`PipelineState`, optional): + PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be + created based on the user inputs and the pipeline blocks's requirement. + output (`str` or `List[str]`, optional): + Optional specification of what to return: + - None: Returns the complete `PipelineState` with all inputs and intermediates (default) + - str: Returns a specific intermediate value from the state (e.g. `output="image"`) + - List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image", + "latents"]`) + + + Examples: + ```python + # Get complete pipeline state + state = pipeline(prompt="A beautiful sunset", num_inference_steps=20) + print(state.intermediates) # All intermediate outputs + + # Get specific output + image = pipeline(prompt="A beautiful sunset", output="image") + + # Get multiple specific outputs + results = pipeline(prompt="A beautiful sunset", output=["image", "latents"]) + image, latents = results["image"], results["latents"] + + # Continue from previous state + state = pipeline(prompt="A beautiful sunset") + new_state = pipeline(state=state, output="image") # Continue processing + ``` + + Returns: + - If `output` is None: Complete `PipelineState` containing all inputs and intermediates + - If `output` is str: The specific intermediate value from the state (e.g. `output="image"`) + - If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g. + `output=["image", "latents"]`) + """ + if state is None: + state = PipelineState() + + # Make a copy of the input kwargs + passed_kwargs = kwargs.copy() + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + intermediate_inputs = [inp.name for inp in self.blocks.inputs] + for expected_input_param in self.blocks.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: + if name not in intermediate_inputs: + state.set(name, passed_kwargs.pop(name), kwargs_type) + else: + state.set(name, passed_kwargs[name], kwargs_type) + elif name not in state.values: + state.set(name, default, kwargs_type) + + # Warn about unexpected inputs + if len(passed_kwargs) > 0: + warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + _, state = self.blocks(self, state) + except Exception: + error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n" + logger.error(error_msg) + raise + + if output is None: + return state + + if isinstance(output, str): + return state.get(output) + + elif isinstance(output, (list, tuple)): + return state.get(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index c56f4af1b8a5..61487cde15c1 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -27,7 +27,7 @@ from ...utils import logging from ...utils.torch_utils import randn_tensor, unwrap_module from ..modular_pipeline import ( - PipelineBlock, + ModularPipelineBlocks, PipelineState, ) from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam @@ -195,7 +195,7 @@ def prepare_latents_img2img( return latents -class StableDiffusionXLInputStep(PipelineBlock): +class StableDiffusionXLInputStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -213,11 +213,6 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [ InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam( "prompt_embeds", required=True, @@ -394,7 +389,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline return components, state -class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): +class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -421,11 +416,6 @@ def inputs(self) -> List[InputParam]: InputParam("denoising_start"), # YiYi TODO: do we need num_images_per_prompt here? InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam( "batch_size", required=True, @@ -543,7 +533,7 @@ def denoising_value_valid(dnv): return components, state -class StableDiffusionXLSetTimestepsStep(PipelineBlock): +class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -611,7 +601,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline return components, state -class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): +class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -640,11 +630,6 @@ def inputs(self) -> List[Tuple[str, Any]]: "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " "`denoising_start` being declared as an integer, the value of `strength` will be ignored.", ), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam("generator"), InputParam( "batch_size", @@ -744,8 +729,6 @@ def prepare_latents_inpaint( timestep=None, is_strength_max=True, add_noise=True, - return_noise=False, - return_image_latents=False, ): shape = ( batch_size, @@ -768,7 +751,7 @@ def prepare_latents_inpaint( if image.shape[1] == 4: image_latents = image.to(device=device, dtype=dtype) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) - elif return_image_latents or (latents is None and not is_strength_max): + elif latents is None and not is_strength_max: image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(components, image=image, generator=generator) image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) @@ -786,13 +769,7 @@ def prepare_latents_inpaint( noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = image_latents.to(device) - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_image_latents: - outputs += (image_latents,) + outputs = (latents, noise, image_latents) return outputs @@ -864,7 +841,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor - block_state.latents, block_state.noise = self.prepare_latents_inpaint( + block_state.latents, block_state.noise, block_state.image_latents = self.prepare_latents_inpaint( components, block_state.batch_size * block_state.num_images_per_prompt, components.num_channels_latents, @@ -878,8 +855,6 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline timestep=block_state.latent_timestep, is_strength_max=block_state.is_strength_max, add_noise=block_state.add_noise, - return_noise=True, - return_image_latents=False, ) # 7. Prepare mask latent variables @@ -900,7 +875,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline return components, state -class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): +class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -920,11 +895,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("latents"), InputParam("num_images_per_prompt", default=1), InputParam("denoising_start"), - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [ InputParam("generator"), InputParam( "latent_timestep", @@ -981,7 +951,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline return components, state -class StableDiffusionXLPrepareLatentsStep(PipelineBlock): +class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -1002,11 +972,6 @@ def inputs(self) -> List[InputParam]: InputParam("width"), InputParam("latents"), InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [ InputParam("generator"), InputParam( "batch_size", @@ -1092,7 +1057,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline return components, state -class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): +class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -1129,11 +1094,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("num_images_per_prompt", default=1), InputParam("aesthetic_score", default=6.0), InputParam("negative_aesthetic_score", default=2.0), - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [ InputParam( "latents", required=True, @@ -1316,7 +1276,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline return components, state -class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): +class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -1345,11 +1305,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)), InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [ InputParam( "latents", required=True, @@ -1499,7 +1454,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline return components, state -class StableDiffusionXLControlNetInputStep(PipelineBlock): +class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -1527,11 +1482,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam( "latents", required=True, @@ -1718,7 +1668,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline return components, state -class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): +class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index e9f627636e8c..38fa3b5c5183 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -24,7 +24,7 @@ from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ...utils import logging from ..modular_pipeline import ( - PipelineBlock, + ModularPipelineBlocks, PipelineState, ) from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam @@ -33,7 +33,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class StableDiffusionXLDecodeStep(PipelineBlock): +class StableDiffusionXLDecodeStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -56,17 +56,12 @@ def description(self) -> str: def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("output_type", default="pil"), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam( "latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step", - ) + ), ] @property @@ -157,7 +152,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState: return components, state -class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): +class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index 7fe4a472eec3..96df9711cc62 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -25,7 +25,7 @@ from ..modular_pipeline import ( BlockState, LoopSequentialPipelineBlocks, - PipelineBlock, + ModularPipelineBlocks, PipelineState, ) from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam @@ -37,7 +37,7 @@ # YiYi experimenting composible denoise loop # loop step (1): prepare latent input for denoiser -class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock): +class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -55,7 +55,7 @@ def description(self) -> str: ) @property - def intermediate_inputs(self) -> List[str]: + def inputs(self) -> List[str]: return [ InputParam( "latents", @@ -73,7 +73,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl # loop step (1): prepare latent input for denoiser (with inpainting) -class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock): +class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -91,7 +91,7 @@ def description(self) -> str: ) @property - def intermediate_inputs(self) -> List[str]: + def inputs(self) -> List[str]: return [ InputParam( "latents", @@ -144,7 +144,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl # loop step (2): denoise the latents with guidance -class StableDiffusionXLLoopDenoiser(PipelineBlock): +class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -171,11 +171,6 @@ def description(self) -> str: def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("cross_attention_kwargs"), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam( "num_inference_steps", required=True, @@ -249,7 +244,7 @@ def __call__( # loop step (2): denoise the latents with guidance (with controlnet) -class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): +class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -277,11 +272,6 @@ def description(self) -> str: def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("cross_attention_kwargs"), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam( "controlnet_cond", required=True, @@ -449,7 +439,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl # loop step (3): scheduler step to update latents -class StableDiffusionXLLoopAfterDenoiser(PipelineBlock): +class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -470,11 +460,6 @@ def description(self) -> str: def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("eta", default=0.0), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam("generator"), ] @@ -520,7 +505,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: Bl # loop step (3): scheduler step to update latents (with inpainting) -class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock): +class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -542,11 +527,6 @@ def description(self) -> str: def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("eta", default=0.0), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam("generator"), InputParam( "timesteps", @@ -660,7 +640,7 @@ def loop_expected_components(self) -> List[ComponentSpec]: ] @property - def loop_intermediate_inputs(self) -> List[InputParam]: + def loop_inputs(self) -> List[InputParam]: return [ InputParam( "timesteps", diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index bd0e962140e8..8926d6c1fb79 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -35,7 +35,7 @@ scale_lora_layers, unscale_lora_layers, ) -from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_pipeline import StableDiffusionXLModularPipeline @@ -57,7 +57,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class StableDiffusionXLIPAdapterStep(PipelineBlock): +class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -215,7 +215,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline return components, state -class StableDiffusionXLTextEncoderStep(PipelineBlock): +class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -576,7 +576,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline return components, state -class StableDiffusionXLVaeEncoderStep(PipelineBlock): +class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -601,11 +601,6 @@ def inputs(self) -> List[InputParam]: InputParam("image", required=True), InputParam("height"), InputParam("width"), - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [ InputParam("generator"), InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), InputParam( @@ -691,7 +686,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline return components, state -class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): +class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks): model_name = "stable-diffusion-xl" @property @@ -726,11 +721,6 @@ def inputs(self) -> List[InputParam]: InputParam("image", required=True), InputParam("mask_image", required=True), InputParam("padding_mask_crop"), - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), InputParam("generator"), ] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py index fc030fae56fb..0ee37f520135 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -247,10 +247,6 @@ def num_channels_latents(self): "control_mode": InputParam( "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet" ), -} - - -SDXL_INTERMEDIATE_INPUTS_SCHEMA = { "prompt_embeds": InputParam( "prompt_embeds", type_hint=torch.Tensor, @@ -271,13 +267,6 @@ def num_channels_latents(self): "preprocess_kwargs": InputParam( "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor" ), - "latents": InputParam( - "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process" - ), - "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), - "num_inference_steps": InputParam( - "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps" - ), "latent_timestep": InputParam( "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep" ), From 4524d432791fbb0b46c195e913973b4687fac6b8 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 30 Jul 2025 08:24:25 +0200 Subject: [PATCH 04/14] update --- .../modular_pipelines/wan/before_denoise.py | 8 ++++---- src/diffusers/modular_pipelines/wan/decoders.py | 4 ++-- src/diffusers/modular_pipelines/wan/denoise.py | 14 +++++++------- src/diffusers/modular_pipelines/wan/encoders.py | 4 ++-- .../modular_pipelines/wan/modular_blocks.py | 12 ++++++------ 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index ef65b6453725..2b9889f8778a 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -20,7 +20,7 @@ from ...schedulers import UniPCMultistepScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor -from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import WanModularPipeline @@ -94,7 +94,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class WanInputStep(PipelineBlock): +class WanInputStep(ModularPipelineBlocks): model_name = "wan" @property @@ -194,7 +194,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe return components, state -class WanSetTimestepsStep(PipelineBlock): +class WanSetTimestepsStep(ModularPipelineBlocks): model_name = "wan" @property @@ -243,7 +243,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe return components, state -class WanPrepareLatentsStep(PipelineBlock): +class WanPrepareLatentsStep(ModularPipelineBlocks): model_name = "wan" @property diff --git a/src/diffusers/modular_pipelines/wan/decoders.py b/src/diffusers/modular_pipelines/wan/decoders.py index 4fadeed4b954..8c751172d858 100644 --- a/src/diffusers/modular_pipelines/wan/decoders.py +++ b/src/diffusers/modular_pipelines/wan/decoders.py @@ -22,14 +22,14 @@ from ...models import AutoencoderKLWan from ...utils import logging from ...video_processor import VideoProcessor -from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class WanDecodeStep(PipelineBlock): +class WanDecodeStep(ModularPipelineBlocks): model_name = "wan" @property diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index 76c5cda5f95e..9528f7e548f3 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -23,8 +23,8 @@ from ...utils import logging from ..modular_pipeline import ( BlockState, - LoopSequentialPipelineBlocks, - PipelineBlock, + LoopSequentialModularPipelineBlockss, + ModularPipelineBlocks, PipelineState, ) from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam @@ -34,7 +34,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class WanLoopDenoiser(PipelineBlock): +class WanLoopDenoiser(ModularPipelineBlocks): model_name = "wan" @property @@ -53,7 +53,7 @@ def expected_components(self) -> List[ComponentSpec]: def description(self) -> str: return ( "Step within the denoising loop that denoise the latents with guidance. " - "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialModularPipelineBlockss` " "object (e.g. `WanDenoiseLoopWrapper`)" ) @@ -132,7 +132,7 @@ def __call__( return components, block_state -class WanLoopAfterDenoiser(PipelineBlock): +class WanLoopAfterDenoiser(ModularPipelineBlocks): model_name = "wan" @property @@ -145,7 +145,7 @@ def expected_components(self) -> List[ComponentSpec]: def description(self) -> str: return ( "step within the denoising loop that update the latents. " - "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialModularPipelineBlockss` " "object (e.g. `WanDenoiseLoopWrapper`)" ) @@ -181,7 +181,7 @@ def __call__(self, components: WanModularPipeline, block_state: BlockState, i: i return components, block_state -class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks): +class WanDenoiseLoopWrapper(LoopSequentialModularPipelineBlockss): model_name = "wan" @property diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index b2ecfd1aa61a..a0bf76b99b55 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance from ...utils import is_ftfy_available, logging -from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_pipeline import WanModularPipeline @@ -51,7 +51,7 @@ def prompt_clean(text): return text -class WanTextEncoderStep(PipelineBlock): +class WanTextEncoderStep(ModularPipelineBlocks): model_name = "wan" @property diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index 5f4c1a983566..a2c6646b972a 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -13,7 +13,7 @@ # limitations under the License. from ...utils import logging -from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline import AutoModularPipelineBlockss, SequentialModularPipelineBlockss from ..modular_pipeline_utils import InsertableDict from .before_denoise import ( WanInputStep, @@ -29,7 +29,7 @@ # before_denoise: text2vid -class WanBeforeDenoiseStep(SequentialPipelineBlocks): +class WanBeforeDenoiseStep(SequentialModularPipelineBlockss): block_classes = [ WanInputStep, WanSetTimestepsStep, @@ -49,7 +49,7 @@ def description(self): # before_denoise: all task (text2vid,) -class WanAutoBeforeDenoiseStep(AutoPipelineBlocks): +class WanAutoBeforeDenoiseStep(AutoModularPipelineBlockss): block_classes = [ WanBeforeDenoiseStep, ] @@ -66,7 +66,7 @@ def description(self): # denoise: text2vid -class WanAutoDenoiseStep(AutoPipelineBlocks): +class WanAutoDenoiseStep(AutoModularPipelineBlockss): block_classes = [ WanDenoiseStep, ] @@ -83,7 +83,7 @@ def description(self) -> str: # decode: all task (text2img, img2img, inpainting) -class WanAutoDecodeStep(AutoPipelineBlocks): +class WanAutoDecodeStep(AutoModularPipelineBlockss): block_classes = [WanDecodeStep] block_names = ["non-inpaint"] block_trigger_inputs = [None] @@ -94,7 +94,7 @@ def description(self): # text2vid -class WanAutoBlocks(SequentialPipelineBlocks): +class WanAutoBlocks(SequentialModularPipelineBlockss): block_classes = [ WanTextEncoderStep, WanAutoBeforeDenoiseStep, From 255c5742aa50ea25567d6d7bffc5bc9e7e76f36a Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 30 Jul 2025 08:33:51 +0200 Subject: [PATCH 05/14] update --- src/diffusers/modular_pipelines/modular_pipeline.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 72d1cd07a8c4..81cf51917084 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -615,9 +615,6 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState: if input_name is not None and state.get(input_name) is not None: block = self.trigger_to_block_map[input_name] break - elif input_name is not None and state.get(input_name) is not None: - block = self.trigger_to_block_map[input_name] - break if block is None: logger.warning(f"skipping auto block: {self.__class__.__name__}") From ea77fdc4b4c50aaa7d5e0d619aa43457c277a603 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 6 Aug 2025 17:17:51 +0530 Subject: [PATCH 06/14] update --- .../modular_pipelines/modular_pipeline.py | 30 +++++++++++++++---- .../modular_pipeline_utils.py | 3 +- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 81cf51917084..3f7436acfc94 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -264,6 +264,18 @@ def inputs(self) -> List[InputParam]: """List of input parameters. Must be implemented by subclasses.""" return [] + def _get_required_inputs(self): + input_names = [] + for input_param in self.inputs: + if input_param.required: + input_names.append(input_param.name) + + return input_names + + @property + def required_inputs(self) -> List[InputParam]: + return self._get_required_inputs() + @property def intermediate_outputs(self) -> List[OutputParam]: """List of intermediate output parameters. Must be implemented by subclasses.""" @@ -492,6 +504,17 @@ def intermediate_output_names(self) -> List[str]: def output_names(self) -> List[str]: return [output_param.name for output_param in self.outputs] + @property + def doc(self): + return make_doc_string( + self.inputs, + self.outputs, + self.description, + class_name=self.__class__.__name__, + expected_components=self.expected_components, + expected_configs=self.expected_configs, + ) + class AutoPipelineBlocks(ModularPipelineBlocks): """ @@ -743,7 +766,6 @@ def __repr__(self): def doc(self): return make_doc_string( self.inputs, - self.intermediate_inputs, self.outputs, self.description, class_name=self.__class__.__name__, @@ -2394,16 +2416,12 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = # Add inputs to state, using defaults if not provided in the kwargs or the state # if same input already in the state, will override it if provided in the kwargs - intermediate_inputs = [inp.name for inp in self.blocks.inputs] for expected_input_param in self.blocks.inputs: name = expected_input_param.name default = expected_input_param.default kwargs_type = expected_input_param.kwargs_type if name in passed_kwargs: - if name not in intermediate_inputs: - state.set(name, passed_kwargs.pop(name), kwargs_type) - else: - state.set(name, passed_kwargs[name], kwargs_type) + state.set(name, passed_kwargs.pop(name), kwargs_type) elif name not in state.values: state.set(name, default, kwargs_type) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index f2fc015e948f..9118f13aa071 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines def make_doc_string( inputs, - intermediate_inputs, outputs, description="", class_name=None, @@ -664,7 +663,7 @@ def make_doc_string( output += configs_str + "\n\n" # Add inputs section - output += format_input_params(inputs + intermediate_inputs, indent_level=2) + output += format_input_params(inputs, indent_level=2) # Add outputs section output += "\n\n" From 1b4af6b7ef2df5087ca19c1057a8a214ef069bb8 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 6 Aug 2025 17:43:21 +0200 Subject: [PATCH 07/14] update --- .../stable_diffusion_xl/before_denoise.py | 5 ---- .../stable_diffusion_xl/decoders.py | 5 ---- .../stable_diffusion_xl/encoders.py | 29 ++++++++----------- 3 files changed, 12 insertions(+), 27 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 61487cde15c1..fbe0d22a52f9 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -1697,11 +1697,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("controlnet_conditioning_scale", default=1.0), InputParam("guess_mode", default=False), InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [ InputParam( "latents", required=True, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index 38fa3b5c5183..feb78e1ef11b 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -179,11 +179,6 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("image"), InputParam("mask_image"), InputParam("padding_mask_crop"), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam( "images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index 8926d6c1fb79..1e8921d363c1 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -663,12 +663,11 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline block_state.device = components._execution_device block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.image = components.image_processor.preprocess( + image = components.image_processor.preprocess( block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs ) - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - - block_state.batch_size = block_state.image.shape[0] + image = image.to(device=block_state.device, dtype=block_state.dtype) + block_state.batch_size = image.shape[0] # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: @@ -677,9 +676,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." ) - block_state.image_latents = self._encode_vae_image( - components, image=block_state.image, generator=block_state.generator - ) + block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator) self.set_block_state(state, block_state) @@ -850,34 +847,32 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline block_state.crops_coords = None block_state.resize_mode = "default" - block_state.image = components.image_processor.preprocess( + image = components.image_processor.preprocess( block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode, ) - block_state.image = block_state.image.to(dtype=torch.float32) + image = image.to(dtype=torch.float32) - block_state.mask = components.mask_processor.preprocess( + mask = components.mask_processor.preprocess( block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords, ) - block_state.masked_image = block_state.image * (block_state.mask < 0.5) + block_state.masked_image = image * (mask < 0.5) - block_state.batch_size = block_state.image.shape[0] - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - block_state.image_latents = self._encode_vae_image( - components, image=block_state.image, generator=block_state.generator - ) + block_state.batch_size = image.shape[0] + image = image.to(device=block_state.device, dtype=block_state.dtype) + block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator) # 7. Prepare mask latent variables block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( components, - block_state.mask, + mask, block_state.masked_image, block_state.batch_size, block_state.height, From 9a0cc463ee89326598ea611fb7ced9cef34ecfe2 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 6 Aug 2025 19:32:23 +0200 Subject: [PATCH 08/14] update --- src/diffusers/modular_pipelines/modular_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 924754e6dc4f..47fa52adfa38 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1326,7 +1326,6 @@ def __call__(self, components, state: PipelineState) -> PipelineState: def doc(self): return make_doc_string( self.inputs, - self.intermediate_inputs, self.outputs, self.description, class_name=self.__class__.__name__, From d1342d7464955c71c9ecacca41a6f1b2a502197e Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 8 Aug 2025 12:10:06 +0200 Subject: [PATCH 09/14] update --- .../modular_pipelines/flux/before_denoise.py | 25 ++++--------------- .../modular_pipelines/flux/decoders.py | 11 +++----- .../modular_pipelines/flux/denoise.py | 13 ++++------ .../modular_pipelines/flux/encoders.py | 4 +-- 4 files changed, 15 insertions(+), 38 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index ffc77bb24fdb..c51061b78f9d 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -21,7 +21,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor -from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import FluxModularPipeline @@ -125,7 +125,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): return latent_image_ids.to(device=device, dtype=dtype) -class FluxInputStep(PipelineBlock): +class FluxInputStep(ModularPipelineBlocks): model_name = "flux" @property @@ -143,11 +143,6 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [ InputParam("num_images_per_prompt", default=1), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam( "prompt_embeds", required=True, @@ -216,7 +211,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip return components, state -class FluxSetTimestepsStep(PipelineBlock): +class FluxSetTimestepsStep(ModularPipelineBlocks): model_name = "flux" @property @@ -235,17 +230,12 @@ def inputs(self) -> List[InputParam]: InputParam("sigmas"), InputParam("guidance_scale", default=3.5), InputParam("latents", type_hint=torch.Tensor), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam( "latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", - ) + ), ] @property @@ -296,7 +286,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip return components, state -class FluxPrepareLatentsStep(PipelineBlock): +class FluxPrepareLatentsStep(ModularPipelineBlocks): model_name = "flux" @property @@ -314,11 +304,6 @@ def inputs(self) -> List[InputParam]: InputParam("width", type_hint=int), InputParam("latents", type_hint=Optional[torch.Tensor]), InputParam("num_images_per_prompt", type_hint=int, default=1), - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [ InputParam("generator"), InputParam( "batch_size", diff --git a/src/diffusers/modular_pipelines/flux/decoders.py b/src/diffusers/modular_pipelines/flux/decoders.py index 8d561d38c6f2..846549b1a376 100644 --- a/src/diffusers/modular_pipelines/flux/decoders.py +++ b/src/diffusers/modular_pipelines/flux/decoders.py @@ -22,7 +22,7 @@ from ...models import AutoencoderKL from ...utils import logging from ...video_processor import VaeImageProcessor -from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam @@ -45,7 +45,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents -class FluxDecodeStep(PipelineBlock): +class FluxDecodeStep(ModularPipelineBlocks): model_name = "flux" @property @@ -70,17 +70,12 @@ def inputs(self) -> List[Tuple[str, Any]]: InputParam("output_type", default="pil"), InputParam("height", default=1024), InputParam("width", default=1024), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam( "latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step", - ) + ), ] @property diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py index c4619c17fb0e..062f1ad49f6a 100644 --- a/src/diffusers/modular_pipelines/flux/denoise.py +++ b/src/diffusers/modular_pipelines/flux/denoise.py @@ -22,7 +22,7 @@ from ..modular_pipeline import ( BlockState, LoopSequentialPipelineBlocks, - PipelineBlock, + ModularPipelineBlocks, PipelineState, ) from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam @@ -32,7 +32,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class FluxLoopDenoiser(PipelineBlock): +class FluxLoopDenoiser(ModularPipelineBlocks): model_name = "flux" @property @@ -49,11 +49,8 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: - return [InputParam("joint_attention_kwargs")] - - @property - def intermediate_inputs(self) -> List[str]: return [ + InputParam("joint_attention_kwargs"), InputParam( "latents", required=True, @@ -113,7 +110,7 @@ def __call__( return components, block_state -class FluxLoopAfterDenoiser(PipelineBlock): +class FluxLoopAfterDenoiser(ModularPipelineBlocks): model_name = "flux" @property @@ -175,7 +172,7 @@ def loop_expected_components(self) -> List[ComponentSpec]: ] @property - def loop_intermediate_inputs(self) -> List[InputParam]: + def loop_inputs(self) -> List[InputParam]: return [ InputParam( "timesteps", diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py index 9bf2f54eece3..9681c1e9d98a 100644 --- a/src/diffusers/modular_pipelines/flux/encoders.py +++ b/src/diffusers/modular_pipelines/flux/encoders.py @@ -21,7 +21,7 @@ from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers -from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_pipeline import FluxModularPipeline @@ -50,7 +50,7 @@ def prompt_clean(text): return text -class FluxTextEncoderStep(PipelineBlock): +class FluxTextEncoderStep(ModularPipelineBlocks): model_name = "flux" @property From c678e8a445bf041d850cbd253304be8ff98748d8 Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 8 Aug 2025 19:47:50 +0530 Subject: [PATCH 10/14] update --- src/diffusers/modular_pipelines/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index e0f2e31388ce..ceee942b3d56 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -25,7 +25,6 @@ _import_structure["modular_pipeline"] = [ "ModularPipelineBlocks", "ModularPipeline", - "PipelineBlock", "AutoPipelineBlocks", "SequentialPipelineBlocks", "LoopSequentialPipelineBlocks", @@ -59,7 +58,6 @@ LoopSequentialPipelineBlocks, ModularPipeline, ModularPipelineBlocks, - PipelineBlock, PipelineState, SequentialPipelineBlocks, ) From 085e9cba36ebe7fe8f69ea8732af4931bb666e53 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 8 Aug 2025 18:22:44 +0200 Subject: [PATCH 11/14] update --- .../modular_pipelines/wan/modular_blocks.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index a2c6646b972a..5f4c1a983566 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -13,7 +13,7 @@ # limitations under the License. from ...utils import logging -from ..modular_pipeline import AutoModularPipelineBlockss, SequentialModularPipelineBlockss +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict from .before_denoise import ( WanInputStep, @@ -29,7 +29,7 @@ # before_denoise: text2vid -class WanBeforeDenoiseStep(SequentialModularPipelineBlockss): +class WanBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = [ WanInputStep, WanSetTimestepsStep, @@ -49,7 +49,7 @@ def description(self): # before_denoise: all task (text2vid,) -class WanAutoBeforeDenoiseStep(AutoModularPipelineBlockss): +class WanAutoBeforeDenoiseStep(AutoPipelineBlocks): block_classes = [ WanBeforeDenoiseStep, ] @@ -66,7 +66,7 @@ def description(self): # denoise: text2vid -class WanAutoDenoiseStep(AutoModularPipelineBlockss): +class WanAutoDenoiseStep(AutoPipelineBlocks): block_classes = [ WanDenoiseStep, ] @@ -83,7 +83,7 @@ def description(self) -> str: # decode: all task (text2img, img2img, inpainting) -class WanAutoDecodeStep(AutoModularPipelineBlockss): +class WanAutoDecodeStep(AutoPipelineBlocks): block_classes = [WanDecodeStep] block_names = ["non-inpaint"] block_trigger_inputs = [None] @@ -94,7 +94,7 @@ def description(self): # text2vid -class WanAutoBlocks(SequentialModularPipelineBlockss): +class WanAutoBlocks(SequentialPipelineBlocks): block_classes = [ WanTextEncoderStep, WanAutoBeforeDenoiseStep, From 6c85fcd8997fde6999f491bbaa4ab69e4f774707 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 8 Aug 2025 18:52:55 +0200 Subject: [PATCH 12/14] update --- src/diffusers/modular_pipelines/wan/denoise.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index 9528f7e548f3..9871d4ad618c 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -23,7 +23,7 @@ from ...utils import logging from ..modular_pipeline import ( BlockState, - LoopSequentialModularPipelineBlockss, + LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState, ) @@ -53,7 +53,7 @@ def expected_components(self) -> List[ComponentSpec]: def description(self) -> str: return ( "Step within the denoising loop that denoise the latents with guidance. " - "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialModularPipelineBlockss` " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `WanDenoiseLoopWrapper`)" ) @@ -145,7 +145,7 @@ def expected_components(self) -> List[ComponentSpec]: def description(self) -> str: return ( "step within the denoising loop that update the latents. " - "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialModularPipelineBlockss` " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " "object (e.g. `WanDenoiseLoopWrapper`)" ) @@ -181,7 +181,7 @@ def __call__(self, components: WanModularPipeline, block_state: BlockState, i: i return components, block_state -class WanDenoiseLoopWrapper(LoopSequentialModularPipelineBlockss): +class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks): model_name = "wan" @property From 512044c5ea8ef1f3c80dddf119d4b901ee14d536 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sat, 9 Aug 2025 15:06:18 +0200 Subject: [PATCH 13/14] update --- .../test_modular_pipeline_stable_diffusion_xl.py | 4 ---- tests/modular_pipelines/test_modular_pipelines_common.py | 2 -- 2 files changed, 6 deletions(-) diff --git a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py index 4127d00c8e1a..044cdd57daea 100644 --- a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py +++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py @@ -117,13 +117,9 @@ def test_pipeline_inputs_and_blocks(self): _ = blocks.sub_blocks.pop("ip_adapter") parameters = blocks.input_names - intermediate_parameters = blocks.intermediate_input_names assert "ip_adapter_image" not in parameters, ( "`ip_adapter_image` argument must be removed from the `__call__` method" ) - assert "ip_adapter_image_embeds" not in intermediate_parameters, ( - "`ip_adapter_image_embeds` argument must be supported by the `__call__` method" - ) def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): return torch.randn((1, 1, cross_attention_dim), device=torch_device) diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py index 6240797742d4..36595b02a24c 100644 --- a/tests/modular_pipelines/test_modular_pipelines_common.py +++ b/tests/modular_pipelines/test_modular_pipelines_common.py @@ -139,7 +139,6 @@ def tearDown(self): def test_pipeline_call_signature(self): pipe = self.get_pipeline() input_parameters = pipe.blocks.input_names - intermediate_parameters = pipe.blocks.intermediate_input_names optional_parameters = pipe.default_call_parameters def _check_for_parameters(parameters, expected_parameters, param_type): @@ -149,7 +148,6 @@ def _check_for_parameters(parameters, expected_parameters, param_type): ) _check_for_parameters(self.params, input_parameters, "input") - _check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate") _check_for_parameters(self.optional_params, optional_parameters, "optional") def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True): From fb8722e9abe18235930c688e599a4032bb392058 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Sat, 9 Aug 2025 16:00:24 +0200 Subject: [PATCH 14/14] update --- src/diffusers/modular_pipelines/modular_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 99e12749dd16..8a05cce209c5 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -800,7 +800,7 @@ def description(self): @property def model_name(self): - return next(iter(self.sub_blocks.values())).model_name + return next((block.model_name for block in self.sub_blocks.values() if block.model_name is not None), None) @property def expected_components(self):