Skip to content

Commit 496bf0b

Browse files
committed
update
1 parent 1db6365 commit 496bf0b

File tree

2 files changed

+14
-213
lines changed

2 files changed

+14
-213
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 13 additions & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@
4545
OutputParam,
4646
format_components,
4747
format_configs,
48-
format_inputs_short,
49-
format_intermediates_short,
5048
make_doc_string,
5149
)
5250

@@ -142,12 +140,8 @@ def format_value(v):
142140
values_str = "\n".join(f" {k}: {format_value(v)}" for k, v in self.values.items())
143141
kwargs_mapping_str = "\n".join(f" {k}: {v}" for k, v in self.kwargs_mapping.items())
144142

145-
return (
146-
f"PipelineState(\n"
147-
f" values={{\n{values_str}\n }},\n"
148-
f" kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n"
149-
f")"
150-
)
143+
return f"PipelineState(\n values={{\n{values_str}\n }},\n kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n)"
144+
151145

152146
@dataclass
153147
class BlockState:
@@ -402,20 +396,21 @@ def set_block_state(self, state: PipelineState, block_state: BlockState):
402396
current_value = state.get(input_param.name)
403397
if current_value is not param: # Using identity comparison to check if object was modified
404398
state.set(input_param.name, param, input_param.kwargs_type)
399+
405400
elif input_param.kwargs_type:
406-
import ipdb; ipdb.set_trace()
407401
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
408402
# we need to first find out which inputs are and loop through them.
409403
intermediate_kwargs = state.get_by_kwargs(input_param.kwargs_type)
410404
for param_name, current_value in intermediate_kwargs.items():
411-
try:
412-
if not hasattr(block_state, param_name):
413-
continue
414-
param = getattr(block_state, param_name)
415-
if current_value is not param: # Using identity comparison to check if object was modified
416-
state.set(param_name, param, input_param.kwargs_type)
417-
except:
418-
import ipdb; ipdb.set_trace()
405+
if param_name is None:
406+
continue
407+
408+
if not hasattr(block_state, param_name):
409+
continue
410+
411+
param = getattr(block_state, param_name)
412+
if current_value is not param: # Using identity comparison to check if object was modified
413+
state.set(param_name, param, input_param.kwargs_type)
419414

420415
@staticmethod
421416
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
@@ -496,200 +491,6 @@ def output_names(self) -> List[str]:
496491
return [output_param.name for output_param in self.outputs]
497492

498493

499-
class PipelineBlock(ModularPipelineBlocks):
500-
"""
501-
A Pipeline Block is the basic building block of a Modular Pipeline.
502-
503-
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
504-
library implements for all the pipeline blocks (such as loading or saving etc.)
505-
506-
<Tip warning={true}>
507-
508-
This is an experimental feature and is likely to change in the future.
509-
510-
</Tip>
511-
512-
Args:
513-
description (str, optional): A description of the block, defaults to None. Define as a property in subclasses.
514-
expected_components (List[ComponentSpec], optional):
515-
A list of components that are expected to be used in the block, defaults to []. To override, define as a
516-
property in subclasses.
517-
expected_configs (List[ConfigSpec], optional):
518-
A list of configs that are expected to be used in the block, defaults to []. To override, define as a
519-
property in subclasses.
520-
inputs (List[InputParam], optional):
521-
A list of inputs that are expected to be used in the block, defaults to []. To override, define as a
522-
property in subclasses.
523-
intermediate_inputs (List[InputParam], optional):
524-
A list of intermediate inputs that are expected to be used in the block, defaults to []. To override,
525-
define as a property in subclasses.
526-
intermediate_outputs (List[OutputParam], optional):
527-
A list of intermediate outputs that are expected to be used in the block, defaults to []. To override,
528-
define as a property in subclasses.
529-
outputs (List[OutputParam], optional):
530-
A list of outputs that are expected to be used in the block, defaults to []. To override, define as a
531-
property in subclasses.
532-
required_inputs (List[str], optional):
533-
A list of required inputs that are expected to be used in the block, defaults to []. To override, define as
534-
a property in subclasses.
535-
required_intermediate_inputs (List[str], optional):
536-
A list of required intermediate inputs that are expected to be used in the block, defaults to []. To
537-
override, define as a property in subclasses.
538-
required_intermediate_outputs (List[str], optional):
539-
A list of required intermediate outputs that are expected to be used in the block, defaults to []. To
540-
override, define as a property in subclasses.
541-
"""
542-
543-
model_name = None
544-
545-
def __init__(self):
546-
self.sub_blocks = InsertableDict()
547-
548-
@property
549-
def description(self) -> str:
550-
"""Description of the block. Must be implemented by subclasses."""
551-
# raise NotImplementedError("description method must be implemented in subclasses")
552-
return "TODO: add a description"
553-
554-
@property
555-
def expected_components(self) -> List[ComponentSpec]:
556-
return []
557-
558-
@property
559-
def expected_configs(self) -> List[ConfigSpec]:
560-
return []
561-
562-
@property
563-
def inputs(self) -> List[InputParam]:
564-
"""List of input parameters. Must be implemented by subclasses."""
565-
return []
566-
567-
@property
568-
def intermediate_inputs(self) -> List[InputParam]:
569-
"""List of intermediate input parameters. Must be implemented by subclasses."""
570-
return []
571-
572-
@property
573-
def intermediate_outputs(self) -> List[OutputParam]:
574-
"""List of intermediate output parameters. Must be implemented by subclasses."""
575-
return []
576-
577-
def _get_outputs(self):
578-
return self.intermediate_outputs
579-
580-
# YiYi TODO: is it too easy for user to unintentionally override these properties?
581-
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
582-
@property
583-
def outputs(self) -> List[OutputParam]:
584-
return self._get_outputs()
585-
586-
def _get_required_inputs(self):
587-
input_names = []
588-
for input_param in self.inputs:
589-
if input_param.required:
590-
input_names.append(input_param.name)
591-
return input_names
592-
593-
@property
594-
def required_inputs(self) -> List[str]:
595-
return self._get_required_inputs()
596-
597-
def _get_required_intermediate_inputs(self):
598-
input_names = []
599-
for input_param in self.intermediate_inputs:
600-
if input_param.required:
601-
input_names.append(input_param.name)
602-
return input_names
603-
604-
# YiYi TODO: maybe we do not need this, it is only used in docstring,
605-
# intermediate_inputs is by default required, unless you manually handle it inside the block
606-
@property
607-
def required_intermediate_inputs(self) -> List[str]:
608-
return self._get_required_intermediate_inputs()
609-
610-
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
611-
raise NotImplementedError("__call__ method must be implemented in subclasses")
612-
613-
def __repr__(self):
614-
class_name = self.__class__.__name__
615-
base_class = self.__class__.__bases__[0].__name__
616-
617-
# Format description with proper indentation
618-
desc_lines = self.description.split("\n")
619-
desc = []
620-
# First line with "Description:" label
621-
desc.append(f" Description: {desc_lines[0]}")
622-
# Subsequent lines with proper indentation
623-
if len(desc_lines) > 1:
624-
desc.extend(f" {line}" for line in desc_lines[1:])
625-
desc = "\n".join(desc) + "\n"
626-
627-
# Components section - use format_components with add_empty_lines=False
628-
expected_components = getattr(self, "expected_components", [])
629-
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
630-
components = " " + components_str.replace("\n", "\n ")
631-
632-
# Configs section - use format_configs with add_empty_lines=False
633-
expected_configs = getattr(self, "expected_configs", [])
634-
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
635-
configs = " " + configs_str.replace("\n", "\n ")
636-
637-
# Inputs section
638-
inputs_str = format_inputs_short(self.inputs)
639-
inputs = "Inputs:\n " + inputs_str
640-
641-
# Intermediates section
642-
intermediates_str = format_intermediates_short(
643-
self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs
644-
)
645-
intermediates = f"Intermediates:\n{intermediates_str}"
646-
647-
return f"{class_name}(\n Class: {base_class}\n{desc}{components}\n{configs}\n {inputs}\n {intermediates}\n)"
648-
649-
@property
650-
def doc(self):
651-
return make_doc_string(
652-
self.inputs,
653-
self.intermediate_inputs,
654-
self.outputs,
655-
self.description,
656-
class_name=self.__class__.__name__,
657-
expected_components=self.expected_components,
658-
expected_configs=self.expected_configs,
659-
)
660-
661-
def set_block_state(self, state: PipelineState, block_state: BlockState):
662-
for output_param in self.intermediate_outputs:
663-
if not hasattr(block_state, output_param.name):
664-
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
665-
param = getattr(block_state, output_param.name)
666-
state.set(output_param.name, param, output_param.kwargs_type)
667-
668-
for input_param in self.intermediate_inputs:
669-
if hasattr(block_state, input_param.name):
670-
param = getattr(block_state, input_param.name)
671-
# Only add if the value is different from what's in the state
672-
current_value = state.get(input_param.name)
673-
if current_value is not param: # Using identity comparison to check if object was modified
674-
state.set(input_param.name, param, input_param.kwargs_type)
675-
676-
for input_param in self.intermediate_inputs:
677-
if input_param.name and hasattr(block_state, input_param.name):
678-
param = getattr(block_state, input_param.name)
679-
# Only add if the value is different from what's in the state
680-
current_value = state.get(input_param.name)
681-
if current_value is not param: # Using identity comparison to check if object was modified
682-
state.set(input_param.name, param, input_param.kwargs_type)
683-
elif input_param.kwargs_type:
684-
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
685-
# we need to first find out which inputs are and loop through them.
686-
intermediate_kwargs = state.get_kwargs(input_param.kwargs_type)
687-
for param_name, current_value in intermediate_kwargs.items():
688-
param = getattr(block_state, param_name)
689-
if current_value is not param: # Using identity comparison to check if object was modified
690-
state.set(param_name, param, input_param.kwargs_type)
691-
692-
693494
class AutoPipelineBlocks(ModularPipelineBlocks):
694495
"""
695496
A Pipeline Blocks that automatically selects a block to run based on the inputs.
@@ -1042,7 +843,7 @@ def _get_inputs(self):
1042843
if inp.name not in outputs and inp.name not in {input.name for input in inputs}:
1043844
inputs.append(inp)
1044845

1045-
# Only add outputs if the block cannot be skipped
846+
# Only add outputs if the block cannot be skipped
1046847
should_add_outputs = True
1047848
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
1048849
should_add_outputs = False

src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def inputs(self) -> List[Tuple[str, Any]]:
6161
required=True,
6262
type_hint=torch.Tensor,
6363
description="The denoised latents from the denoising step",
64-
)
64+
),
6565
]
6666

6767
@property

0 commit comments

Comments
 (0)