Skip to content

Commit ea77fdc

Browse files
committed
update
1 parent 255c574 commit ea77fdc

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,18 @@ def inputs(self) -> List[InputParam]:
264264
"""List of input parameters. Must be implemented by subclasses."""
265265
return []
266266

267+
def _get_required_inputs(self):
268+
input_names = []
269+
for input_param in self.inputs:
270+
if input_param.required:
271+
input_names.append(input_param.name)
272+
273+
return input_names
274+
275+
@property
276+
def required_inputs(self) -> List[InputParam]:
277+
return self._get_required_inputs()
278+
267279
@property
268280
def intermediate_outputs(self) -> List[OutputParam]:
269281
"""List of intermediate output parameters. Must be implemented by subclasses."""
@@ -492,6 +504,17 @@ def intermediate_output_names(self) -> List[str]:
492504
def output_names(self) -> List[str]:
493505
return [output_param.name for output_param in self.outputs]
494506

507+
@property
508+
def doc(self):
509+
return make_doc_string(
510+
self.inputs,
511+
self.outputs,
512+
self.description,
513+
class_name=self.__class__.__name__,
514+
expected_components=self.expected_components,
515+
expected_configs=self.expected_configs,
516+
)
517+
495518

496519
class AutoPipelineBlocks(ModularPipelineBlocks):
497520
"""
@@ -743,7 +766,6 @@ def __repr__(self):
743766
def doc(self):
744767
return make_doc_string(
745768
self.inputs,
746-
self.intermediate_inputs,
747769
self.outputs,
748770
self.description,
749771
class_name=self.__class__.__name__,
@@ -2394,16 +2416,12 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
23942416

23952417
# Add inputs to state, using defaults if not provided in the kwargs or the state
23962418
# if same input already in the state, will override it if provided in the kwargs
2397-
intermediate_inputs = [inp.name for inp in self.blocks.inputs]
23982419
for expected_input_param in self.blocks.inputs:
23992420
name = expected_input_param.name
24002421
default = expected_input_param.default
24012422
kwargs_type = expected_input_param.kwargs_type
24022423
if name in passed_kwargs:
2403-
if name not in intermediate_inputs:
2404-
state.set(name, passed_kwargs.pop(name), kwargs_type)
2405-
else:
2406-
state.set(name, passed_kwargs[name], kwargs_type)
2424+
state.set(name, passed_kwargs.pop(name), kwargs_type)
24072425
elif name not in state.values:
24082426
state.set(name, default, kwargs_type)
24092427

src/diffusers/modular_pipelines/modular_pipeline_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
618618

619619
def make_doc_string(
620620
inputs,
621-
intermediate_inputs,
622621
outputs,
623622
description="",
624623
class_name=None,
@@ -664,7 +663,7 @@ def make_doc_string(
664663
output += configs_str + "\n\n"
665664

666665
# Add inputs section
667-
output += format_input_params(inputs + intermediate_inputs, indent_level=2)
666+
output += format_input_params(inputs, indent_level=2)
668667

669668
# Add outputs section
670669
output += "\n\n"

0 commit comments

Comments
 (0)