@@ -123,7 +123,9 @@ def test_buildtfgraphforname(self):
123123 tfPredict = sess .run (outputTensor , {inputTensor : imageArray })
124124
125125 self .assertEqual (kerasPredict .shape , tfPredict .shape )
126- np .testing .assert_array_almost_equal (kerasPredict , tfPredict , decimal = 5 )
126+ np .testing .assert_array_almost_equal (kerasPredict ,
127+ tfPredict ,
128+ decimal = self .featurizerCompareDigitsExact )
127129
128130 def _rowWithImage (self , img ):
129131 row = imageIO .imageArrayToStruct (img .astype ('uint8' ))
@@ -152,7 +154,9 @@ def test_DeepImagePredictorNoReshape(self):
152154 dfPredict = np .array ([i .prediction for i in dfPredict ])
153155
154156 self .assertEqual (kerasPredict .shape , dfPredict .shape )
155- np .testing .assert_array_almost_equal (kerasPredict , dfPredict )
157+ np .testing .assert_array_almost_equal (kerasPredict ,
158+ dfPredict ,
159+ decimal = self .featurizerCompareDigitsExact )
156160
157161 def test_DeepImagePredictor (self ):
158162 """
@@ -164,7 +168,9 @@ def test_DeepImagePredictor(self):
164168 fullPredict = self ._sortByFileOrder (transformer .transform (self .imageDF ).collect ())
165169 fullPredict = np .array ([i .prediction for i in fullPredict ])
166170 self .assertEqual (kerasPredict .shape , fullPredict .shape )
167- np .testing .assert_array_almost_equal (kerasPredict , fullPredict , decimal = 6 )
171+ np .testing .assert_array_almost_equal (kerasPredict ,
172+ fullPredict ,
173+ decimal = self .featurizerCompareDigitsExact )
168174
169175 def test_prediction_decoded (self ):
170176 """
@@ -200,7 +206,9 @@ def test_featurization_no_reshape(self):
200206 dfFeatures = transformer .transform (imageDf ).collect ()
201207 dfFeatures = np .array ([i .features for i in dfFeatures ])
202208 kerasReshaped = self .kerasFeatures .reshape (self .kerasFeatures .shape [0 ], - 1 )
203- np .testing .assert_array_almost_equal (kerasReshaped , dfFeatures , decimal = self .featurizerCompareDigitsExact )
209+ np .testing .assert_array_almost_equal (kerasReshaped ,
210+ dfFeatures ,
211+ decimal = self .featurizerCompareDigitsExact )
204212
205213
206214 def test_featurization (self ):
0 commit comments