@@ -115,9 +115,21 @@ def make_components(self, path_to_components_file, custom_file_dir=None):
115115 used. (Default: None)
116116 """
117117 with open (path_to_components_file , 'r' ) as file :
118- components = json .load (file )
118+ componentsConfig = json .load (file )
119+
120+ parameterMap = {}
121+ components = componentsConfig ["components" ]
122+ if "hyperparameters" in componentsConfig .keys ():
123+ for component in components :
124+ for pKey , pValue in componentsConfig ["hyperparameters" ].items ():
125+ for cKey , cValue in component .items ():
126+ if pKey == cValue :
127+ component [cKey ] = pValue
128+ parameterMap [cKey ] = pKey
129+
119130 for component in components :
120- self .add_component (** component , directory = custom_file_dir )
131+ self .add_component (** component , directory = custom_file_dir ,
132+ parameterMap = parameterMap )
121133
122134 def make_steps (self , path_to_steps_file ):
123135 """
@@ -197,13 +209,11 @@ def add_component(self, component_type, match_case=False, absolute_path=False, *
197209 raise RuntimeError ("Given component type " + str (component_type )
198210 + " is not callable" )
199211
200-
201-
202- count = call .__code__ .co_argcount - 1
203- named_args = call .__code__ .co_varnames [1 :count ]
204212 try :
205213 component = Component_class (** kwargs )
206214 except TypeError as E :
215+ count = call .__code__ .co_argcount - 1
216+ named_args = call .__code__ .co_varnames [1 :count ]
207217 print (E )
208218 raise RuntimeError (str (E ) + "\n Provided keyword arguments:\t " + str (list (kwargs .keys ())) +
209219 "\n Required keyword arguments:\t " + str (list (named_args )))
@@ -217,6 +227,9 @@ def add_component(self, component_type, match_case=False, absolute_path=False, *
217227 del obj [key ]
218228 print ("Failed to serialize \" " + str (key ) + "\" in " + component .name )
219229
230+ if "directory" in obj .keys ():
231+ del obj ["directory" ]
232+
220233 self ._json_objects ['components' ].append (obj )
221234
222235 return component
@@ -347,7 +360,45 @@ def save_to_json(self, directory, model_name=None, custom_save=True):
347360 json .dump (self ._json_objects ['steps' ], fp , indent = 4 )
348361
349362 with open (path + "/components.json" , 'w' ) as fp :
350- json .dump (self ._json_objects ['components' ], fp , indent = 4 )
363+ hyperparameters = {}
364+
365+ for idx , component in enumerate (self ._json_objects ['components' ]):
366+ if component .get ('parameterMap' , None ) is not None :
367+ for cKey , pKey in component ['parameterMap' ].items ():
368+ pVal = component [cKey ]
369+ if pKey not in hyperparameters .keys ():
370+ hyperparameters [pKey ] = []
371+ hyperparameters [pKey ].append ((idx , cKey , pVal ))
372+
373+ hp = {}
374+ for param in hyperparameters .keys ():
375+ matched = True
376+ hp [param ] = None
377+ for _ , _ , pVal in hyperparameters [param ]:
378+ if hp [param ] is None :
379+ hp [param ] = pVal
380+ elif hp [param ] != pVal :
381+ del hp [param ]
382+ matched = False
383+ break
384+
385+ for idx , cKey , _ in hyperparameters [param ]:
386+ if matched :
387+ self ._json_objects ['components' ][idx ][cKey ] = param
388+
389+ else :
390+ warnings .warn ("Unable to extract hyperparameter " + str (param ) +
391+ " as it is mismatched between components. Parameter will not be extracted" )
392+
393+ for component in self ._json_objects ['components' ]:
394+ if "parameterMap" in component .keys ():
395+ del component ["parameterMap" ]
396+
397+ obj = {"components" : self ._json_objects ['components' ]}
398+ if len (hp .keys ()) != 0 :
399+ obj ["hyperparameters" ] = hp
400+
401+ json .dump (obj , fp , indent = 4 )
351402
352403 with open (path + "/connections.json" , 'w' ) as fp :
353404 json .dump (self ._json_objects ['connections' ], fp , indent = 4 )
0 commit comments