2626from keras .layers import Activation , Dense , Flatten
2727from keras .models import Sequential
2828import numpy as np
29+ import six
2930
3031from pyspark .ml .evaluation import BinaryClassificationEvaluator
3132import pyspark .ml .linalg as spla
@@ -52,17 +53,20 @@ def _load_image_from_uri(local_uri):
5253
5354class KerasEstimatorsTest (SparkDLTestCase ):
5455
55- def _create_train_image_uris_and_labels (self , repeat_factor = 1 , cardinality = 100 ):
56+ def _create_train_image_uris_and_labels (self , repeat_factor = 1 , cardinality = 100 , dense = True ):
5657 image_uris = getSampleImagePaths () * repeat_factor
5758 # Create image categorical labels (integer IDs)
5859 local_rows = []
5960 for uri in image_uris :
6061 label = np .random .randint (low = 0 , high = cardinality , size = 1 )[0 ]
61- label_inds = np .zeros (cardinality )
62- label_inds [label ] = 1.0
63- label_inds = label_inds .ravel ()
64- assert label_inds .shape [0 ] == cardinality , label_inds .shape
65- one_hot_vec = spla .Vectors .dense (label_inds .tolist ())
62+ if dense :
63+ label_inds = np .zeros (cardinality )
64+ label_inds [label ] = 1.0
65+ label_inds = label_inds .ravel ()
66+ assert label_inds .shape [0 ] == cardinality , label_inds .shape
67+ one_hot_vec = spla .Vectors .dense (label_inds .tolist ())
68+ else : # sparse
69+ one_hot_vec = spla .Vectors .sparse (cardinality , {label : 1 })
6670 _row_struct = {self .input_col : uri , self .one_hot_col : one_hot_vec ,
6771 self .one_hot_label_col : float (label )}
6872 row = sptyp .Row (** _row_struct )
@@ -71,7 +75,8 @@ def _create_train_image_uris_and_labels(self, repeat_factor=1, cardinality=100):
7175 image_uri_df = self .session .createDataFrame (local_rows )
7276 return image_uri_df
7377
74- def _get_model (self , label_cardinality ):
78+ @staticmethod
79+ def _get_model (label_cardinality ):
7580 # We need a small model so that machines with limited resources can run it
7681 model = Sequential ()
7782 model .add (Flatten (input_shape = (299 , 299 , 3 )))
@@ -109,23 +114,43 @@ def test_validate_params(self):
109114
110115 # should raise an error to define required parameters
111116 # assuming at least one param without default value
112- self . assertRaisesRegexp ( ValueError , 'defined' , kifest ._validateParams , {})
117+ six . assertRaisesRegex ( self , ValueError , 'defined' , kifest ._validateParams , {})
113118 kifest .setParams (imageLoader = _load_image_from_uri , inputCol = 'c1' , labelCol = 'c2' )
114119 kifest .setParams (modelFile = '/path/to/file.ext' )
115120
116121 # should raise an error to define or tune parameters
117122 # assuming at least one tunable param without default value
118- self . assertRaisesRegexp ( ValueError , 'tuned' , kifest ._validateParams , {})
123+ six . assertRaisesRegex ( self , ValueError , 'tuned' , kifest ._validateParams , {})
119124 kifest .setParams (kerasOptimizer = 'adam' , kerasLoss = 'mse' , kerasFitParams = {})
120125 kifest .setParams (outputCol = 'c3' , outputMode = 'vector' )
121126
122127 # should raise an error to not override
123- self . assertRaisesRegexp ( ValueError , 'not tuned' , kifest . _validateParams ,
124- {kifest .imageLoader : None })
128+ six . assertRaisesRegex (
129+ self , ValueError , 'not tuned' , kifest . _validateParams , {kifest .imageLoader : None })
125130
126131 # should pass test on supplying all parameters
127132 self .assertTrue (kifest ._validateParams ({}))
128133
134+ def test_get_numpy_features_and_labels (self ):
135+ """Test that `KerasImageFileEstimator._getNumpyFeaturesAndLabels` method returns
136+ the right kind of features and labels for all kinds of inputs"""
137+ for cardinality in (1 , 2 , 4 ):
138+ model = self ._get_model (cardinality )
139+ estimator = self ._get_estimator (model )
140+ for repeat_factor in (1 , 2 , 4 ):
141+ for dense in (True , False ):
142+ df = self ._create_train_image_uris_and_labels (
143+ repeat_factor = repeat_factor , cardinality = cardinality , dense = dense )
144+ local_features , local_labels = estimator ._getNumpyFeaturesAndLabels (df )
145+ self .assertIsInstance (local_features , np .ndarray )
146+ self .assertIsInstance (local_labels , np .ndarray )
147+ self .assertEqual (local_features .shape [0 ], local_labels .shape [0 ])
148+ for img_array in local_features :
149+ (_ , _ , num_channels ) = img_array .shape
150+ self .assertEqual (num_channels , 3 , "2 dimensional image with 3 channels" )
151+ for one_hot_vector in local_labels :
152+ self .assertEqual (one_hot_vector .sum (), 1 , "vector should be one hot" )
153+
129154 def test_single_training (self ):
130155 """Test that single model fitting works well"""
131156 # Create image URI dataframe
@@ -140,9 +165,11 @@ def test_single_training(self):
140165
141166 transformer = estimator .fit (image_uri_df )
142167 self .assertIsInstance (transformer , KerasImageFileTransformer , "output should be KIFT" )
143- for p in map (lambda p : p .name , transformer .params ):
144- self .assertEqual (transformer .getOrDefault (p ), estimator .getOrDefault (p ),
145- str (transformer .getOrDefault (p )))
168+ for param in transformer .params :
169+ param_name = param .name
170+ self .assertEqual (
171+ transformer .getOrDefault (param_name ), estimator .getOrDefault (param_name ),
172+ "Param should be equal for transformer generated from estimator: " + str (param ))
146173
147174 def test_tuning (self ):
148175 """Test that multiple model fitting using `CrossValidator` works well"""
@@ -161,12 +188,12 @@ def test_tuning(self):
161188 .build ()
162189 )
163190
164- bc = BinaryClassificationEvaluator (rawPredictionCol = self . output_col ,
165- labelCol = self .one_hot_label_col )
166- cv = CrossValidator (estimator = estimator , estimatorParamMaps = paramGrid , evaluator = bc ,
167- numFolds = 2 )
191+ evaluator = BinaryClassificationEvaluator (
192+ rawPredictionCol = self . output_col , labelCol = self .one_hot_label_col )
193+ validator = CrossValidator (
194+ estimator = estimator , estimatorParamMaps = paramGrid , evaluator = evaluator , numFolds = 2 )
168195
169- transformer = cv .fit (image_uri_df )
196+ transformer = validator .fit (image_uri_df )
170197 self .assertIsInstance (transformer .bestModel , KerasImageFileTransformer ,
171198 "best model should be an instance of KerasImageFileTransformer" )
172199 self .assertIn ('batch_size' , transformer .bestModel .getKerasFitParams (),
0 commit comments