44
55import json
66import logging
7+ import re
78from collections import Counter , OrderedDict , defaultdict
89from copy import deepcopy
910
@@ -198,12 +199,15 @@ def __init__(self, pipeline=None, primitives=None, init_params=None,
198199 if hyperparameters :
199200 self .set_hyperparameters (hyperparameters )
200201
202+ self ._re_block_name = re .compile (r'(^[^#]+#\d+)(\..*)?' )
203+
201204 def _get_str_output (self , output ):
202205 """Get the outputs that correspond to the str specification."""
203206 if output in self .outputs :
204207 return self .outputs [output ]
205208 elif output in self .blocks :
206- return self ._get_block_outputs (output )
209+ return [{'name' : output , 'variable' : output }]
210+ # return self._get_block_outputs(output)
207211 elif '.' in output :
208212 block_name , variable_name = output .rsplit ('.' , 1 )
209213 block = self .blocks .get (block_name )
@@ -260,11 +264,11 @@ def get_outputs(self, outputs='default'):
260264
261265 computed = list ()
262266 for output in outputs :
267+ if isinstance (output , int ):
268+ output = self ._get_block_name (output )
269+
263270 if isinstance (output , str ):
264271 computed .extend (self ._get_str_output (output ))
265- elif isinstance (output , int ):
266- block_name = self ._get_block_name (output )
267- computed .extend (self ._get_block_outputs (block_name ))
268272 else :
269273 raise TypeError ('Output Specification can only be str or int' )
270274
@@ -316,6 +320,18 @@ def get_output_variables(self, outputs='default'):
316320 outputs = self .get_outputs (outputs )
317321 return [output ['variable' ] for output in outputs ]
318322
323+ def _extract_block_name (self , variable_name ):
324+ return self ._re_block_name .search (variable_name ).group (1 )
325+
326+ def _prepare_outputs (self , outputs ):
327+ output_variables = self .get_output_variables (outputs )
328+ outputs = output_variables .copy ()
329+ output_blocks = {
330+ self ._extract_block_name (variable )
331+ for variable in output_variables
332+ }
333+ return output_variables , outputs , output_blocks
334+
319335 @staticmethod
320336 def _flatten_dict (hyperparameters ):
321337 return {
@@ -519,13 +535,11 @@ def _extract_outputs(self, block_name, outputs, block_outputs):
519535
520536 return output_dict
521537
522- def _update_outputs (self , block_name , output_variables , outputs , outputs_dict ):
538+ def _update_outputs (self , variable_name , output_variables , outputs , value ):
523539 """Set the requested block outputs into the outputs list in the right place."""
524- for key , value in outputs_dict .items ():
525- variable_name = '{}.{}' .format (block_name , key )
526- if variable_name in output_variables :
527- index = output_variables .index (variable_name )
528- outputs [index ] = deepcopy (value )
540+ if variable_name in output_variables :
541+ index = output_variables .index (variable_name )
542+ outputs [index ] = deepcopy (value )
529543
530544 def _fit_block (self , block , block_name , context ):
531545 """Get the block args from the context and fit the block."""
@@ -554,7 +568,12 @@ def _produce_block(self, block, block_name, context, output_variables, outputs):
554568 context .update (outputs_dict )
555569
556570 if output_variables :
557- self ._update_outputs (block_name , output_variables , outputs , outputs_dict )
571+ if block_name in output_variables :
572+ self ._update_outputs (block_name , output_variables , outputs , context )
573+ else :
574+ for key , value in outputs_dict .items ():
575+ variable_name = '{}.{}' .format (block_name , key )
576+ self ._update_outputs (variable_name , output_variables , outputs , value )
558577
559578 except Exception :
560579 if self .verbose :
@@ -609,17 +628,12 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
609628 if y is not None :
610629 context ['y' ] = y
611630
612- if isinstance (output_ , str ):
613- output_variables = self .get_output_variables (output_ )
614- outputs = output_variables .copy ()
615- output_blocks = {variable .rsplit ('.' , 1 )[0 ] for variable in output_variables }
616- else :
631+ if output_ is None :
617632 output_variables = None
618633 outputs = None
619634 output_blocks = set ()
620-
621- if isinstance (output_ , int ):
622- output_ = self ._get_block_name (output_ )
635+ else :
636+ output_variables , outputs , output_blocks = self ._prepare_outputs (output_ )
623637
624638 if isinstance (start_ , int ):
625639 start_ = self ._get_block_name (start_ )
@@ -634,15 +648,12 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
634648
635649 self ._fit_block (block , block_name , context )
636650
637- last_block = block_name != self ._last_block_name
638- if last_block or (block_name == output_ ) or (block_name in output_blocks ):
651+ if (block_name != self ._last_block_name ) or (block_name in output_blocks ):
639652 self ._produce_block (block , block_name , context , output_variables , outputs )
640653
641654 # We already captured the output from this block
642655 if block_name in output_blocks :
643656 output_blocks .remove (block_name )
644- elif block_name == output_ :
645- return context
646657
647658 # If there was an output_ but there are no pending
648659 # outputs we are done.
@@ -695,9 +706,7 @@ def predict(self, X=None, output_='default', start_=None, **kwargs):
695706 if X is not None :
696707 context ['X' ] = X
697708
698- output_variables = self .get_output_variables (output_ )
699- outputs = output_variables .copy ()
700- output_blocks = {variable .rsplit ('.' , 1 )[0 ] for variable in output_variables }
709+ output_variables , outputs , output_blocks = self ._prepare_outputs (output_ )
701710
702711 if isinstance (start_ , int ):
703712 start_ = self ._get_block_name (start_ )
0 commit comments