11import math
2- from typing import Any , Optional , Dict , Tuple
2+ from typing import Any , Optional , Dict , Tuple , Union
33from typing_extensions import Self , Final
44
55import trueskill
@@ -92,8 +92,8 @@ def __init__(
9292 self ,
9393 * ,
9494 learner_meta_weights : Optional [LearnerMetaWeights ] = None ,
95- novelty_classifier : Optional [NoveltyClassifier ] = None ,
96- interest_classifier : Optional [InterestClassifier ] = None ,
95+ novelty_classifier : Optional [Union [ NoveltyClassifier , Dict ] ] = None ,
96+ interest_classifier : Optional [Union [ InterestClassifier , Dict ] ] = None ,
9797 threshold : float = 0.5 ,
9898 tau : float = 0.0 ,
9999 greedy : bool = False ,
@@ -106,9 +106,13 @@ def __init__(
106106 learner_meta_weights:
107107 The novelty/interest/bias weights.
108108 novelty_classifier:
109- The NoveltyClassifier.
109+ The NoveltyClassifier. It can be a NoveltyClassifier object or
110+ a dictionary of parameters that can be used to instantiate a
111+ NoveltyClassifier object.
110112 interest_classifier:
111- The InterestClassifier.
113+ The InterestClassifier. It can be an InterestClassifier object or
114+ a dictionary of parameters that can be used to instantiate an
115+ InterestClassifier object.
112116 threshold:
113117 A float that determines the classification threshold.
114118 tau:
@@ -125,8 +129,20 @@ def __init__(
125129 TrueLearnValueError:
126130 Values of parameters do not satisfy their constraints.
127131 """
128- self ._novelty_classifier = novelty_classifier or NoveltyClassifier ()
129- self ._interest_classifier = interest_classifier or InterestClassifier ()
132+ if novelty_classifier is None :
133+ self ._novelty_classifier = NoveltyClassifier ()
134+ elif isinstance (novelty_classifier , dict ):
135+ self ._novelty_classifier = NoveltyClassifier (** novelty_classifier )
136+ else :
137+ self ._novelty_classifier = novelty_classifier
138+
139+ if interest_classifier is None :
140+ self ._interest_classifier = InterestClassifier ()
141+ elif isinstance (interest_classifier , dict ):
142+ self ._interest_classifier = InterestClassifier (** interest_classifier )
143+ else :
144+ self ._interest_classifier = interest_classifier
145+
130146 self ._learner_meta_weights = learner_meta_weights or LearnerMetaWeights ()
131147 self ._threshold = threshold
132148 self ._tau = tau
0 commit comments