44
55import json
66import logging
7+ import re
78from collections import Counter , OrderedDict , defaultdict
89from copy import deepcopy
910
@@ -145,8 +146,11 @@ def _get_block_outputs(self, block_name):
145146 """Get the list of output variables for the given block."""
146147 block = self .blocks [block_name ]
147148 outputs = deepcopy (block .produce_output )
149+ output_names = self .output_names .get (block_name , dict ())
148150 for output in outputs :
149- output ['variable' ] = '{}.{}' .format (block_name , output ['name' ])
151+ name = output ['name' ]
152+ context_name = output_names .get (name , name )
153+ output ['variable' ] = '{}.{}' .format (block_name , context_name )
150154
151155 return outputs
152156
@@ -195,12 +199,15 @@ def __init__(self, pipeline=None, primitives=None, init_params=None,
195199 if hyperparameters :
196200 self .set_hyperparameters (hyperparameters )
197201
202+ self ._re_block_name = re .compile (r'(^[^#]+#\d+)(\..*)?' )
203+
198204 def _get_str_output (self , output ):
199205 """Get the outputs that correspond to the str specification."""
200206 if output in self .outputs :
201207 return self .outputs [output ]
202208 elif output in self .blocks :
203- return self ._get_block_outputs (output )
209+ return [{'name' : output , 'variable' : output }]
210+ # return self._get_block_outputs(output)
204211 elif '.' in output :
205212 block_name , variable_name = output .rsplit ('.' , 1 )
206213 block = self .blocks .get (block_name )
@@ -257,11 +264,11 @@ def get_outputs(self, outputs='default'):
257264
258265 computed = list ()
259266 for output in outputs :
267+ if isinstance (output , int ):
268+ output = self ._get_block_name (output )
269+
260270 if isinstance (output , str ):
261271 computed .extend (self ._get_str_output (output ))
262- elif isinstance (output , int ):
263- block_name = self ._get_block_name (output )
264- computed .extend (self ._get_block_outputs (block_name ))
265272 else :
266273 raise TypeError ('Output Specification can only be str or int' )
267274
@@ -313,6 +320,18 @@ def get_output_variables(self, outputs='default'):
313320 outputs = self .get_outputs (outputs )
314321 return [output ['variable' ] for output in outputs ]
315322
323+ def _extract_block_name (self , variable_name ):
324+ return self ._re_block_name .search (variable_name ).group (1 )
325+
326+ def _prepare_outputs (self , outputs ):
327+ output_variables = self .get_output_variables (outputs )
328+ outputs = output_variables .copy ()
329+ output_blocks = {
330+ self ._extract_block_name (variable )
331+ for variable in output_variables
332+ }
333+ return output_variables , outputs , output_blocks
334+
316335 @staticmethod
317336 def _flatten_dict (hyperparameters ):
318337 return {
@@ -516,13 +535,11 @@ def _extract_outputs(self, block_name, outputs, block_outputs):
516535
517536 return output_dict
518537
519- def _update_outputs (self , block_name , output_variables , outputs , outputs_dict ):
538+ def _update_outputs (self , variable_name , output_variables , outputs , value ):
520539 """Set the requested block outputs into the outputs list in the right place."""
521- for key , value in outputs_dict .items ():
522- variable_name = '{}.{}' .format (block_name , key )
523- if variable_name in output_variables :
524- index = output_variables .index (variable_name )
525- outputs [index ] = deepcopy (value )
540+ if variable_name in output_variables :
541+ index = output_variables .index (variable_name )
542+ outputs [index ] = deepcopy (value )
526543
527544 def _fit_block (self , block , block_name , context ):
528545 """Get the block args from the context and fit the block."""
@@ -551,7 +568,12 @@ def _produce_block(self, block, block_name, context, output_variables, outputs):
551568 context .update (outputs_dict )
552569
553570 if output_variables :
554- self ._update_outputs (block_name , output_variables , outputs , outputs_dict )
571+ if block_name in output_variables :
572+ self ._update_outputs (block_name , output_variables , outputs , context )
573+ else :
574+ for key , value in outputs_dict .items ():
575+ variable_name = '{}.{}' .format (block_name , key )
576+ self ._update_outputs (variable_name , output_variables , outputs , value )
555577
556578 except Exception :
557579 if self .verbose :
@@ -606,14 +628,12 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
606628 if y is not None :
607629 context ['y' ] = y
608630
609- if output_ is not None :
610- output_variables = self .get_output_variables (output_ )
611- outputs = output_variables .copy ()
612- output_blocks = {variable .rsplit ('.' , 1 )[0 ] for variable in output_variables }
613- else :
631+ if output_ is None :
614632 output_variables = None
615633 outputs = None
616634 output_blocks = set ()
635+ else :
636+ output_variables , outputs , output_blocks = self ._prepare_outputs (output_ )
617637
618638 if isinstance (start_ , int ):
619639 start_ = self ._get_block_name (start_ )
@@ -637,7 +657,7 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
637657
638658 # If there was an output_ but there are no pending
639659 # outputs we are done.
640- if output_ is not None and not output_blocks :
660+ if output_variables is not None and not output_blocks :
641661 if len (outputs ) > 1 :
642662 return tuple (outputs )
643663 else :
@@ -686,9 +706,7 @@ def predict(self, X=None, output_='default', start_=None, **kwargs):
686706 if X is not None :
687707 context ['X' ] = X
688708
689- output_variables = self .get_output_variables (output_ )
690- outputs = output_variables .copy ()
691- output_blocks = {variable .rsplit ('.' , 1 )[0 ] for variable in output_variables }
709+ output_variables , outputs , output_blocks = self ._prepare_outputs (output_ )
692710
693711 if isinstance (start_ , int ):
694712 start_ = self ._get_block_name (start_ )
0 commit comments