Skip to content

Commit 2a63d4f

Browse files
yogeshgsueann
authored andcommitted
add better parameter validation for KIFEst (#116)
* add better param validation; add tests * remove unused params from project * test for all params and avoid dependencies on strings
1 parent 841f905 commit 2a63d4f

File tree

5 files changed

+50
-54
lines changed

5 files changed

+50
-54
lines changed

python/sparkdl/estimators/keras_image_file_estimator.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020
import threading
2121
import numpy as np
2222

23-
import pyspark
2423
from pyspark.ml import Estimator
2524
import pyspark.ml.linalg as spla
2625

2726
from sparkdl.image.imageIO import imageStructToArray
2827
from sparkdl.param import (
2928
keyword_only, CanLoadImage, HasKerasModel, HasKerasOptimizer, HasKerasLoss, HasOutputMode,
30-
HasInputCol, HasInputImageNodeName, HasLabelCol, HasOutputNodeName, HasOutputCol)
29+
HasInputCol, HasLabelCol, HasOutputCol)
3130
from sparkdl.transformers.keras_image import KerasImageFileTransformer
3231
import sparkdl.utils.jvmapi as JVMAPI
3332
import sparkdl.utils.keras_model as kmutil
@@ -74,10 +73,8 @@ def next(self):
7473
return self.__next__()
7574

7675

77-
class KerasImageFileEstimator(Estimator, HasInputCol, HasInputImageNodeName,
78-
HasOutputCol, HasOutputNodeName, HasLabelCol,
79-
HasKerasModel, HasKerasOptimizer, HasKerasLoss,
80-
CanLoadImage, HasOutputMode):
76+
class KerasImageFileEstimator(Estimator, HasInputCol, HasOutputCol, HasLabelCol, HasKerasModel,
77+
HasKerasOptimizer, HasKerasLoss, CanLoadImage, HasOutputMode):
8178
"""
8279
Build a Estimator from a Keras model.
8380
@@ -138,13 +135,11 @@ def load_image_and_process(uri):
138135
"""
139136

140137
@keyword_only
141-
def __init__(self, inputCol=None, inputImageNodeName=None, outputCol=None,
142-
outputNodeName=None, outputMode="vector", labelCol=None,
138+
def __init__(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=None,
143139
modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
144140
kerasFitParams=None):
145141
"""
146-
__init__(self, inputCol=None, inputImageNodeName=None, outputCol=None,
147-
outputNodeName=None, outputMode="vector", labelCol=None,
142+
__init__(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=None,
148143
modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
149144
kerasFitParams=None)
150145
"""
@@ -155,13 +150,11 @@ def __init__(self, inputCol=None, inputImageNodeName=None, outputCol=None,
155150
self.setParams(**kwargs)
156151

157152
@keyword_only
158-
def setParams(self, inputCol=None, inputImageNodeName=None, outputCol=None,
159-
outputNodeName=None, outputMode="vector", labelCol=None,
153+
def setParams(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=None,
160154
modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
161155
kerasFitParams=None):
162156
"""
163-
setParams(self, inputCol=None, inputImageNodeName=None, outputCol=None,
164-
outputNodeName=None, outputMode="vector", labelCol=None,
157+
setParams(self, inputCol=None, outputCol=None, outputMode="vector", labelCol=None,
165158
modelFile=None, imageLoader=None, kerasOptimizer=None, kerasLoss=None,
166159
kerasFitParams=None)
167160
"""
@@ -174,12 +167,23 @@ def _validateParams(self, paramMap):
174167
:param paramMap: Dict[pyspark.ml.param.Param, object]
175168
:return: True if parameters are valid
176169
"""
177-
if not self.isDefined(self.inputCol):
178-
raise ValueError("Input column must be defined")
179-
if not self.isDefined(self.outputCol):
180-
raise ValueError("Output column must be defined")
181-
if self.inputCol in paramMap:
182-
raise ValueError("Input column can not be fine tuned")
170+
model_params = [self.kerasOptimizer, self.kerasLoss, self.kerasFitParams]
171+
output_params = [self.outputCol, self.outputMode]
172+
173+
params = self.params
174+
undefined = set([p for p in params if not self.isDefined(p)])
175+
undefined_tunable = undefined.intersection(model_params + output_params)
176+
failed_define = [p.name for p in undefined.difference(undefined_tunable)]
177+
failed_tune = [p.name for p in undefined_tunable if p not in paramMap]
178+
179+
if failed_define or failed_tune:
180+
msg = "Following Params must be"
181+
if failed_define:
182+
msg += " defined: [" + ", ".join(failed_define) + "]"
183+
if failed_tune:
184+
msg += " defined or tuned: [" + ", ".join(failed_tune) + "]"
185+
raise ValueError(msg)
186+
183187
return True
184188

185189
def _getNumpyFeaturesAndLabels(self, dataset):
@@ -236,8 +240,7 @@ def _collectModels(self, kerasModelBytesRDD):
236240
"""
237241
Collect Keras models on workers to MLlib Models on the driver.
238242
:param kerasModelBytesRDD: RDD of (param_map, model_bytes) tuples
239-
:param paramMaps: list of ParamMaps matching the maps in `kerasModelsRDD`
240-
:return: list of MLlib models
243+
:return: generator of (index, MLlib model) tuples
241244
"""
242245
for (i, param_map, model_bytes) in kerasModelBytesRDD.collect():
243246
model_filename = kmutil.bytes_to_h5file(model_bytes)
@@ -264,7 +267,6 @@ def _name_value_map(paramMap):
264267
"""takes a dictionary {param -> value} and returns a map of {param.name -> value}"""
265268
return {param.name: val for param, val in paramMap.items()}
266269

267-
268270
sc = JVMAPI._curr_sc()
269271
paramNameMaps = list(enumerate(map(_name_value_map, paramMaps)))
270272
num_models = len(paramNameMaps)

python/sparkdl/param/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# TFTransformer Params
1919
HasInputMapping, HasOutputMapping, HasTFInputGraph, HasTFHParams,
2020
# Keras Estimator Params
21-
HasKerasModel, HasKerasLoss, HasKerasOptimizer, HasOutputNodeName)
21+
HasKerasModel, HasKerasLoss, HasKerasOptimizer)
2222
from sparkdl.param.converters import SparkDLTypeConverters
2323
from sparkdl.param.image_params import (
24-
CanLoadImage, HasInputImageNodeName, HasOutputMode, OUTPUT_MODES)
24+
CanLoadImage, HasOutputMode, OUTPUT_MODES)

python/sparkdl/param/image_params.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,6 @@
2929
OUTPUT_MODES = ["vector", "image"]
3030

3131

32-
class HasInputImageNodeName(Params):
33-
# TODO: docs
34-
inputImageNodeName = Param(Params._dummy(), "inputImageNodeName",
35-
"name of the graph element/node corresponding to the input",
36-
typeConverter=TypeConverters.toString)
37-
38-
def setInputImageNodeName(self, value):
39-
return self._set(inputImageNodeName=value)
40-
41-
def getInputImageNodeName(self):
42-
return self.getOrDefault(self.inputImageNodeName)
43-
44-
4532
class CanLoadImage(Params):
4633
"""
4734
In standard Keras workflow, we use provides an image loading function

python/sparkdl/param/shared_params.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,6 @@ def getOutputCol(self):
105105
########################################################
106106

107107

108-
class HasOutputNodeName(Params):
109-
# TODO: docs
110-
outputNodeName = Param(Params._dummy(), "outputNodeName",
111-
"name of the graph element/node corresponding to the output",
112-
typeConverter=TypeConverters.toString)
113-
114-
def setOutputNodeName(self, value):
115-
return self._set(outputNodeName=value)
116-
117-
def getOutputNodeName(self):
118-
return self.getOrDefault(self.outputNodeName)
119-
120-
121108
class HasLabelCol(Params):
122109
"""
123110
When training Keras image models in a supervised learning setting,

python/tests/estimators/test_keras_estimators.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ def _get_model(self, label_cardinality):
8080
return model
8181

8282
def _get_estimator(self, model):
83-
"""
84-
Create a :py:obj:`KerasImageFileEstimator` from an existing Keras model
85-
"""
83+
"""Create a :py:obj:`KerasImageFileEstimator` from an existing Keras model"""
8684
_random_filename_suffix = str(uuid.uuid4())
8785
model_filename = os.path.join(self.temp_dir, 'model-{}.h5'.format(_random_filename_suffix))
8886
model.save(model_filename)
@@ -105,7 +103,27 @@ def setUp(self):
105103
def tearDown(self):
106104
shutil.rmtree(self.temp_dir, ignore_errors=True)
107105

106+
def test_validate_params(self):
107+
"""Test that `KerasImageFileEstimator._validateParams` method works as expected"""
108+
kifest = KerasImageFileEstimator()
109+
110+
# should raise an error to define required parameters
111+
# assuming at least one param without default value
112+
self.assertRaisesRegexp(ValueError, 'defined', kifest._validateParams, {})
113+
kifest.setParams(imageLoader=_load_image_from_uri, inputCol='c1', labelCol='c2')
114+
kifest.setParams(modelFile='/path/to/file.ext')
115+
116+
# should raise an error to define or tune parameters
117+
# assuming at least one tunable param without default value
118+
self.assertRaisesRegexp(ValueError, 'tuned', kifest._validateParams, {})
119+
kifest.setParams(kerasOptimizer='adam', kerasLoss='mse', kerasFitParams={})
120+
kifest.setParams(outputCol='c3', outputMode='vector')
121+
122+
# should pass test on supplying all parameters
123+
self.assertTrue(kifest._validateParams({}))
124+
108125
def test_single_training(self):
126+
"""Test that single model fitting works well"""
109127
# Create image URI dataframe
110128
label_cardinality = 10
111129
image_uri_df = self._create_train_image_uris_and_labels(repeat_factor=3,
@@ -123,6 +141,7 @@ def test_single_training(self):
123141
str(transformer.getOrDefault(p)))
124142

125143
def test_tuning(self):
144+
"""Test that multiple model fitting using `CrossValidator` works well"""
126145
# Create image URI dataframe
127146
label_cardinality = 2
128147
image_uri_df = self._create_train_image_uris_and_labels(repeat_factor=3,
@@ -150,6 +169,7 @@ def test_tuning(self):
150169
"fit params must be copied")
151170

152171
def test_keras_training_utils(self):
172+
"""Test some Keras training utils"""
153173
self.assertTrue(kmutil.is_valid_optimizer('adam'))
154174
self.assertFalse(kmutil.is_valid_optimizer('noSuchOptimizer'))
155175
self.assertTrue(kmutil.is_valid_loss_function('mse'))

0 commit comments

Comments
 (0)