|
17 | 17 | from collections import OrderedDict |
18 | 18 | from dataclasses import dataclass, field |
19 | 19 | from typing import Any, Dict, List, Tuple, Union, Optional, Type |
| 20 | +from copy import deepcopy |
20 | 21 |
|
21 | 22 |
|
22 | 23 | import torch |
@@ -109,7 +110,9 @@ def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): |
109 | 110 | self.intermediate_kwargs[kwargs_type].append(key) |
110 | 111 |
|
111 | 112 | def get_input(self, key: str, default: Any = None) -> Any: |
112 | | - return self.inputs.get(key, default) |
| 113 | + value = self.inputs.get(key, default) |
| 114 | + if value is not None: |
| 115 | + return deepcopy(value) |
113 | 116 |
|
114 | 117 | def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: |
115 | 118 | return {key: self.inputs.get(key, default) for key in keys} |
@@ -483,6 +486,7 @@ def doc(self): |
483 | 486 | ) |
484 | 487 |
|
485 | 488 |
|
| 489 | + # YiYi TODO: input and inteermediate inputs with same name? should warn? |
486 | 490 | def get_block_state(self, state: PipelineState) -> dict: |
487 | 491 | """Get all inputs and intermediates in one dictionary""" |
488 | 492 | data = {} |
@@ -1032,14 +1036,21 @@ def get_intermediates_inputs(self): |
1032 | 1036 |
|
1033 | 1037 | @property |
1034 | 1038 | def intermediates_outputs(self) -> List[str]: |
1035 | | - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] |
| 1039 | + named_outputs = [] |
| 1040 | + for name, block in self.blocks.items(): |
| 1041 | + inp_names = set([inp.name for inp in block.intermediates_inputs]) |
| 1042 | + # so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) |
| 1043 | + # filter out them here so they do not end up as intermediates_outputs |
| 1044 | + if name not in inp_names: |
| 1045 | + named_outputs.append((name, block.intermediates_outputs)) |
1036 | 1046 | combined_outputs = combine_outputs(*named_outputs) |
1037 | 1047 | return combined_outputs |
1038 | 1048 |
|
| 1049 | + # YiYi TODO: I think we can remove the outputs property |
1039 | 1050 | @property |
1040 | 1051 | def outputs(self) -> List[str]: |
1041 | | - return next(reversed(self.blocks.values())).intermediates_outputs |
1042 | | - |
| 1052 | + # return next(reversed(self.blocks.values())).intermediates_outputs |
| 1053 | + return self.intermediates_outputs |
1043 | 1054 | @torch.no_grad() |
1044 | 1055 | def __call__(self, pipeline, state: PipelineState) -> PipelineState: |
1045 | 1056 | for block_name, block in self.blocks.items(): |
|
0 commit comments