Skip to content

Commit 0148279

Browse files
authored
Merge pull request #72 from TrueLearnAI/update-interest
fix: use binary skill representation in interest
2 parents 69a6e95 + e2a958b commit 0148279

File tree

3 files changed

+41
-28
lines changed

3 files changed

+41
-28
lines changed

truelearn/learning/_ink_classifier.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ class INKClassifier(BaseClassifier):
6464
... ink_classifier.predict_proba(event)
6565
... )
6666
...
67-
True 0.64839...
68-
False 0.43767...
69-
True 0.65660...
67+
True 0.64387...
68+
False 0.42658...
69+
True 0.65406...
7070
>>> ink_classifier.get_params(deep=False) # doctest:+ELLIPSIS
7171
{...'learner_meta_weights': LearnerMetaWeights(novelty_weights=Weights(\
72-
mean=0.20461..., variance=0.45871...), interest_weights=Weights(\
73-
mean=0.66315..., variance=0.42187...), bias_weights=Weights(\
74-
mean=0.12698..., variance=0.39796...))...}
72+
mean=0.20787..., variance=0.45787...), interest_weights=Weights(\
73+
mean=0.66924..., variance=0.42672...), bias_weights=Weights(\
74+
mean=0.13029..., variance=0.39582...))...}
7575
7676
"""
7777

truelearn/learning/_interest_classifier.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ class InterestClassifier(InterestNoveltyKnowledgeBaseClassifier):
7777
... interest_classifier.predict_proba(event)
7878
... )
7979
...
80-
True 0.88450...
81-
True 0.81079...
82-
True 0.95872...
80+
True 0.87299...
81+
True 0.69146...
82+
True 0.91941...
8383
>>> interest_classifier.get_params() # doctest:+ELLIPSIS
8484
{..., 'learner_model': LearnerModel(knowledge=Knowledge(knowledge=\
85-
{1: KnowledgeComponent(mean=0.99556..., variance=0.10483..., ...), ...}), ...}
85+
{1: KnowledgeComponent(mean=1.3651..., variance=0.07128..., ...), ...}), ...}
8686
"""
8787

8888
_parameter_constraints: Dict[str, Any] = {
@@ -196,6 +196,25 @@ def __get_decay_func(self) -> Callable[[float], float]:
196196

197197
return lambda t_delta: min(math.exp(-self._decay_func_factor * t_delta), 1.0)
198198

199+
@staticmethod
200+
def _content_kc_masks(
201+
content_kcs: Iterable[BaseKnowledgeComponent],
202+
) -> Iterable[BaseKnowledgeComponent]:
203+
"""Return a new iterable of content's knowledge components.
204+
205+
Args:
206+
content_kcs: An iterable of content's knowledge components.
207+
208+
Returns:
209+
A new iterable of content's knowledge components, where
210+
the mean of each knowledge component is set to 1,
211+
based on the assumption of the TrueLearn Interest model.
212+
"""
213+
return (
214+
kc.clone(mean=1.0, variance=kc.variance, timestamp=kc.timestamp)
215+
for kc in content_kcs
216+
)
217+
199218
def _generate_ratings(
200219
self,
201220
env: trueskill.TrueSkill,
@@ -257,7 +276,9 @@ def __apply_interest_decay(
257276
learner_kcs_decayed = map(__apply_interest_decay, learner_kcs)
258277

259278
team_learner = gather_trueskill_team(env, learner_kcs_decayed)
260-
team_content = gather_trueskill_team(env, content_kcs)
279+
team_content = gather_trueskill_team(
280+
env, InterestClassifier._content_kc_masks(content_kcs)
281+
)
261282

262283
# learner always wins in interest
263284
updated_team_learner, _ = env.rate([team_learner, team_content], ranks=[0, 1])
@@ -269,4 +290,6 @@ def _eval_matching_quality(
269290
learner_kcs: Iterable[BaseKnowledgeComponent],
270291
content_kcs: Iterable[BaseKnowledgeComponent],
271292
) -> float:
272-
return team_sum_quality_from_kcs(learner_kcs, content_kcs, self._beta)
293+
return team_sum_quality_from_kcs(
294+
learner_kcs, InterestClassifier._content_kc_masks(content_kcs), self._beta
295+
)

truelearn/tests/test_learning.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -618,16 +618,6 @@ def test_interest_get_set_params(self):
618618
"the draw_proba_static should not be None."
619619
)
620620

621-
def test_interest_positive_easy(self):
622-
classifier = learning.InterestClassifier(init_skill=0.0, def_var=0.5)
623-
624-
knowledge = models.Knowledge(
625-
{1: models.KnowledgeComponent(mean=0.0, variance=0.5)}
626-
)
627-
event = models.EventModel(knowledge)
628-
629-
assert classifier.predict_proba(event) == 0.5
630-
631621
def test_interest_throw(self):
632622
with pytest.raises(TrueLearnTypeError) as excinfo:
633623
learning.InterestClassifier(threshold=0)
@@ -749,7 +739,7 @@ def test_interest_classifier(self, train_cases, test_events):
749739
for event, label in zip(train_events, train_labels):
750740
classifier.fit(event, label)
751741

752-
expected_results = [0.8648794445446283, 0.8438279621999456, 0.7777471206958368]
742+
expected_results = [0.8245402410711562, 0.7833295255047532, 0.9194176141732581]
753743
actual_results = [classifier.predict_proba(event) for event in test_events]
754744

755745
check_farray_close(actual_results, expected_results)
@@ -800,9 +790,9 @@ def test_ink_classifier_customize(self, train_cases, test_events):
800790
classifier.fit(event, label)
801791

802792
expected_results = [
803-
0.40575267541878457,
804-
0.36519542301026875,
805-
0.33362493980730495,
793+
0.39583121564200274,
794+
0.3542200117164174,
795+
0.36177605375601996,
806796
]
807797
actual_results = [classifier.predict_proba(event) for event in test_events]
808798

@@ -853,7 +843,7 @@ def test_ink_classifier(self, train_cases, test_events):
853843
for event, label in zip(train_events, train_labels):
854844
classifier.fit(event, label)
855845

856-
expected_results = [0.3844070661899784, 0.3398805698754434, 0.3133264788862059]
846+
expected_results = [0.3807588746036166, 0.33099758411287944, 0.3461486335816942]
857847
actual_results = [classifier.predict_proba(event) for event in test_events]
858848

859849
check_farray_close(actual_results, expected_results)
@@ -869,7 +859,7 @@ def test_ink_classifier_greedy(self):
869859
event_time=0,
870860
)
871861
]
872-
train_labels = [True]
862+
train_labels = [False]
873863
for event, label in zip(train_events, train_labels):
874864
classifier.fit(event, label)
875865

0 commit comments

Comments
 (0)