@@ -145,8 +145,11 @@ def _get_block_outputs(self, block_name):
145145 """Get the list of output variables for the given block."""
146146 block = self .blocks [block_name ]
147147 outputs = deepcopy (block .produce_output )
148+ output_names = self .output_names .get (block_name , dict ())
148149 for output in outputs :
149- output ['variable' ] = '{}.{}' .format (block_name , output ['name' ])
150+ name = output ['name' ]
151+ context_name = output_names .get (name , name )
152+ output ['variable' ] = '{}.{}' .format (block_name , context_name )
150153
151154 return outputs
152155
@@ -606,7 +609,7 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
606609 if y is not None :
607610 context ['y' ] = y
608611
609- if output_ is not None :
612+ if isinstance ( output_ , str ) :
610613 output_variables = self .get_output_variables (output_ )
611614 outputs = output_variables .copy ()
612615 output_blocks = {variable .rsplit ('.' , 1 )[0 ] for variable in output_variables }
@@ -615,6 +618,9 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
615618 outputs = None
616619 output_blocks = set ()
617620
621+ if isinstance (output_ , int ):
622+ output_ = self ._get_block_name (output_ )
623+
618624 if isinstance (start_ , int ):
619625 start_ = self ._get_block_name (start_ )
620626
@@ -628,16 +634,19 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
628634
629635 self ._fit_block (block , block_name , context )
630636
631- if (block_name != self ._last_block_name ) or (block_name in output_blocks ):
637+ last_block = block_name != self ._last_block_name
638+ if last_block or (block_name == output_ ) or (block_name in output_blocks ):
632639 self ._produce_block (block , block_name , context , output_variables , outputs )
633640
634641 # We already captured the output from this block
635642 if block_name in output_blocks :
636643 output_blocks .remove (block_name )
644+ elif block_name == output_ :
645+ return context
637646
638647 # If there was an output_ but there are no pending
639648 # outputs we are done.
640- if output_ is not None and not output_blocks :
649+ if output_variables is not None and not output_blocks :
641650 if len (outputs ) > 1 :
642651 return tuple (outputs )
643652 else :
0 commit comments