Skip to content

Commit a928839

Browse files
authored
use .toArray() instead of .array to allow both Sparse and Dense Vectors (#125)
* use .toArray() instead of .array to allow both Sparse and Dense Vectors * add test to check logic for both SparseVectors and DenseVectors * update style in tests file
1 parent b09974b commit a928839

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

python/sparkdl/estimators/keras_image_file_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _getNumpyFeaturesAndLabels(self, dataset):
225225
localLabels = []
226226
for row in rows:
227227
try:
228-
_keras_label = row[label_col].array
228+
_keras_label = row[label_col].toArray()
229229
except ValueError:
230230
raise ValueError("Cannot extract encoded label array")
231231
localLabels.append(_keras_label)

python/tests/estimators/test_keras_estimators.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from keras.layers import Activation, Dense, Flatten
2727
from keras.models import Sequential
2828
import numpy as np
29+
import six
2930

3031
from pyspark.ml.evaluation import BinaryClassificationEvaluator
3132
import pyspark.ml.linalg as spla
@@ -52,17 +53,20 @@ def _load_image_from_uri(local_uri):
5253

5354
class KerasEstimatorsTest(SparkDLTestCase):
5455

55-
def _create_train_image_uris_and_labels(self, repeat_factor=1, cardinality=100):
56+
def _create_train_image_uris_and_labels(self, repeat_factor=1, cardinality=100, dense=True):
5657
image_uris = getSampleImagePaths() * repeat_factor
5758
# Create image categorical labels (integer IDs)
5859
local_rows = []
5960
for uri in image_uris:
6061
label = np.random.randint(low=0, high=cardinality, size=1)[0]
61-
label_inds = np.zeros(cardinality)
62-
label_inds[label] = 1.0
63-
label_inds = label_inds.ravel()
64-
assert label_inds.shape[0] == cardinality, label_inds.shape
65-
one_hot_vec = spla.Vectors.dense(label_inds.tolist())
62+
if dense:
63+
label_inds = np.zeros(cardinality)
64+
label_inds[label] = 1.0
65+
label_inds = label_inds.ravel()
66+
assert label_inds.shape[0] == cardinality, label_inds.shape
67+
one_hot_vec = spla.Vectors.dense(label_inds.tolist())
68+
else: # sparse
69+
one_hot_vec = spla.Vectors.sparse(cardinality, {label: 1})
6670
_row_struct = {self.input_col: uri, self.one_hot_col: one_hot_vec,
6771
self.one_hot_label_col: float(label)}
6872
row = sptyp.Row(**_row_struct)
@@ -71,7 +75,8 @@ def _create_train_image_uris_and_labels(self, repeat_factor=1, cardinality=100):
7175
image_uri_df = self.session.createDataFrame(local_rows)
7276
return image_uri_df
7377

74-
def _get_model(self, label_cardinality):
78+
@staticmethod
79+
def _get_model(label_cardinality):
7580
# We need a small model so that machines with limited resources can run it
7681
model = Sequential()
7782
model.add(Flatten(input_shape=(299, 299, 3)))
@@ -109,23 +114,43 @@ def test_validate_params(self):
109114

110115
# should raise an error to define required parameters
111116
# assuming at least one param without default value
112-
self.assertRaisesRegexp(ValueError, 'defined', kifest._validateParams, {})
117+
six.assertRaisesRegex(self, ValueError, 'defined', kifest._validateParams, {})
113118
kifest.setParams(imageLoader=_load_image_from_uri, inputCol='c1', labelCol='c2')
114119
kifest.setParams(modelFile='/path/to/file.ext')
115120

116121
# should raise an error to define or tune parameters
117122
# assuming at least one tunable param without default value
118-
self.assertRaisesRegexp(ValueError, 'tuned', kifest._validateParams, {})
123+
six.assertRaisesRegex(self, ValueError, 'tuned', kifest._validateParams, {})
119124
kifest.setParams(kerasOptimizer='adam', kerasLoss='mse', kerasFitParams={})
120125
kifest.setParams(outputCol='c3', outputMode='vector')
121126

122127
# should raise an error to not override
123-
self.assertRaisesRegexp(ValueError, 'not tuned', kifest._validateParams,
124-
{kifest.imageLoader: None})
128+
six.assertRaisesRegex(
129+
self, ValueError, 'not tuned', kifest._validateParams, {kifest.imageLoader: None})
125130

126131
# should pass test on supplying all parameters
127132
self.assertTrue(kifest._validateParams({}))
128133

134+
def test_get_numpy_features_and_labels(self):
135+
"""Test that `KerasImageFileEstimator._getNumpyFeaturesAndLabels` method returns
136+
the right kind of features and labels for all kinds of inputs"""
137+
for cardinality in (1, 2, 4):
138+
model = self._get_model(cardinality)
139+
estimator = self._get_estimator(model)
140+
for repeat_factor in (1, 2, 4):
141+
for dense in (True, False):
142+
df = self._create_train_image_uris_and_labels(
143+
repeat_factor=repeat_factor, cardinality=cardinality, dense=dense)
144+
local_features, local_labels = estimator._getNumpyFeaturesAndLabels(df)
145+
self.assertIsInstance(local_features, np.ndarray)
146+
self.assertIsInstance(local_labels, np.ndarray)
147+
self.assertEqual(local_features.shape[0], local_labels.shape[0])
148+
for img_array in local_features:
149+
(_, _, num_channels) = img_array.shape
150+
self.assertEqual(num_channels, 3, "2 dimensional image with 3 channels")
151+
for one_hot_vector in local_labels:
152+
self.assertEqual(one_hot_vector.sum(), 1, "vector should be one hot")
153+
129154
def test_single_training(self):
130155
"""Test that single model fitting works well"""
131156
# Create image URI dataframe
@@ -140,9 +165,11 @@ def test_single_training(self):
140165

141166
transformer = estimator.fit(image_uri_df)
142167
self.assertIsInstance(transformer, KerasImageFileTransformer, "output should be KIFT")
143-
for p in map(lambda p: p.name, transformer.params):
144-
self.assertEqual(transformer.getOrDefault(p), estimator.getOrDefault(p),
145-
str(transformer.getOrDefault(p)))
168+
for param in transformer.params:
169+
param_name = param.name
170+
self.assertEqual(
171+
transformer.getOrDefault(param_name), estimator.getOrDefault(param_name),
172+
"Param should be equal for transformer generated from estimator: " + str(param))
146173

147174
def test_tuning(self):
148175
"""Test that multiple model fitting using `CrossValidator` works well"""
@@ -161,12 +188,12 @@ def test_tuning(self):
161188
.build()
162189
)
163190

164-
bc = BinaryClassificationEvaluator(rawPredictionCol=self.output_col,
165-
labelCol=self.one_hot_label_col)
166-
cv = CrossValidator(estimator=estimator, estimatorParamMaps=paramGrid, evaluator=bc,
167-
numFolds=2)
191+
evaluator = BinaryClassificationEvaluator(
192+
rawPredictionCol=self.output_col, labelCol=self.one_hot_label_col)
193+
validator = CrossValidator(
194+
estimator=estimator, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=2)
168195

169-
transformer = cv.fit(image_uri_df)
196+
transformer = validator.fit(image_uri_df)
170197
self.assertIsInstance(transformer.bestModel, KerasImageFileTransformer,
171198
"best model should be an instance of KerasImageFileTransformer")
172199
self.assertIn('batch_size', transformer.bestModel.getKerasFitParams(),

0 commit comments

Comments
 (0)