22# pylint: disable=missing-docstring
33
44import unittest
5+ import warnings
56from unittest .mock import Mock
67
78import numpy as np
1213from Orange .evaluation import CrossValidation , CA
1314
1415
16+ # This class is used to force predict_storage to fall back to the slower
17+ # procedure instead of calling `predict`
18+ class NotATable (Table ): # pylint: disable=too-many-ancestors,abstract-method
19+ pass
20+
21+
1522class TestNaiveBayesLearner (unittest .TestCase ):
1623 @classmethod
1724 def setUpClass (cls ):
@@ -32,20 +39,6 @@ def test_NaiveBayes(self):
3239 ca = CA (results )
3340 self .assertGreater (ca , 0.7 )
3441
35- def test_predict_single_instance (self ):
36- for ins in self .table :
37- self .model (ins )
38- val , prob = self .model (ins , self .model .ValueProbs )
39-
40- def test_predict_table (self ):
41- self .model (self .table )
42- vals , probs = self .model (self .table , self .model .ValueProbs )
43-
44- def test_predict_numpy (self ):
45- X = self .table .X [::20 ]
46- self .model (X )
47- vals , probs = self .model (X , self .model .ValueProbs )
48-
4942 def test_degenerate (self ):
5043 d = Domain ((ContinuousVariable (name = "A" ),
5144 ContinuousVariable (name = "B" ),
@@ -64,15 +57,6 @@ def test_allnan_cv(self):
6457 results = CrossValidation (data , [self .learner ])
6558 self .assertFalse (any (results .failed ))
6659
67- def test_sparse (self ):
68- _ , dense_p = self .model .predict (self .data .X )
69-
70- _ , csc_p = self .model .predict (sp .csc_matrix (self .data .X ))
71- np .testing .assert_almost_equal (dense_p , csc_p )
72-
73- _ , csr_p = self .model .predict (sp .csr_matrix (self .data .X ))
74- np .testing .assert_almost_equal (dense_p , csr_p )
75-
7660 def test_prediction_routing (self ):
7761 data = self .data
7862 predict = self .model .predict = Mock (return_value = (data .Y , None ))
@@ -92,6 +76,132 @@ def test_prediction_routing(self):
9276 self .model .predict_storage (data [0 ])
9377 predict .assert_not_called ()
9478
79+ def test_compare_results_of_storage_and_predict_storage (self ):
80+ data2 = NotATable ("titanic" )
81+
82+ self .model = self .learner (self .data [:50 ])
83+ values , probs = self .model .predict_storage (self .data [50 :])
84+ values2 , probs2 = self .model .predict_storage (data2 [50 :])
85+ np .testing .assert_equal (values , values2 )
86+ np .testing .assert_equal (probs , probs2 )
87+
88+ def test_predictions (self ):
89+ self ._test_predictions (sparse = None )
90+
91+ def test_predictions_csr_matrix (self ):
92+ with warnings .catch_warnings ():
93+ warnings .filterwarnings (
94+ "ignore" , ".*the matrix subclass.*" , PendingDeprecationWarning )
95+ self ._test_predictions (sparse = sp .csr_matrix )
96+
97+ def test_predictions_csc_matrix (self ):
98+ with warnings .catch_warnings ():
99+ warnings .filterwarnings (
100+ "ignore" , ".*the matrix subclass.*" , PendingDeprecationWarning )
101+ self ._test_predictions (sparse = sp .csc_matrix )
102+
103+ def _test_predictions (self , sparse ):
104+ x = np .array ([
105+ [1 , 0 , 0 ],
106+ [0 , np .nan , 0 ],
107+ [0 , 1 , 0 ],
108+ [0 , 0 , 0 ],
109+ [1 , 2 , 0 ],
110+ [1 , 1 , 0 ],
111+ [1 , 2 , 0 ],
112+ [0 , 1 , 0 ]])
113+ if sparse is not None :
114+ x = sparse (x )
115+
116+ y = np .array ([0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 ])
117+ domain = Domain (
118+ [DiscreteVariable ("a" , values = "ab" ),
119+ DiscreteVariable ("b" , values = "abc" ),
120+ DiscreteVariable ("c" , values = "a" )],
121+ DiscreteVariable ("y" , values = "abc" ))
122+ data = Table .from_numpy (domain , x , y )
123+
124+ model = self .learner (data )
125+ np .testing .assert_almost_equal (
126+ model .class_prob ,
127+ [4 / 11 , 4 / 11 , 3 / 11 ]
128+ )
129+ np .testing .assert_almost_equal (
130+ np .exp (model .log_cont_prob [0 ]) * model .class_prob [:, None ],
131+ [[3 / 7 , 2 / 7 ], [2 / 7 , 3 / 7 ], [2 / 7 , 2 / 7 ]])
132+ np .testing .assert_almost_equal (
133+ np .exp (model .log_cont_prob [1 ]) * model .class_prob [:, None ],
134+ [[2 / 5 , 1 / 3 , 1 / 5 ], [2 / 5 , 1 / 3 , 2 / 5 ], [1 / 5 , 1 / 3 , 2 / 5 ]])
135+ np .testing .assert_almost_equal (
136+ np .exp (model .log_cont_prob [2 ]) * model .class_prob [:, None ],
137+ [[4 / 11 ], [4 / 11 ], [3 / 11 ]])
138+
139+ test_x = np .array ([[a , b , 0 ] for a in [0 , 1 ] for b in [0 , 1 , 2 ]])
140+ # Model.__call__ does not accept csc matrices
141+ # We however test the classifier with csc_matrix (below)
142+ if sparse is not None and sparse is not sp .csc_matrix :
143+ test_x = sparse (test_x )
144+ test_y = np .full ((6 , ), np .nan )
145+ # The following was computed manually, too
146+ exp_probs = np .array ([
147+ [0.47368421052632 , 0.31578947368421 , 0.21052631578947 ],
148+ [0.39130434782609 , 0.26086956521739 , 0.34782608695652 ],
149+ [0.24324324324324 , 0.32432432432432 , 0.43243243243243 ],
150+ [0.31578947368421 , 0.47368421052632 , 0.21052631578947 ],
151+ [0.26086956521739 , 0.39130434782609 , 0.34782608695652 ],
152+ [0.15000000000000 , 0.45000000000000 , 0.40000000000000 ]
153+ ])
154+
155+ # Test the faster algorithm for Table (numpy matrices)
156+ test_data = Table .from_numpy (domain , test_x , test_y )
157+ probs = model (test_data , ret = model .Probs )
158+ np .testing .assert_almost_equal (exp_probs , probs )
159+ values = model (test_data )
160+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
161+ values , probs = model (test_data , ret = model .ValueProbs )
162+ np .testing .assert_almost_equal (exp_probs , probs )
163+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
164+
165+ # Test the slower algorithm for non-Table data (iteration in Python)
166+ test_data = NotATable .from_numpy (domain , test_x , test_y )
167+ probs = model (test_data , ret = model .Probs )
168+ np .testing .assert_almost_equal (exp_probs , probs )
169+ values = model (test_data )
170+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
171+ values , probs = model (test_data , ret = model .ValueProbs )
172+ np .testing .assert_almost_equal (exp_probs , probs )
173+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
174+
175+ # Test prediction directly on numpy
176+ probs = model (test_x , ret = model .Probs )
177+ np .testing .assert_almost_equal (exp_probs , probs )
178+ values = model (test_x )
179+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
180+ values , probs = model (test_x , ret = model .ValueProbs )
181+ np .testing .assert_almost_equal (exp_probs , probs )
182+ np .testing .assert_equal (values , np .argmax (exp_probs , axis = 1 ))
183+
184+ # Test prediction on instances
185+ for inst , exp_prob in zip (test_data , exp_probs ):
186+ np .testing .assert_almost_equal (
187+ model (inst , ret = model .Probs )[0 ],
188+ exp_prob )
189+ self .assertEqual (model (inst ), np .argmax (exp_prob ))
190+ value , prob = model (inst , ret = model .ValueProbs )
191+ np .testing .assert_almost_equal (prob [0 ], exp_prob )
192+ self .assertEqual (value , np .argmax (exp_prob ))
193+
194+ def test_no_attributes (self ):
195+ y = np .array ([0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 ])
196+ domain = Domain ([], DiscreteVariable ("y" , values = "abc" ))
197+ data = Table .from_numpy (domain , np .zeros ((len (y ), 0 )), y .T )
198+ test_data = Table .from_numpy (domain , np .zeros ((5 , 0 )), np .zeros ((5 , 1 )))
199+ model = self .learner (data )
200+ np .testing .assert_almost_equal (
201+ model .predict_storage (test_data )[1 ],
202+ [[4 / 11 , 4 / 11 , 3 / 11 ]] * 5
203+ )
204+
95205
96206if __name__ == "__main__" :
97207 unittest .main ()
0 commit comments