@@ -87,16 +87,21 @@ def _build_blocks(self):
8787
8888 block_names_count = Counter ()
8989 for primitive in self .primitives :
90+ if isinstance (primitive , str ):
91+ primitive_name = primitive
92+ else :
93+ primitive_name = primitive ['name' ]
94+
9095 try :
91- block_names_count .update ([primitive ])
92- block_count = block_names_count [primitive ]
93- block_name = '{}#{}' .format (primitive , block_count )
96+ block_names_count .update ([primitive_name ])
97+ block_count = block_names_count [primitive_name ]
98+ block_name = '{}#{}' .format (primitive_name , block_count )
9499 block_params = self .init_params .get (block_name , dict ())
95100 if not block_params :
96- block_params = self .init_params .get (primitive , dict ())
101+ block_params = self .init_params .get (primitive_name , dict ())
97102 if block_params and block_count > 1 :
98103 LOGGER .warning (("Non-numbered init_params are being used "
99- "for more than one block %s." ), primitive )
104+ "for more than one block %s." ), primitive_name )
100105
101106 block = MLBlock (primitive , ** block_params )
102107 blocks [block_name ] = block
@@ -330,10 +335,6 @@ def _get_block_args(self, block_name, block_args, context):
330335
331336 if variable in context :
332337 kwargs [name ] = context [variable ]
333- elif 'default' in arg :
334- kwargs [name ] = arg ['default' ]
335- elif arg .get ('required' , True ):
336- raise ValueError ('Input variable {} not found in context' .format (variable ))
337338
338339 return kwargs
339340
@@ -517,11 +518,12 @@ def fit(self, X=None, y=None, output_=None, start_=None, **kwargs):
517518 the value of that variable from the context will extracted and returned
518519 after the produce method of that block has been called.
519520 """
520- context = {
521- 'X' : X ,
522- 'y' : y
523- }
524- context .update (kwargs )
521+ context = kwargs .copy ()
522+ if X is not None :
523+ context ['X' ] = X
524+
525+ if y is not None :
526+ context ['y' ] = y
525527
526528 output_block , output_variable = self ._get_output_spec (output_ )
527529 last_block_name = self ._get_block_name (- 1 )
@@ -624,10 +626,9 @@ def predict(self, X=None, output_=None, start_=None, **kwargs):
624626 the value of that variable from the context will extracted and returned
625627 after the produce method of that block has been called.
626628 """
627- context = {
628- 'X' : X
629- }
630- context .update (kwargs )
629+ context = kwargs .copy ()
630+ if X is not None :
631+ context ['X' ] = X
631632
632633 output_block , output_variable = self ._get_output_spec (output_ )
633634
0 commit comments