@@ -42,34 +42,41 @@ class ThresholdPredictor(PredictionModule):
4242
4343 Examples
4444 --------
45- Single-label classification example:
46- >>> from autointent.modules import ThresholdPredictor
47- >>> import numpy as np
48- >>> scores = np.array([[0.2, 0.8], [0.6, 0.4], [0.1, 0.9]])
49- >>> labels = [1, 0, 1]
50- >>> threshold = 0.5
51- >>> predictor = ThresholdPredictor(thresh=threshold)
52- >>> predictor.fit(scores, labels)
53- >>> test_scores = np.array([[0.3, 0.7], [0.5, 0.5]])
54- >>> predictions = predictor.predict(test_scores)
55- >>> print(predictions)
56- [1 0]
57-
58- Multi-label classification example:
59- >>> labels = [[1, 0], [0, 1], [1, 1]]
60- >>> predictor = ThresholdPredictor(thresh=[0.5, 0.5])
61- >>> predictor.fit(scores, labels)
62- >>> test_scores = np.array([[0.3, 0.7], [0.6, 0.4]])
63- >>> predictions = predictor.predict(test_scores)
64- >>> print(predictions)
65- [[0 1] [1 0]]
66-
67- Save and load the model:
68- >>> predictor.dump("outputs/")
69- >>> loaded_predictor = ThresholdPredictor(thresh=0.5)
70- >>> loaded_predictor.load("outputs/")
71- >>> print(loaded_predictor.thresh)
72- 0.5
45+ Single-label classification
46+ ===========================
47+ .. testcode::
48+
49+ from autointent.modules import ThresholdPredictor
50+ import numpy as np
51+ scores = np.array([[0.2, 0.8], [0.6, 0.4], [0.1, 0.9]])
52+ labels = [1, 0, 1]
53+ threshold = 0.5
54+ predictor = ThresholdPredictor(thresh=threshold)
55+ predictor.fit(scores, labels)
56+ test_scores = np.array([[0.3, 0.7], [0.5, 0.5]])
57+ predictions = predictor.predict(test_scores)
58+ print(predictions)
59+
60+ .. testoutput::
61+
62+ [1 0]
63+
64+ Multi-label classification
65+ ==========================
66+ .. testcode::
67+
68+ labels = [[1, 0], [0, 1], [1, 1]]
69+ predictor = ThresholdPredictor(thresh=[0.5, 0.5])
70+ predictor.fit(scores, labels)
71+ test_scores = np.array([[0.3, 0.7], [0.6, 0.4]])
72+ predictions = predictor.predict(test_scores)
73+ print(predictions)
74+
75+ .. testoutput::
76+
77+ [[0 1]
78+ [1 0]]
79+
7380 """
7481
7582 metadata : ThresholdPredictorDumpMetadata
0 commit comments