Skip to content

Commit a36b705

Browse files
yogeshgsmurching
authored andcommitted
broadcast only tunable parameters (#120)
* Broadcast only tunable parameters in KerasImageFileEstimator's fitMultiple implementation to mitigate issues where some untunable parameters (e.g. imageLoader) were unpicklable, causing the broadcast to fail.
1 parent 9f05fc2 commit a36b705

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

python/sparkdl/estimators/keras_image_file_estimator.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def __init__(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=
148148
super(KerasImageFileEstimator, self).__init__()
149149
kwargs = self._input_kwargs
150150
self.setParams(**kwargs)
151+
self._tunable_params = [self.kerasOptimizer, self.kerasLoss, self.kerasFitParams,
152+
self.outputCol, self.outputMode] # model params and output params
151153

152154
@keyword_only
153155
def setParams(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=None,
@@ -168,14 +170,11 @@ def _validateParams(self, paramMap):
168170
:return: True if parameters are valid
169171
"""
170172

171-
tunable_params = [self.kerasOptimizer, self.kerasLoss, self.kerasFitParams, # model params
172-
self.outputCol, self.outputMode] # output params
173-
174173
undefined = set([p for p in self.params if not self.isDefined(p)])
175-
undefined_tunable = undefined.intersection(tunable_params)
174+
undefined_tunable = undefined.intersection(self._tunable_params)
176175
failed_define = [p.name for p in undefined.difference(undefined_tunable)]
177176
failed_tune = [p.name for p in undefined_tunable if p not in paramMap]
178-
untunable_overrides = [p.name for p in paramMap if p not in tunable_params]
177+
untunable_overrides = [p.name for p in paramMap if p not in self._tunable_params]
179178

180179
if failed_define or failed_tune or untunable_overrides:
181180
msg = "Following Params must be"
@@ -266,14 +265,16 @@ def fitMultiple(self, dataset, paramMaps):
266265
"""
267266
[self._validateParams(pm) for pm in paramMaps]
268267

269-
def _name_value_map(paramMap):
270-
"""takes a dictionary {param -> value} and returns a map of {param.name -> value}"""
271-
return {param.name: val for param, val in paramMap.items()}
268+
def _get_tunable_name_value_map(param_map, tunable):
269+
"""takes a dictionary {`Param` -> value} and a list [`Param`], select keys that are
270+
present in both and returns a map of {Param.name -> value}"""
271+
return {param.name: val for param, val in param_map.items() if param in tunable}
272272

273273
sc = JVMAPI._curr_sc()
274-
paramNameMaps = list(enumerate(map(_name_value_map, paramMaps)))
275-
num_models = len(paramNameMaps)
276-
paramNameMapsRDD = sc.parallelize(paramNameMaps, numSlices=num_models)
274+
param_name_maps = [(i, _get_tunable_name_value_map(pm, self._tunable_params))
275+
for (i, pm) in enumerate(paramMaps)]
276+
num_models = len(param_name_maps)
277+
paramNameMapsRDD = sc.parallelize(param_name_maps, numSlices=num_models)
277278

278279
# Extract image URI from provided dataset and create features as numpy arrays
279280
localFeatures, localLabels = self._getNumpyFeaturesAndLabels(dataset)
@@ -285,8 +286,8 @@ def _name_value_map(paramMap):
285286
modelBytesBc = sc.broadcast(modelBytes)
286287

287288
# Obtain params for this estimator instance
288-
baseParams = _name_value_map(self.extractParamMap())
289-
baseParamsBc = sc.broadcast(baseParams)
289+
base_params = _get_tunable_name_value_map(self.extractParamMap(), self._tunable_params)
290+
baseParamsBc = sc.broadcast(base_params)
290291

291292
def _local_fit(row):
292293
"""

0 commit comments

Comments
 (0)