Skip to content
This repository was archived by the owner on Dec 21, 2023. It is now read-only.

Commit 4d67860

Browse files
[6.1] AC invalid k for predict_topk. (#2856)
1 parent dee325d commit 4d67860

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

src/python/turicreate/test/test_activity_classifier.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,18 @@ def test_predict_topk(self):
439439
expected_len = self._calc_expected_predictions_length(self.data.head(100), top_k=5)
440440
self.assertEqual(len(preds), expected_len)
441441

442+
def test_predict_topk_invalid_k(self):
443+
model = self.model
444+
with self.assertRaises(_ToolkitError):
445+
preds = model.predict_topk(self.data, k=-1)
446+
447+
with self.assertRaises(_ToolkitError):
448+
preds = model.predict_topk(self.data, k=0)
449+
450+
with self.assertRaises(TypeError):
451+
preds = model.predict_topk(self.data, k=[])
452+
453+
442454
def test_evaluate_with_incomplete_targets(self):
443455
"""
444456
Check that evaluation does not require the test data to span all labels.

src/python/turicreate/toolkits/activity_classifier/_activity_classifier.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,9 @@ def predict_topk(self, dataset, output_type='probability', k=3, output_frequency
611611
| ... | ... | ... |
612612
+---------------+-------+-------------------+
613613
"""
614+
if not isinstance(k, int):
615+
raise TypeError('k must be of type int')
616+
_tkutl._numeric_param_check_range('k', k, 1, _six.MAXSIZE)
614617
return self.__proxy__.predict_topk(dataset, output_type, k, output_frequency);
615618

616619
def classify(self, dataset, output_frequency='per_row'):

0 commit comments

Comments
 (0)