1616)
1717
1818from autointent import Context
19- from autointent .configs import EmbedderConfig
19+ from autointent .configs import HFModelConfig
2020from autointent .custom_types import ListOfLabels
2121from autointent .modules .base import BaseScorer
2222
2323
24- class TokenizerConfig :
25- """Configuration for tokenizer parameters."""
26-
27- def __init__ (
28- self ,
29- max_length : int = 128 ,
30- padding : str = "max_length" ,
31- truncation : bool = True ,
32- ) -> None :
33- self .max_length = max_length
34- self .padding = padding
35- self .truncation = truncation
36-
37-
3824class BertScorer (BaseScorer ):
3925 name = "transformer"
4026 supports_multiclass = True
@@ -45,31 +31,28 @@ class BertScorer(BaseScorer):
4531
4632 def __init__ (
4733 self ,
48- model_config : EmbedderConfig | str | dict [str , Any ] | None = None ,
34+ model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
4935 num_train_epochs : int = 3 ,
5036 batch_size : int = 8 ,
5137 learning_rate : float = 5e-5 ,
5238 seed : int = 0 ,
53- tokenizer_config : TokenizerConfig | None = None ,
5439 ) -> None :
55- self .model_config = EmbedderConfig .from_search_config (model_config )
40+ self .model_config = HFModelConfig .from_search_config (model_config )
5641 self .num_train_epochs = num_train_epochs
5742 self .batch_size = batch_size
5843 self .learning_rate = learning_rate
5944 self .seed = seed
60- self .tokenizer_config = tokenizer_config or TokenizerConfig ()
6145 self ._multilabel = False
6246
6347 @classmethod
6448 def from_context (
6549 cls ,
6650 context : Context ,
67- model_config : EmbedderConfig | str | None = None ,
51+ model_config : HFModelConfig | str | dict [ str , Any ] | None = None ,
6852 num_train_epochs : int = 3 ,
6953 batch_size : int = 8 ,
7054 learning_rate : float = 5e-5 ,
7155 seed : int = 0 ,
72- tokenizer_config : TokenizerConfig | None = None ,
7356 ) -> "BertScorer" :
7457 if model_config is None :
7558 model_config = context .resolve_embedder ()
@@ -79,7 +62,6 @@ def from_context(
7962 batch_size = batch_size ,
8063 learning_rate = learning_rate ,
8164 seed = seed ,
82- tokenizer_config = tokenizer_config ,
8365 )
8466
8567 def get_embedder_config (self ) -> dict [str , Any ]:
@@ -114,10 +96,7 @@ def fit(
11496
11597 def tokenize_function (examples : dict [str , Any ]) -> dict [str , Any ]:
11698 return self ._tokenizer ( # type: ignore[no-any-return]
117- examples ["text" ],
118- padding = self .tokenizer_config .padding ,
119- truncation = self .tokenizer_config .truncation ,
120- max_length = self .tokenizer_config .max_length ,
99+ examples ["text" ], return_tensors = "pt" , ** self .model_config .tokenizer_config .model_dump ()
121100 )
122101
123102 dataset = Dataset .from_dict ({"text" : utterances , "labels" : labels })
@@ -154,9 +133,7 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
154133 msg = "Model is not trained. Call fit() first."
155134 raise RuntimeError (msg )
156135
157- inputs = self ._tokenizer (
158- utterances , padding = True , truncation = True , max_length = self .tokenizer_config .max_length , return_tensors = "pt"
159- )
136+ inputs = self ._tokenizer (utterances , return_tensors = "pt" , ** self .model_config .tokenizer_config .model_dump ())
160137
161138 with torch .no_grad ():
162139 outputs = self ._model (** inputs )
0 commit comments