@@ -29,7 +29,49 @@ class ThresholdPredictorDumpMetadata(BaseMetadataDict):
2929
3030
3131class ThresholdPredictor (PredictionModule ):
32- """Threshold predictor module."""
32+ """
33+ Threshold predictor module.
34+
35+ ThresholdPredictor uses a predefined threshold (or array of thresholds) to predict
36+ labels for single-label or multi-label classification tasks.
37+
38+ :ivar metadata_dict_name: Filename for saving metadata to disk.
39+ :ivar multilabel: If True, the model supports multi-label classification.
40+ :ivar n_classes: Number of classes in the dataset.
41+ :ivar tags: Tags for predictions (if any).
42+ :ivar name: Name of the predictor, defaults to "adaptive".
43+
44+ Examples
45+ --------
46+ Single-label classification example:
47+ >>> from autointent.modules import ThresholdPredictor
48+ >>> import numpy as np
49+ >>> scores = np.array([[0.2, 0.8], [0.6, 0.4], [0.1, 0.9]])
50+ >>> labels = [1, 0, 1]
51+ >>> threshold = 0.5
52+ >>> predictor = ThresholdPredictor(thresh=threshold)
53+ >>> predictor.fit(scores, labels)
54+ >>> test_scores = np.array([[0.3, 0.7], [0.5, 0.5]])
55+ >>> predictions = predictor.predict(test_scores)
56+ >>> print(predictions)
57+ [1 0]
58+
59+ Multi-label classification example:
60+ >>> labels = [[1, 0], [0, 1], [1, 1]]
61+ >>> predictor = ThresholdPredictor(thresh=[0.5, 0.5])
62+ >>> predictor.fit(scores, labels)
63+ >>> test_scores = np.array([[0.3, 0.7], [0.6, 0.4]])
64+ >>> predictions = predictor.predict(test_scores)
65+ >>> print(predictions)
66+ [[0 1] [1 0]]
67+
68+ Save and load the model:
69+ >>> predictor.dump("outputs/")
70+ >>> loaded_predictor = ThresholdPredictor(thresh=0.5)
71+ >>> loaded_predictor.load("outputs/")
72+ >>> print(loaded_predictor.thresh)
73+ 0.5
74+ """
3375
3476 metadata : ThresholdPredictorDumpMetadata
3577 multilabel : bool
@@ -45,9 +87,6 @@ def __init__(
4587 Initialize threshold predictor.
4688
4789 :param thresh: Threshold for the scores, shape (n_classes,) or float
48- :param multilabel: If multilabel classification, default False
49- :param n_classes: Number of classes, default None
50- :param tags: Tags for predictions, default None
5190 """
5291 self .thresh = thresh
5392
0 commit comments