Skip to content

Commit cfe9a7f

Browse files
yogeshgsueann
authored andcommitted
error for untunable params, expose to root name space (#117)
* expose KerasImageFileEstimator at top level * raise error for trying to tune untunable params
1 parent 2a63d4f commit cfe9a7f

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

python/sparkdl/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from .transformers.tf_image import TFImageTransformer
2020
from .transformers.tf_tensor import TFTransformer
2121
from .transformers.utils import imageInputPlaceholder
22-
22+
from .estimators.keras_image_file_estimator import KerasImageFileEstimator
2323

2424
__all__ = [
2525
'imageSchema', 'imageType', 'readImages',
2626
'TFImageTransformer', 'TFInputGraph', 'TFTransformer',
2727
'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer', 'KerasTransformer',
28-
'imageInputPlaceholder']
28+
'imageInputPlaceholder', 'KerasImageFileEstimator']

python/sparkdl/estimators/keras_image_file_estimator.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,21 +167,24 @@ def _validateParams(self, paramMap):
167167
:param paramMap: Dict[pyspark.ml.param.Param, object]
168168
:return: True if parameters are valid
169169
"""
170-
model_params = [self.kerasOptimizer, self.kerasLoss, self.kerasFitParams]
171-
output_params = [self.outputCol, self.outputMode]
172170

173-
params = self.params
174-
undefined = set([p for p in params if not self.isDefined(p)])
175-
undefined_tunable = undefined.intersection(model_params + output_params)
171+
tunable_params = [self.kerasOptimizer, self.kerasLoss, self.kerasFitParams, # model params
172+
self.outputCol, self.outputMode] # output params
173+
174+
undefined = set([p for p in self.params if not self.isDefined(p)])
175+
undefined_tunable = undefined.intersection(tunable_params)
176176
failed_define = [p.name for p in undefined.difference(undefined_tunable)]
177177
failed_tune = [p.name for p in undefined_tunable if p not in paramMap]
178+
untunable_overrides = [p.name for p in paramMap if p not in tunable_params]
178179

179-
if failed_define or failed_tune:
180+
if failed_define or failed_tune or untunable_overrides:
180181
msg = "Following Params must be"
181182
if failed_define:
182-
msg += " defined: [" + ", ".join(failed_define) + "]"
183+
msg += " defined: " + str(failed_define)
183184
if failed_tune:
184-
msg += " defined or tuned: [" + ", ".join(failed_tune) + "]"
185+
msg += " defined or tuned: " + str(failed_tune)
186+
if untunable_overrides:
187+
msg += " not tuned: " + str(untunable_overrides)
185188
raise ValueError(msg)
186189

187190
return True

python/tests/estimators/test_keras_estimators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def test_validate_params(self):
119119
kifest.setParams(kerasOptimizer='adam', kerasLoss='mse', kerasFitParams={})
120120
kifest.setParams(outputCol='c3', outputMode='vector')
121121

122+
# should raise an error to not override
123+
self.assertRaisesRegexp(ValueError, 'not tuned', kifest._validateParams,
124+
{kifest.imageLoader: None})
125+
122126
# should pass test on supplying all parameters
123127
self.assertTrue(kifest._validateParams({}))
124128

0 commit comments

Comments
 (0)