|
5 | 5 | import numpy as np |
6 | 6 | from AnyQt.QtCore import Qt |
7 | 7 | from AnyQt.QtWidgets import QRadioButton |
| 8 | +from sklearn.metrics import silhouette_score |
8 | 9 |
|
9 | 10 | import Orange.clustering |
10 | 11 | from Orange.data import Table, Domain |
@@ -199,24 +200,22 @@ def test_data_on_output(self): |
199 | 200 |
|
200 | 201 | @patch("Orange.clustering.kmeans.KMeansModel.__call__") |
201 | 202 | def test_centroids_on_output(self, km_call): |
202 | | - ret = km_call.return_value = Mock() |
203 | | - ret.X = np.array([0] * 50 + [1] * 100) |
204 | | - ret.silhouette_samples = np.arange(150) / 150 |
| 203 | + km_call.return_value = np.array([0] * 50 + [1] * 100).flatten() |
205 | 204 |
|
206 | 205 | widget = self.widget |
207 | 206 | widget.optimize_k = False |
208 | 207 | widget.k = 4 |
209 | 208 | self.send_signal(widget.Inputs.data, self.iris) |
210 | 209 | self.commit_and_wait() |
211 | 210 |
|
212 | | - widget.clusterings[4].silhouette_samples = np.arange(150) / 150 |
| 211 | + widget.samples_scores = lambda x: np.arctan( |
| 212 | + np.arange(150) / 150) / np.pi + 0.5 |
213 | 213 | widget.send_data() |
214 | 214 | out = self.get_output(widget.Outputs.centroids) |
215 | | - np.testing.assert_almost_equal( |
216 | | - out.metas, |
217 | | - [[0, np.mean(np.arctan(np.arange(50) / 150)) / np.pi + 0.5], |
| 215 | + np.testing.assert_array_almost_equal( |
| 216 | + np.array([[0, np.mean(np.arctan(np.arange(50) / 150)) / np.pi + 0.5], |
218 | 217 | [1, np.mean(np.arctan(np.arange(50, 150) / 150)) / np.pi + 0.5], |
219 | | - [2, 0], [3, 0]]) |
| 218 | + [2, 0], [3, 0]]), out.metas.astype(float)) |
220 | 219 | self.assertEqual(out.name, "iris centroids") |
221 | 220 |
|
222 | 221 | def test_centroids_domain_on_output(self): |
@@ -262,12 +261,14 @@ def test_optimization_fails(self): |
262 | 261 | self.KMeansFail.fail_on = {3, 5, 7} |
263 | 262 | model = widget.table_view.model() |
264 | 263 |
|
265 | | - with patch.object(model, "set_scores", wraps=model.set_scores) as set_scores: |
| 264 | + with patch.object( |
| 265 | + model, "set_scores", wraps=model.set_scores) as set_scores: |
266 | 266 | self.send_signal(self.widget.Inputs.data, self.iris, wait=5000) |
267 | 267 | scores, start_k = set_scores.call_args[0] |
268 | 268 | self.assertEqual( |
269 | 269 | scores, |
270 | | - [km if isinstance(km, str) else km.silhouette |
| 270 | + [km if isinstance(km, str) else silhouette_score( |
| 271 | + self.iris.X, km(self.iris)) |
271 | 272 | for km in (widget.clusterings[k] for k in range(3, 9))] |
272 | 273 | ) |
273 | 274 | self.assertEqual(start_k, 3) |
@@ -312,15 +313,14 @@ def test_run_fails(self): |
312 | 313 | self.assertIsNotNone(self.get_output(self.widget.Outputs.annotated_data)) |
313 | 314 |
|
314 | 315 | def test_select_best_row(self): |
315 | | - class Cluster: |
316 | | - def __init__(self, n): |
317 | | - self.silhouette = n |
318 | | - |
319 | 316 | widget = self.widget |
320 | 317 | widget.k_from, widget.k_to = 2, 6 |
321 | | - widget.clusterings = {k: Cluster(5 - (k - 4) ** 2) for k in range(2, 7)} |
| 318 | + widget.optimize_k = True |
| 319 | + self.send_signal(self.widget.Inputs.data, Table("housing"), wait=5000) |
| 320 | + self.commit_and_wait() |
322 | 321 | widget.update_results() |
323 | | - self.assertEqual(widget.selected_row(), 2) |
| 322 | + # for housing dataset best selection is 3 clusters, so row no. 1 |
| 323 | + self.assertEqual(widget.selected_row(), 1) |
324 | 324 |
|
325 | 325 | widget.clusterings = {k: "error" for k in range(2, 7)} |
326 | 326 | widget.update_results() |
@@ -394,7 +394,9 @@ def test_silhouette_column(self): |
394 | 394 | # Avoid randomness in the test |
395 | 395 | random = np.random.RandomState(0) # pylint: disable=no-member |
396 | 396 | table = Table(random.rand(110, 2)) |
397 | | - with patch("Orange.clustering.kmeans.SILHOUETTE_MAX_SAMPLES", 100): |
| 397 | + with patch( |
| 398 | + "Orange.widgets.unsupervised.owkmeans.SILHOUETTE_MAX_SAMPLES", |
| 399 | + 100): |
398 | 400 | self.send_signal(self.widget.Inputs.data, table) |
399 | 401 | outtable = self.get_output(widget.Outputs.annotated_data) |
400 | 402 | outtable = outtable.get_column_view("Silhouette")[0] |
|
0 commit comments