22# pylint: disable=missing-docstring
33
44import unittest
5+ import warnings
56
67import numpy as np
7- from scipy .sparse import csc_matrix
8+ from scipy .sparse import csc_matrix , csr_matrix
89
910import Orange
10- from Orange .clustering .kmeans import KMeans
11+ from Orange .clustering .kmeans import KMeans , KMeansModel
12+ from Orange .data import Table , Domain , ContinuousVariable
13+ from Orange .data .table import DomainTransformationError
1114
1215
1316class TestKMeans (unittest .TestCase ):
@@ -18,25 +21,115 @@ def setUp(self):
1821 def test_kmeans (self ):
1922 c = self .kmeans (self .iris )
2023 # First 20 iris belong to one cluster
24+ self .assertEqual (np .ndarray , type (c ))
25+ self .assertEqual (len (self .iris ), len (c ))
2126 self .assertEqual (1 , len (set (c [:20 ].ravel ())))
2227
2328 def test_kmeans_parameters (self ):
2429 kmeans = KMeans (n_clusters = 10 , max_iter = 10 , random_state = 42 , tol = 0.001 ,
2530 init = 'random' )
26- kmeans (self .iris )
31+ c = kmeans (self .iris )
32+ self .assertEqual (np .ndarray , type (c ))
33+ self .assertEqual (len (self .iris ), len (c ))
2734
2835 def test_predict_table (self ):
29- kmeans = KMeans ()
30- c = kmeans (self .iris )
36+ c = self .kmeans (self .iris )
3137 self .assertEqual (np .ndarray , type (c ))
38+ self .assertEqual (len (self .iris ), len (c ))
3239
3340 def test_predict_numpy (self ):
34- kmeans = KMeans ( )
35- c = kmeans . fit ( self .iris . X )
41+ c = self . kmeans . fit ( self . iris . X )
42+ self .assertEqual ( KMeansModel , type ( c ) )
3643 self .assertEqual (np .ndarray , type (c .labels ))
44+ self .assertEqual (len (self .iris ), len (c .labels ))
3745
3846 def test_predict_sparse (self ):
39- kmeans = KMeans ()
4047 self .iris .X = csc_matrix (self .iris .X [::20 ])
41- c = kmeans (self .iris )
48+ c = self . kmeans (self .iris )
4249 self .assertEqual (np .ndarray , type (c ))
50+ self .assertEqual (len (self .iris ), len (c ))
51+
52+ def test_model (self ):
53+ c = self .kmeans .get_model (self .iris )
54+ self .assertEqual (KMeansModel , type (c ))
55+ self .assertEqual (len (self .iris ), len (c .labels ))
56+
57+ c1 = c (self .iris )
58+ # prediction of the model must be same since data are same
59+ np .testing .assert_array_almost_equal (c .labels , c1 )
60+
61+ def test_model_np (self ):
62+ """
63+ Test with numpy array as an input in model.
64+ """
65+ c = self .kmeans .get_model (self .iris )
66+ c1 = c (self .iris .X )
67+ # prediction of the model must be same since data are same
68+ np .testing .assert_array_almost_equal (c .labels , c1 )
69+
70+ def test_model_sparse (self ):
71+ """
72+ Test with sparse array as an input in model.
73+ """
74+ c = self .kmeans .get_model (self .iris )
75+ c1 = c (csr_matrix (self .iris .X ))
76+ # prediction of the model must be same since data are same
77+ np .testing .assert_array_almost_equal (c .labels , c1 )
78+
79+ def test_model_instance (self ):
80+ """
81+ Test with instance as an input in model.
82+ """
83+ c = self .kmeans .get_model (self .iris )
84+ c1 = c (self .iris [0 ])
85+ # prediction of the model must be same since data are same
86+ self .assertEqual (c1 , c .labels [0 ])
87+
88+ def test_model_list (self ):
89+ """
90+ Test with list as an input in model.
91+ """
92+ c = self .kmeans .get_model (self .iris )
93+ c1 = c (self .iris .X .tolist ())
94+ # prediction of the model must be same since data are same
95+ np .testing .assert_array_almost_equal (c .labels , c1 )
96+
97+ # example with a list of only one data item
98+ c1 = c (self .iris .X .tolist ()[0 ])
99+ # prediction of the model must be same since data are same
100+ np .testing .assert_array_almost_equal (c .labels [0 ], c1 )
101+
102+ def test_model_bad_datatype (self ):
103+ """
104+ Check model with data-type that is not supported.
105+ """
106+ c = self .kmeans .get_model (self .iris )
107+ self .assertRaises (TypeError , c , 10 )
108+
109+ def test_model_data_table_domain (self ):
110+ """
111+ Check model with data-type that is not supported.
112+ """
113+ # ok domain
114+ data = Table (Domain (
115+ list (self .iris .domain .attributes ) + [ContinuousVariable ("a" )]),
116+ np .concatenate ((self .iris .X , np .ones ((len (self .iris ), 1 ))), axis = 1 ))
117+ c = self .kmeans .get_model (self .iris )
118+ res = c (data )
119+ np .testing .assert_array_almost_equal (c .labels , res )
120+
121+ # totally different domain - should fail
122+ self .assertRaises (DomainTransformationError , c , Table ("housing" ))
123+
124+ def test_deprecated_silhouette (self ):
125+ with warnings .catch_warnings (record = True ) as w :
126+ KMeans (compute_silhouette_score = True )
127+
128+ assert len (w ) == 1
129+ assert issubclass (w [- 1 ].category , DeprecationWarning )
130+
131+ with warnings .catch_warnings (record = True ) as w :
132+ KMeans (compute_silhouette_score = False )
133+
134+ assert len (w ) == 1
135+ assert issubclass (w [- 1 ].category , DeprecationWarning )
0 commit comments