Skip to content

Commit 5cde77f

Browse files
committed
make inputs truly immutable, remove the output logic in sequential pipeline, and update so that intermediates_outputs are only new variables
1 parent 522e827 commit 5cde77f

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from collections import OrderedDict
1818
from dataclasses import dataclass, field
1919
from typing import Any, Dict, List, Tuple, Union, Optional, Type
20+
from copy import deepcopy
2021

2122

2223
import torch
@@ -109,7 +110,9 @@ def add_intermediate(self, key: str, value: Any, kwargs_type: str = None):
109110
self.intermediate_kwargs[kwargs_type].append(key)
110111

111112
def get_input(self, key: str, default: Any = None) -> Any:
112-
return self.inputs.get(key, default)
113+
value = self.inputs.get(key, default)
114+
if value is not None:
115+
return deepcopy(value)
113116

114117
def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
115118
return {key: self.inputs.get(key, default) for key in keys}
@@ -483,6 +486,7 @@ def doc(self):
483486
)
484487

485488

489+
# YiYi TODO: input and inteermediate inputs with same name? should warn?
486490
def get_block_state(self, state: PipelineState) -> dict:
487491
"""Get all inputs and intermediates in one dictionary"""
488492
data = {}
@@ -1032,14 +1036,21 @@ def get_intermediates_inputs(self):
10321036

10331037
@property
10341038
def intermediates_outputs(self) -> List[str]:
1035-
named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()]
1039+
named_outputs = []
1040+
for name, block in self.blocks.items():
1041+
inp_names = set([inp.name for inp in block.intermediates_inputs])
1042+
# so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
1043+
# filter out them here so they do not end up as intermediates_outputs
1044+
if name not in inp_names:
1045+
named_outputs.append((name, block.intermediates_outputs))
10361046
combined_outputs = combine_outputs(*named_outputs)
10371047
return combined_outputs
10381048

1049+
# YiYi TODO: I think we can remove the outputs property
10391050
@property
10401051
def outputs(self) -> List[str]:
1041-
return next(reversed(self.blocks.values())).intermediates_outputs
1042-
1052+
# return next(reversed(self.blocks.values())).intermediates_outputs
1053+
return self.intermediates_outputs
10431054
@torch.no_grad()
10441055
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
10451056
for block_name, block in self.blocks.items():

0 commit comments

Comments
 (0)