Skip to content

Commit 5e46734

Browse files
authored
Merge branch 'main' into optionalise-variance-plotters
2 parents bb560bb + 27d1834 commit 5e46734

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

truelearn/learning/_ink_classifier.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Any, Optional, Dict, Tuple
2+
from typing import Any, Optional, Dict, Tuple, Union
33
from typing_extensions import Self, Final
44

55
import trueskill
@@ -92,8 +92,8 @@ def __init__(
9292
self,
9393
*,
9494
learner_meta_weights: Optional[LearnerMetaWeights] = None,
95-
novelty_classifier: Optional[NoveltyClassifier] = None,
96-
interest_classifier: Optional[InterestClassifier] = None,
95+
novelty_classifier: Optional[Union[NoveltyClassifier, Dict]] = None,
96+
interest_classifier: Optional[Union[InterestClassifier, Dict]] = None,
9797
threshold: float = 0.5,
9898
tau: float = 0.0,
9999
greedy: bool = False,
@@ -106,9 +106,13 @@ def __init__(
106106
learner_meta_weights:
107107
The novelty/interest/bias weights.
108108
novelty_classifier:
109-
The NoveltyClassifier.
109+
The NoveltyClassifier. It can be a NoveltyClassifier object or
110+
a dictionary of parameters that can be used to instantiate a
111+
NoveltyClassifier object.
110112
interest_classifier:
111-
The InterestClassifier.
113+
The InterestClassifier. It can be an InterestClassifier object or
114+
a dictionary of parameters that can be used to instantiate an
115+
InterestClassifier object.
112116
threshold:
113117
A float that determines the classification threshold.
114118
tau:
@@ -125,8 +129,20 @@ def __init__(
125129
TrueLearnValueError:
126130
Values of parameters do not satisfy their constraints.
127131
"""
128-
self._novelty_classifier = novelty_classifier or NoveltyClassifier()
129-
self._interest_classifier = interest_classifier or InterestClassifier()
132+
if novelty_classifier is None:
133+
self._novelty_classifier = NoveltyClassifier()
134+
elif isinstance(novelty_classifier, dict):
135+
self._novelty_classifier = NoveltyClassifier(**novelty_classifier)
136+
else:
137+
self._novelty_classifier = novelty_classifier
138+
139+
if interest_classifier is None:
140+
self._interest_classifier = InterestClassifier()
141+
elif isinstance(interest_classifier, dict):
142+
self._interest_classifier = InterestClassifier(**interest_classifier)
143+
else:
144+
self._interest_classifier = interest_classifier
145+
130146
self._learner_meta_weights = learner_meta_weights or LearnerMetaWeights()
131147
self._threshold = threshold
132148
self._tau = tau

truelearn/tests/test_learning.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,25 @@ def test_ink_classifier_customize(self, train_cases, test_events):
798798

799799
check_farray_close(actual_results, expected_results)
800800

801+
def test_ink_classifier_customize_via_dict(self, train_cases, test_events):
802+
classifier = learning.INKClassifier(
803+
novelty_classifier={"def_var": 0.4},
804+
interest_classifier={"beta": 0.2},
805+
)
806+
807+
train_events, train_labels = train_cases
808+
for event, label in zip(train_events, train_labels):
809+
classifier.fit(event, label)
810+
811+
expected_results = [
812+
0.36741905582712336,
813+
0.3247468970257655,
814+
0.33375911026514554,
815+
]
816+
actual_results = [classifier.predict_proba(event) for event in test_events]
817+
818+
check_farray_close(actual_results, expected_results)
819+
801820
def test_ink_get_set_params(self):
802821
classifier = learning.INKClassifier()
803822

0 commit comments

Comments
 (0)