Skip to content

Commit 27dde51

Browse files
committed
add output arg to run_blocks
1 parent 10d4a77 commit 27dde51

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

src/diffusers/pipelines/modular_pipeline_builder.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,12 @@ def get_intermediate(self, key: str, default: Any = None) -> Any:
6969
return self.intermediates.get(key, default)
7070

7171
def get_output(self, key: str, default: Any = None) -> Any:
72-
return self.outputs.get(key, default)
72+
if key in self.outputs:
73+
return self.outputs[key]
74+
elif key in self.intermediates:
75+
return self.intermediates[key]
76+
else:
77+
return default
7378

7479
def to_dict(self) -> Dict[str, Any]:
7580
return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates, "outputs": self.outputs}
@@ -1132,7 +1137,7 @@ def replace_blocks(self, pipeline_blocks, at: int):
11321137
# Remove the old blocks
11331138
self.remove_blocks(indices_to_remove)
11341139

1135-
def run_blocks(self, state: PipelineState = None, **kwargs):
1140+
def run_blocks(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
11361141
"""
11371142
Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
11381143
"""
@@ -1174,7 +1179,18 @@ def run_blocks(self, state: PipelineState = None, **kwargs):
11741179
raise
11751180
self.maybe_free_model_hooks()
11761181

1177-
return state
1182+
if output is None:
1183+
return state
1184+
1185+
if isinstance(output, str):
1186+
return state.get_output(output)
1187+
elif isinstance(output, (list, tuple)):
1188+
outputs = {}
1189+
for output_name in output:
1190+
outputs[output_name] = state.get_output(output_name)
1191+
return outputs
1192+
else:
1193+
raise ValueError(f"Output '{output}' is not a valid output type")
11781194

11791195
def run_pipeline(self, **kwargs):
11801196
state = PipelineState()

0 commit comments

Comments
 (0)