@@ -77,12 +77,12 @@ class InterestClassifier(InterestNoveltyKnowledgeBaseClassifier):
7777 ... interest_classifier.predict_proba(event)
7878 ... )
7979 ...
80- True 0.88450 ...
81- True 0.81079 ...
82- True 0.95872 ...
80+ True 0.87299 ...
81+ True 0.69146 ...
82+ True 0.91941 ...
8383 >>> interest_classifier.get_params() # doctest:+ELLIPSIS
8484 {..., 'learner_model': LearnerModel(knowledge=Knowledge(knowledge=\
85- {1: KnowledgeComponent(mean=0.99556 ..., variance=0.10483 ..., ...), ...}), ...}
85+ {1: KnowledgeComponent(mean=1.3651 ..., variance=0.07128 ..., ...), ...}), ...}
8686 """
8787
8888 _parameter_constraints : Dict [str , Any ] = {
@@ -196,6 +196,25 @@ def __get_decay_func(self) -> Callable[[float], float]:
196196
197197 return lambda t_delta : min (math .exp (- self ._decay_func_factor * t_delta ), 1.0 )
198198
199+ @staticmethod
200+ def _content_kc_masks (
201+ content_kcs : Iterable [BaseKnowledgeComponent ],
202+ ) -> Iterable [BaseKnowledgeComponent ]:
203+ """Return a new iterable of content's knowledge components.
204+
205+ Args:
206+ content_kcs: An iterable of content's knowledge components.
207+
208+ Returns:
209+ A new iterable of content's knowledge components, where
210+ the mean of each knowledge component is set to 1,
211+ based on the assumption of the TrueLearn Interest model.
212+ """
213+ return (
214+ kc .clone (mean = 1.0 , variance = kc .variance , timestamp = kc .timestamp )
215+ for kc in content_kcs
216+ )
217+
199218 def _generate_ratings (
200219 self ,
201220 env : trueskill .TrueSkill ,
@@ -257,7 +276,9 @@ def __apply_interest_decay(
257276 learner_kcs_decayed = map (__apply_interest_decay , learner_kcs )
258277
259278 team_learner = gather_trueskill_team (env , learner_kcs_decayed )
260- team_content = gather_trueskill_team (env , content_kcs )
279+ team_content = gather_trueskill_team (
280+ env , InterestClassifier ._content_kc_masks (content_kcs )
281+ )
261282
262283 # learner always wins in interest
263284 updated_team_learner , _ = env .rate ([team_learner , team_content ], ranks = [0 , 1 ])
@@ -269,4 +290,6 @@ def _eval_matching_quality(
269290 learner_kcs : Iterable [BaseKnowledgeComponent ],
270291 content_kcs : Iterable [BaseKnowledgeComponent ],
271292 ) -> float :
272- return team_sum_quality_from_kcs (learner_kcs , content_kcs , self ._beta )
293+ return team_sum_quality_from_kcs (
294+ learner_kcs , InterestClassifier ._content_kc_masks (content_kcs ), self ._beta
295+ )
0 commit comments