@@ -31,14 +31,14 @@ class BertScorer(BaseScorer):
3131
3232 def __init__ (
3333 self ,
34- model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
34+ classification_model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
3535 num_train_epochs : int = 3 ,
3636 batch_size : int = 8 ,
3737 learning_rate : float = 5e-5 ,
3838 seed : int = 0 ,
3939 report_to : REPORTERS_NAMES | None = None , # type: ignore # noqa: PGH003
4040 ) -> None :
41- self .model_config = HFModelConfig .from_search_config (model_config )
41+ self .classification_model_config = HFModelConfig .from_search_config (classification_model_config )
4242 self .num_train_epochs = num_train_epochs
4343 self .batch_size = batch_size
4444 self .learning_rate = learning_rate
@@ -49,19 +49,19 @@ def __init__(
4949 def from_context (
5050 cls ,
5151 context : Context ,
52- model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
52+ classification_model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
5353 num_train_epochs : int = 3 ,
5454 batch_size : int = 8 ,
5555 learning_rate : float = 5e-5 ,
5656 seed : int = 0 ,
5757 ) -> "BertScorer" :
58- if model_config is None :
59- model_config = context .resolve_embedder ()
58+ if classification_model_config is None :
59+ classification_model_config = context .resolve_embedder ()
6060
6161 report_to = context .logging_config .report_to
6262
6363 return cls (
64- model_config = model_config ,
64+ classification_model_config = classification_model_config ,
6565 num_train_epochs = num_train_epochs ,
6666 batch_size = batch_size ,
6767 learning_rate = learning_rate ,
@@ -70,7 +70,7 @@ def from_context(
7070 )
7171
7272 def get_embedder_config (self ) -> dict [str , Any ]:
73- return self .model_config .model_dump ()
73+ return self .classification_model_config .model_dump ()
7474
7575 def fit (
7676 self ,
@@ -81,7 +81,7 @@ def fit(
8181 self .clear_cache ()
8282 self ._validate_task (labels )
8383
84- model_name = self .model_config .model_name
84+ model_name = self .classification_model_config .model_name
8585 self ._tokenizer = AutoTokenizer .from_pretrained (model_name )
8686
8787 label2id = {i : i for i in range (self ._n_classes )}
@@ -95,11 +95,11 @@ def fit(
9595 problem_type = "multi_label_classification" if self ._multilabel else "single_label_classification" ,
9696 )
9797
98- use_cpu = self .model_config .device == "cpu"
98+ use_cpu = self .classification_model_config .device == "cpu"
9999
100100 def tokenize_function (examples : dict [str , Any ]) -> dict [str , Any ]:
101101 return self ._tokenizer ( # type: ignore[no-any-return]
102- examples ["text" ], return_tensors = "pt" , ** self .model_config .tokenizer_config .model_dump ()
102+ examples ["text" ], return_tensors = "pt" , ** self .classification_model_config .tokenizer_config .model_dump ()
103103 )
104104
105105 dataset = Dataset .from_dict ({"text" : utterances , "labels" : labels })
@@ -148,7 +148,9 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
148148 all_predictions = []
149149 for i in range (0 , len (utterances ), self .batch_size ):
150150 batch = utterances [i : i + self .batch_size ]
151- inputs = self ._tokenizer (batch , return_tensors = "pt" , ** self .model_config .tokenizer_config .model_dump ())
151+ inputs = self ._tokenizer (
152+ batch , return_tensors = "pt" , ** self .classification_model_config .tokenizer_config .model_dump ()
153+ )
152154 inputs = {k : v .to (device ) for k , v in inputs .items ()}
153155 with torch .no_grad ():
154156 outputs = self ._model (** inputs )
0 commit comments