@@ -106,6 +106,7 @@ def __init__(
106106 self .loss_function = loss_function
107107 self .verbose = verbose
108108 self .catboost_kwargs = catboost_kwargs or {}
109+ self .catboost_kwargs |= self .get_extra_params ()
109110
110111 @classmethod
111112 def from_context (
@@ -125,6 +126,7 @@ def from_context(
125126 loss_function = loss_function ,
126127 verbose = verbose ,
127128 features_type = features_type ,
129+ use_embedding_features = use_embedding_features ,
128130 ** catboost_kwargs ,
129131 )
130132
@@ -144,12 +146,26 @@ def _prepare_data_for_fit(
144146 if self .use_embedding_features :
145147 data = pd .DataFrame ({"embedding" : encoded_utterances })
146148 else :
147- data = pd .DataFrame (encoded_utterances )
149+ data = pd .DataFrame (np . array ( encoded_utterances ) )
148150 if self .features_type == FeaturesType .BOTH :
149151 data ["text" ] = utterances
150152 return data
151153 return pd .DataFrame ({"text" : utterances })
152154
155+ def get_extra_params (self ) -> dict [str , Any ]:
156+ extra_params = {}
157+ if self .features_type == FeaturesType .EMBEDDING :
158+ if self .use_embedding_features : # to not raise error if embedding witout embedding_features
159+ extra_params ["embedding_features" ] = ["embedding" ]
160+ elif self .features_type in {FeaturesType .TEXT , FeaturesType .BOTH }:
161+ extra_params ["text_features" ] = ["text" ]
162+ if self .features_type == FeaturesType .BOTH and self .use_embedding_features :
163+ extra_params ["embedding_features" ] = ["embedding" ]
164+ else :
165+ msg = f"Unsupported features type: { self .features_type } "
166+ raise ValueError (msg )
167+ return extra_params
168+
153169 def fit (
154170 self ,
155171 utterances : list [str ],
@@ -167,19 +183,6 @@ def fit(
167183 else ("MultiClass" if self ._n_classes > BINARY_CLASS_THRESHOLD else "Logloss" )
168184 )
169185
170- extra_params = {}
171- if self .features_type == FeaturesType .EMBEDDING :
172- if self .use_embedding_features : # to not raise error if embedding witout embedding_features
173- extra_params ["embedding_features" ] = ["embedding" ]
174- elif self .features_type in {FeaturesType .TEXT , FeaturesType .BOTH }:
175- extra_params ["text_features" ] = ["text" ]
176- if self .features_type == FeaturesType .BOTH and self .use_embedding_features :
177- extra_params ["embedding_features" ] = ["embedding" ]
178- else :
179- msg = f"Unsupported features type: { self .features_type } "
180- raise ValueError (msg )
181- self .catboost_kwargs .update (extra_params )
182-
183186 self ._model = CatBoostClassifier (
184187 loss_function = self .loss_function or default_loss ,
185188 verbose = self .verbose ,
0 commit comments