22# pylint: disable=missing-docstring
33
44import unittest
5- import warnings
65from unittest .mock import Mock
6+ import warnings
77
88import numpy as np
99import scipy .sparse as sp
@@ -92,18 +92,21 @@ def test_compare_results_of_predict_and_predict_storage(self):
9292
9393 def test_predictions (self ):
9494 self ._test_predictions (sparse = None )
95+ self ._test_predictions_with_absent_class (sparse = None )
9596
9697 def test_predictions_csr_matrix (self ):
9798 with warnings .catch_warnings ():
9899 warnings .filterwarnings (
99100 "ignore" , ".*the matrix subclass.*" , PendingDeprecationWarning )
100101 self ._test_predictions (sparse = sp .csr_matrix )
102+ self ._test_predictions_with_absent_class (sparse = sp .csr_matrix )
101103
102104 def test_predictions_csc_matrix (self ):
103105 with warnings .catch_warnings ():
104106 warnings .filterwarnings (
105107 "ignore" , ".*the matrix subclass.*" , PendingDeprecationWarning )
106108 self ._test_predictions (sparse = sp .csc_matrix )
109+ self ._test_predictions_with_absent_class (sparse = sp .csc_matrix )
107110
108111 def _test_predictions (self , sparse ):
109112 x = np .array ([
@@ -205,6 +208,107 @@ def _test_predictions(self, sparse):
205208 np .testing .assert_almost_equal (exp_probs , probs )
206209 np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
207210
211+ def _test_predictions_with_absent_class (self , sparse ):
212+ """Empty classes should not affect predictions"""
213+ x = np .array ([
214+ [1 , 0 , 0 ],
215+ [0 , np .nan , 0 ],
216+ [0 , 1 , 0 ],
217+ [0 , 0 , 0 ],
218+ [1 , 2 , 0 ],
219+ [1 , 1 , 0 ],
220+ [1 , 2 , 0 ],
221+ [0 , 1 , 0 ]])
222+ if sparse is not None :
223+ x = sparse (x )
224+
225+ y = np .array ([0 , 0 , 0 , 2 , 2 , 2 , 3 , 3 ])
226+ domain = Domain (
227+ [DiscreteVariable ("a" , values = "ab" ),
228+ DiscreteVariable ("b" , values = "abc" ),
229+ DiscreteVariable ("c" , values = "a" )],
230+ DiscreteVariable ("y" , values = "abcd" ))
231+ data = Table .from_numpy (domain , x , y )
232+
233+ model = self .learner (data )
234+ np .testing .assert_almost_equal (
235+ model .class_prob ,
236+ [4 / 11 , 0 , 4 / 11 , 3 / 11 ]
237+ )
238+ np .testing .assert_almost_equal (
239+ np .exp (model .log_cont_prob [0 ]) * model .class_prob [:, None ],
240+ [[3 / 7 , 2 / 7 ], [0 , 0 ], [2 / 7 , 3 / 7 ], [2 / 7 , 2 / 7 ]])
241+ np .testing .assert_almost_equal (
242+ np .exp (model .log_cont_prob [1 ]) * model .class_prob [:, None ],
243+ [[2 / 5 , 1 / 3 , 1 / 5 ], [0 , 0 , 0 ], [2 / 5 , 1 / 3 , 2 / 5 ], [1 / 5 , 1 / 3 , 2 / 5 ]])
244+ np .testing .assert_almost_equal (
245+ np .exp (model .log_cont_prob [2 ]) * model .class_prob [:, None ],
246+ [[4 / 11 ], [0 ], [4 / 11 ], [3 / 11 ]])
247+
248+ test_x = np .array ([[a , b , 0 ] for a in [0 , 1 ] for b in [0 , 1 , 2 ]])
249+ # Classifiers reject csc matrices in the base class
250+ # Naive bayesian classifier supports them if predict_storage is
251+ # called directly, which we do below
252+ if sparse is not None and sparse is not sp .csc_matrix :
253+ test_x = sparse (test_x )
254+ test_y = np .full ((6 , ), np .nan )
255+ # The following was computed manually, too
256+ exp_probs = np .array ([
257+ [0.47368421052632 , 0 , 0.31578947368421 , 0.21052631578947 ],
258+ [0.39130434782609 , 0 , 0.26086956521739 , 0.34782608695652 ],
259+ [0.24324324324324 , 0 , 0.32432432432432 , 0.43243243243243 ],
260+ [0.31578947368421 , 0 , 0.47368421052632 , 0.21052631578947 ],
261+ [0.26086956521739 , 0 , 0.39130434782609 , 0.34782608695652 ],
262+ [0.15000000000000 , 0 , 0.45000000000000 , 0.40000000000000 ]
263+ ])
264+
265+ # Test the faster algorithm for Table (numpy matrices)
266+ test_data = Table .from_numpy (domain , test_x , test_y )
267+ probs = model (test_data , ret = model .Probs )
268+ np .testing .assert_almost_equal (exp_probs , probs )
269+ values = model (test_data )
270+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
271+ values , probs = model (test_data , ret = model .ValueProbs )
272+ np .testing .assert_almost_equal (exp_probs , probs )
273+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
274+
275+ # Test the slower algorithm for non-Table data (iteration in Python)
276+ test_data = NotATable .from_numpy (domain , test_x , test_y )
277+ probs = model (test_data , ret = model .Probs )
278+ np .testing .assert_almost_equal (exp_probs , probs )
279+ values = model (test_data )
280+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
281+ values , probs = model (test_data , ret = model .ValueProbs )
282+ np .testing .assert_almost_equal (exp_probs , probs )
283+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
284+
285+ # Test prediction directly on numpy
286+ probs = model (test_x , ret = model .Probs )
287+ np .testing .assert_almost_equal (exp_probs , probs )
288+ values = model (test_x )
289+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
290+ values , probs = model (test_x , ret = model .ValueProbs )
291+ np .testing .assert_almost_equal (exp_probs , probs )
292+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
293+
294+ # Test prediction on instances
295+ for inst , exp_prob in zip (test_data , exp_probs ):
296+ np .testing .assert_almost_equal (
297+ model (inst , ret = model .Probs )[0 ],
298+ exp_prob )
299+ self .assertEqual (model (inst ), np .argmax (exp_prob ))
300+ value , prob = model (inst , ret = model .ValueProbs )
301+ np .testing .assert_almost_equal (prob [0 ], exp_prob )
302+ self .assertEqual (value , np .argmax (exp_prob ))
303+
304+ # Test prediction by directly calling predict. This is needed to test
305+ # csc_matrix, but doesn't hurt others
306+ if sparse is sp .csc_matrix :
307+ test_x = sparse (test_x )
308+ values , probs = model .predict (test_x )
309+ np .testing .assert_almost_equal (exp_probs , probs )
310+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
311+
208312 def test_no_attributes (self ):
209313 y = np .array ([0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 ])
210314 domain = Domain ([], DiscreteVariable ("y" , values = "abc" ))
@@ -215,6 +319,5 @@ def test_no_attributes(self):
215319 [[4 / 11 , 4 / 11 , 3 / 11 ]] * 5
216320 )
217321
218-
219322if __name__ == "__main__" :
220323 unittest .main ()
0 commit comments