22# pylint: disable=missing-docstring
33
44import unittest
5+ import warnings
6+ from unittest .mock import Mock
7+
8+ import numpy as np
9+ import scipy .sparse as sp
510
611from Orange .classification import NaiveBayesLearner
712from Orange .data import Table , Domain , DiscreteVariable , ContinuousVariable
813from Orange .evaluation import CrossValidation , CA
914
1015
16+ # This class is used to force predict_storage to fall back to the slower
17+ # procedure instead of calling `predict`
18+ class NotATable (Table ): # pylint: disable=too-many-ancestors,abstract-method
19+ pass
20+
21+
1122class TestNaiveBayesLearner (unittest .TestCase ):
1223 @classmethod
1324 def setUpClass (cls ):
14- data = Table ('titanic' )
25+ cls . data = data = Table ('titanic' )
1526 cls .learner = NaiveBayesLearner ()
16- cls .model = cls .learner (data )
1727 cls .table = data [::20 ]
1828
29+ def setUp (self ):
30+ self .model = self .learner (self .data )
31+
1932 def test_NaiveBayes (self ):
2033 results = CrossValidation (self .table , [self .learner ], k = 10 )
2134 ca = CA (results )
2235 self .assertGreater (ca , 0.7 )
2336 self .assertLess (ca , 0.9 )
2437
25- def test_predict_single_instance (self ):
26- for ins in self .table :
27- self .model (ins )
28- val , prob = self .model (ins , self .model .ValueProbs )
29-
30- def test_predict_table (self ):
31- self .model (self .table )
32- vals , probs = self .model (self .table , self .model .ValueProbs )
33-
34- def test_predict_numpy (self ):
35- X = self .table .X [::20 ]
36- self .model (X )
37- vals , probs = self .model (X , self .model .ValueProbs )
38+ results = CrossValidation (Table ("iris" ), [self .learner ], k = 10 )
39+ ca = CA (results )
40+ self .assertGreater (ca , 0.7 )
3841
3942 def test_degenerate (self ):
4043 d = Domain ((ContinuousVariable (name = "A" ),
@@ -53,3 +56,165 @@ def test_allnan_cv(self):
5356 data = Table ('voting' )
5457 results = CrossValidation (data , [self .learner ])
5558 self .assertFalse (any (results .failed ))
59+
60+ def test_prediction_routing (self ):
61+ data = self .data
62+ predict = self .model .predict = Mock (return_value = (data .Y , None ))
63+
64+ self .model (data )
65+ predict .assert_called ()
66+ predict .reset_mock ()
67+
68+ self .model (data .X )
69+ predict .assert_called ()
70+ predict .reset_mock ()
71+
72+ self .model .predict_storage (data )
73+ predict .assert_called ()
74+ predict .reset_mock ()
75+
76+ self .model .predict_storage (data [0 ])
77+ predict .assert_called ()
78+
79+ def test_compare_results_of_predict_and_predict_storage (self ):
80+ data2 = NotATable ("titanic" )
81+
82+ self .model = self .learner (self .data [:50 ])
83+ predict = self .model .predict = Mock (side_effect = self .model .predict )
84+ values , probs = self .model .predict_storage (self .data [50 :])
85+ predict .assert_called ()
86+ predict .reset_mock ()
87+ values2 , probs2 = self .model .predict_storage (data2 [50 :])
88+ predict .assert_not_called ()
89+
90+ np .testing .assert_equal (values , values2 )
91+ np .testing .assert_equal (probs , probs2 )
92+
93+ def test_predictions (self ):
94+ self ._test_predictions (sparse = None )
95+
96+ def test_predictions_csr_matrix (self ):
97+ with warnings .catch_warnings ():
98+ warnings .filterwarnings (
99+ "ignore" , ".*the matrix subclass.*" , PendingDeprecationWarning )
100+ self ._test_predictions (sparse = sp .csr_matrix )
101+
102+ def test_predictions_csc_matrix (self ):
103+ with warnings .catch_warnings ():
104+ warnings .filterwarnings (
105+ "ignore" , ".*the matrix subclass.*" , PendingDeprecationWarning )
106+ self ._test_predictions (sparse = sp .csc_matrix )
107+
108+ def _test_predictions (self , sparse ):
109+ x = np .array ([
110+ [1 , 0 , 0 ],
111+ [0 , np .nan , 0 ],
112+ [0 , 1 , 0 ],
113+ [0 , 0 , 0 ],
114+ [1 , 2 , 0 ],
115+ [1 , 1 , 0 ],
116+ [1 , 2 , 0 ],
117+ [0 , 1 , 0 ]])
118+ if sparse is not None :
119+ x = sparse (x )
120+
121+ y = np .array ([0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 ])
122+ domain = Domain (
123+ [DiscreteVariable ("a" , values = "ab" ),
124+ DiscreteVariable ("b" , values = "abc" ),
125+ DiscreteVariable ("c" , values = "a" )],
126+ DiscreteVariable ("y" , values = "abc" ))
127+ data = Table .from_numpy (domain , x , y )
128+
129+ model = self .learner (data )
130+ np .testing .assert_almost_equal (
131+ model .class_prob ,
132+ [4 / 11 , 4 / 11 , 3 / 11 ]
133+ )
134+ np .testing .assert_almost_equal (
135+ np .exp (model .log_cont_prob [0 ]) * model .class_prob [:, None ],
136+ [[3 / 7 , 2 / 7 ], [2 / 7 , 3 / 7 ], [2 / 7 , 2 / 7 ]])
137+ np .testing .assert_almost_equal (
138+ np .exp (model .log_cont_prob [1 ]) * model .class_prob [:, None ],
139+ [[2 / 5 , 1 / 3 , 1 / 5 ], [2 / 5 , 1 / 3 , 2 / 5 ], [1 / 5 , 1 / 3 , 2 / 5 ]])
140+ np .testing .assert_almost_equal (
141+ np .exp (model .log_cont_prob [2 ]) * model .class_prob [:, None ],
142+ [[4 / 11 ], [4 / 11 ], [3 / 11 ]])
143+
144+ test_x = np .array ([[a , b , 0 ] for a in [0 , 1 ] for b in [0 , 1 , 2 ]])
145+ # Classifiers reject csc matrices in the base class
146+ # Naive bayesian classifier supports them if predict_storage is
147+ # called directly, which we do below
148+ if sparse is not None and sparse is not sp .csc_matrix :
149+ test_x = sparse (test_x )
150+ test_y = np .full ((6 , ), np .nan )
151+ # The following was computed manually, too
152+ exp_probs = np .array ([
153+ [0.47368421052632 , 0.31578947368421 , 0.21052631578947 ],
154+ [0.39130434782609 , 0.26086956521739 , 0.34782608695652 ],
155+ [0.24324324324324 , 0.32432432432432 , 0.43243243243243 ],
156+ [0.31578947368421 , 0.47368421052632 , 0.21052631578947 ],
157+ [0.26086956521739 , 0.39130434782609 , 0.34782608695652 ],
158+ [0.15000000000000 , 0.45000000000000 , 0.40000000000000 ]
159+ ])
160+
161+ # Test the faster algorithm for Table (numpy matrices)
162+ test_data = Table .from_numpy (domain , test_x , test_y )
163+ probs = model (test_data , ret = model .Probs )
164+ np .testing .assert_almost_equal (exp_probs , probs )
165+ values = model (test_data )
166+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
167+ values , probs = model (test_data , ret = model .ValueProbs )
168+ np .testing .assert_almost_equal (exp_probs , probs )
169+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
170+
171+ # Test the slower algorithm for non-Table data (iteration in Python)
172+ test_data = NotATable .from_numpy (domain , test_x , test_y )
173+ probs = model (test_data , ret = model .Probs )
174+ np .testing .assert_almost_equal (exp_probs , probs )
175+ values = model (test_data )
176+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
177+ values , probs = model (test_data , ret = model .ValueProbs )
178+ np .testing .assert_almost_equal (exp_probs , probs )
179+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
180+
181+ # Test prediction directly on numpy
182+ probs = model (test_x , ret = model .Probs )
183+ np .testing .assert_almost_equal (exp_probs , probs )
184+ values = model (test_x )
185+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
186+ values , probs = model (test_x , ret = model .ValueProbs )
187+ np .testing .assert_almost_equal (exp_probs , probs )
188+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
189+
190+ # Test prediction on instances
191+ for inst , exp_prob in zip (test_data , exp_probs ):
192+ np .testing .assert_almost_equal (
193+ model (inst , ret = model .Probs )[0 ],
194+ exp_prob )
195+ self .assertEqual (model (inst ), np .argmax (exp_prob ))
196+ value , prob = model (inst , ret = model .ValueProbs )
197+ np .testing .assert_almost_equal (prob [0 ], exp_prob )
198+ self .assertEqual (value , np .argmax (exp_prob ))
199+
200+ # Test prediction by directly calling predict. This is needed to test
201+ # csc_matrix, but doesn't hurt others
202+ if sparse is sp .csc_matrix :
203+ test_x = sparse (test_x )
204+ values , probs = model .predict (test_x )
205+ np .testing .assert_almost_equal (exp_probs , probs )
206+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
207+
208+ def test_no_attributes (self ):
209+ y = np .array ([0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 ])
210+ domain = Domain ([], DiscreteVariable ("y" , values = "abc" ))
211+ data = Table .from_numpy (domain , np .zeros ((len (y ), 0 )), y .T )
212+ model = self .learner (data )
213+ np .testing .assert_almost_equal (
214+ model .predict_storage (np .zeros ((5 , 0 )))[1 ],
215+ [[4 / 11 , 4 / 11 , 3 / 11 ]] * 5
216+ )
217+
218+
219+ if __name__ == "__main__" :
220+ unittest .main ()
0 commit comments