@@ -56,7 +56,7 @@ def __init__(
5656 embedder_name : str ,
5757 clf_name : str ,
5858 cv : int = 3 ,
59- clf_args : dict [str , Any ] = {}, # noqa: B006
59+ clf_args : dict [str , Any ] | None = None ,
6060 n_jobs : int = - 1 ,
6161 device : str = "cpu" ,
6262 seed : int = 0 ,
@@ -91,7 +91,7 @@ def from_context(
9191 cls ,
9292 context : Context ,
9393 clf_name : str ,
94- clf_args : dict [str , Any ] = {}, # noqa: B006
94+ clf_args : dict [str , Any ] | None = None ,
9595 embedder_name : str | None = None ,
9696 ) -> Self :
9797 """
@@ -136,7 +136,7 @@ def fit(
136136 self ._multilabel = isinstance (labels [0 ], list )
137137
138138 if self .precomputed_embeddings :
139- # this happens only when LinearScorer is within Pipeline opimization after RetrievalNode optimization
139+ # this happens only when SklearnScorer is within Pipeline opimization after RetrievalNode optimization
140140 vector_index_client = VectorIndexClient (self .device , self .db_dir , self .batch_size , self .max_length )
141141 vector_index = vector_index_client .get_index (self .embedder_name )
142142 features = vector_index .get_all_embeddings ()
@@ -152,7 +152,7 @@ def fit(
152152 max_length = self .max_length ,
153153 )
154154 features = embedder .embed (utterances )
155-
155+ self . clf_args = {} if self . clf_args is None else self . clf_args
156156 if AVAILIABLE_CLASSIFIERS .get (self .clf_name ):
157157 base_clf = AVAILIABLE_CLASSIFIERS [self .clf_name ](** self .clf_args )
158158 else :
0 commit comments