Skip to content

Commit 3cfc122

Browse files
committed
classification: Add ModelWithThreshold
1 parent 3ee1cb1 commit 3cfc122

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

Orange/classification/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from .rules import *
2020
from .sgd import *
2121
from .neural_network import *
22+
from .calibration import *
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from Orange.classification import Model
2+
3+
__all__ = ["ModelWithThreshold"]
4+
5+
6+
class ModelWithThreshold(Model):
7+
def __init__(self, wrapped_model, threshold, target_class=1):
8+
super().__init__(wrapped_model.domain, wrapped_model.original_domain)
9+
self.name = f"{wrapped_model.name}, thresh={threshold:.2f}"
10+
self.wrapped_model = wrapped_model
11+
self.threshold = threshold
12+
self.target_class = target_class
13+
14+
def __call__(self, data, ret=Model.Value):
15+
probs = self.wrapped_model(data, ret=Model.Probs)
16+
if ret == Model.Probs:
17+
return probs
18+
vals = probs[:, self.target_class].flatten() > self.threshold
19+
if ret == Model.Value:
20+
return vals
21+
else:
22+
return vals, probs

0 commit comments

Comments
 (0)