@@ -148,6 +148,8 @@ def __init__(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=
148148 super (KerasImageFileEstimator , self ).__init__ ()
149149 kwargs = self ._input_kwargs
150150 self .setParams (** kwargs )
151+ self ._tunable_params = [self .kerasOptimizer , self .kerasLoss , self .kerasFitParams ,
152+ self .outputCol , self .outputMode ] # model params and output params
151153
152154 @keyword_only
153155 def setParams (self , inputCol = None , outputCol = None , outputMode = "vector" , labelCol = None ,
@@ -168,14 +170,11 @@ def _validateParams(self, paramMap):
168170 :return: True if parameters are valid
169171 """
170172
171- tunable_params = [self .kerasOptimizer , self .kerasLoss , self .kerasFitParams , # model params
172- self .outputCol , self .outputMode ] # output params
173-
174173 undefined = set ([p for p in self .params if not self .isDefined (p )])
175- undefined_tunable = undefined .intersection (tunable_params )
174+ undefined_tunable = undefined .intersection (self . _tunable_params )
176175 failed_define = [p .name for p in undefined .difference (undefined_tunable )]
177176 failed_tune = [p .name for p in undefined_tunable if p not in paramMap ]
178- untunable_overrides = [p .name for p in paramMap if p not in tunable_params ]
177+ untunable_overrides = [p .name for p in paramMap if p not in self . _tunable_params ]
179178
180179 if failed_define or failed_tune or untunable_overrides :
181180 msg = "Following Params must be"
@@ -266,14 +265,16 @@ def fitMultiple(self, dataset, paramMaps):
266265 """
267266 [self ._validateParams (pm ) for pm in paramMaps ]
268267
269- def _name_value_map (paramMap ):
270- """takes a dictionary {param -> value} and returns a map of {param.name -> value}"""
271- return {param .name : val for param , val in paramMap .items ()}
268+ def _get_tunable_name_value_map (param_map , tunable ):
269+ """takes a dictionary {`Param` -> value} and a list [`Param`], select keys that are
270+ present in both and returns a map of {Param.name -> value}"""
271+ return {param .name : val for param , val in param_map .items () if param in tunable }
272272
273273 sc = JVMAPI ._curr_sc ()
274- paramNameMaps = list (enumerate (map (_name_value_map , paramMaps )))
275- num_models = len (paramNameMaps )
276- paramNameMapsRDD = sc .parallelize (paramNameMaps , numSlices = num_models )
274+ param_name_maps = [(i , _get_tunable_name_value_map (pm , self ._tunable_params ))
275+ for (i , pm ) in enumerate (paramMaps )]
276+ num_models = len (param_name_maps )
277+ paramNameMapsRDD = sc .parallelize (param_name_maps , numSlices = num_models )
277278
278279 # Extract image URI from provided dataset and create features as numpy arrays
279280 localFeatures , localLabels = self ._getNumpyFeaturesAndLabels (dataset )
@@ -285,8 +286,8 @@ def _name_value_map(paramMap):
285286 modelBytesBc = sc .broadcast (modelBytes )
286287
287288 # Obtain params for this estimator instance
288- baseParams = _name_value_map (self .extractParamMap ())
289- baseParamsBc = sc .broadcast (baseParams )
289+ base_params = _get_tunable_name_value_map (self .extractParamMap (), self . _tunable_params )
290+ baseParamsBc = sc .broadcast (base_params )
290291
291292 def _local_fit (row ):
292293 """
0 commit comments