@@ -69,7 +69,12 @@ def get_intermediate(self, key: str, default: Any = None) -> Any:
6969 return self .intermediates .get (key , default )
7070
7171 def get_output (self , key : str , default : Any = None ) -> Any :
72- return self .outputs .get (key , default )
72+ if key in self .outputs :
73+ return self .outputs [key ]
74+ elif key in self .intermediates :
75+ return self .intermediates [key ]
76+ else :
77+ return default
7378
7479 def to_dict (self ) -> Dict [str , Any ]:
7580 return {** self .__dict__ , "inputs" : self .inputs , "intermediates" : self .intermediates , "outputs" : self .outputs }
@@ -1132,7 +1137,7 @@ def replace_blocks(self, pipeline_blocks, at: int):
11321137 # Remove the old blocks
11331138 self .remove_blocks (indices_to_remove )
11341139
1135- def run_blocks (self , state : PipelineState = None , ** kwargs ):
1140+ def run_blocks (self , state : PipelineState = None , output : Union [ str , List [ str ]] = None , ** kwargs ):
11361141 """
11371142 Run one or more blocks in sequence, optionally you can pass a previous pipeline state.
11381143 """
@@ -1174,7 +1179,18 @@ def run_blocks(self, state: PipelineState = None, **kwargs):
11741179 raise
11751180 self .maybe_free_model_hooks ()
11761181
1177- return state
1182+ if output is None :
1183+ return state
1184+
1185+ if isinstance (output , str ):
1186+ return state .get_output (output )
1187+ elif isinstance (output , (list , tuple )):
1188+ outputs = {}
1189+ for output_name in output :
1190+ outputs [output_name ] = state .get_output (output_name )
1191+ return outputs
1192+ else :
1193+ raise ValueError (f"Output '{ output } ' is not a valid output type" )
11781194
11791195 def run_pipeline (self , ** kwargs ):
11801196 state = PipelineState ()
0 commit comments