1- """TransformerScorer class for transformer-based classification."""
1+ """BertScorer class for transformer-based classification."""
22
33import tempfile
44from typing import Any
2121from autointent .modules .base import BaseScorer
2222
2323
24- class TransformerScorer (BaseScorer ):
24+ class TokenizerConfig :
25+ """Configuration for tokenizer parameters."""
26+
27+ def __init__ (
28+ self ,
29+ max_length : int = 128 ,
30+ padding : str = "max_length" ,
31+ truncation : bool = True ,
32+ ) -> None :
33+ self .max_length = max_length
34+ self .padding = padding
35+ self .truncation = truncation
36+
37+
38+ class BertScorer (BaseScorer ):
2539 name = "transformer"
2640 supports_multiclass = True
2741 supports_multilabel = True
@@ -36,26 +50,46 @@ def __init__(
3650 batch_size : int = 8 ,
3751 learning_rate : float = 5e-5 ,
3852 seed : int = 0 ,
53+ tokenizer_config : TokenizerConfig | None = None ,
3954 ) -> None :
4055 self .model_config = EmbedderConfig .from_search_config (model_config )
4156 self .num_train_epochs = num_train_epochs
4257 self .batch_size = batch_size
4358 self .learning_rate = learning_rate
4459 self .seed = seed
60+ self .tokenizer_config = tokenizer_config or TokenizerConfig ()
61+ self ._multilabel = False
4562
4663 @classmethod
4764 def from_context (
4865 cls ,
4966 context : Context ,
5067 model_config : EmbedderConfig | str | None = None ,
51- ) -> "TransformerScorer" :
68+ num_train_epochs : int = 3 ,
69+ batch_size : int = 8 ,
70+ learning_rate : float = 5e-5 ,
71+ seed : int = 0 ,
72+ tokenizer_config : TokenizerConfig | None = None ,
73+ ) -> "BertScorer" :
5274 if model_config is None :
5375 model_config = context .resolve_embedder ()
54- return cls (model_config = model_config )
76+ return cls (
77+ model_config = model_config ,
78+ num_train_epochs = num_train_epochs ,
79+ batch_size = batch_size ,
80+ learning_rate = learning_rate ,
81+ seed = seed ,
82+ tokenizer_config = tokenizer_config ,
83+ )
5584
5685 def get_embedder_config (self ) -> dict [str , Any ]:
5786 return self .model_config .model_dump ()
5887
88+ def _validate_task (self , labels : ListOfLabels ) -> None :
89+ """Validate the task and set _multilabel flag."""
90+ super ()._validate_task (labels )
91+ self ._multilabel = isinstance (labels [0 ], list )
92+
5993 def fit (
6094 self ,
6195 utterances : list [str ],
@@ -67,7 +101,7 @@ def fit(
67101 self ._validate_task (labels )
68102
69103 if self ._multilabel :
70- labels_array = np .array (labels ) if not isinstance ( labels , np . ndarray ) else labels
104+ labels_array = np .array (labels )
71105 num_labels = labels_array .shape [1 ]
72106 else :
73107 num_labels = len (set (labels ))
@@ -76,8 +110,15 @@ def fit(
76110 self ._tokenizer = AutoTokenizer .from_pretrained (model_name )
77111 self ._model = AutoModelForSequenceClassification .from_pretrained (model_name , num_labels = num_labels )
78112
79- def tokenize_function (examples : dict [str , Any ]) -> dict [str , Any ]:
80- return self ._tokenizer (examples ["text" ], padding = "max_length" , truncation = True , max_length = 128 )
113+ use_cpu = hasattr (self .model_config , "device" ) and self .model_config .device == "cpu"
114+
115+ def tokenize_function (examples : dict [str , Any ]) -> dict [str , Any ]: # type: ignore[no-any-return]
116+ return self ._tokenizer (
117+ examples ["text" ],
118+ padding = self .tokenizer_config .padding ,
119+ truncation = self .tokenizer_config .truncation ,
120+ max_length = self .tokenizer_config .max_length ,
121+ )
81122
82123 dataset = Dataset .from_dict ({"text" : utterances , "labels" : labels })
83124 tokenized_dataset = dataset .map (tokenize_function , batched = True )
@@ -90,8 +131,10 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
90131 learning_rate = self .learning_rate ,
91132 seed = self .seed ,
92133 save_strategy = "no" ,
93- logging_strategy = "no" ,
94- report_to = "none" ,
134+ logging_strategy = "steps" ,
135+ logging_steps = 10 ,
136+ report_to = "wandb" ,
137+ use_cpu = use_cpu ,
95138 )
96139
97140 trainer = Trainer (
@@ -111,7 +154,9 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
111154 msg = "Model is not trained. Call fit() first."
112155 raise RuntimeError (msg )
113156
114- inputs = self ._tokenizer (utterances , padding = True , truncation = True , max_length = 128 , return_tensors = "pt" )
157+ inputs = self ._tokenizer (
158+ utterances , padding = True , truncation = True , max_length = self .tokenizer_config .max_length , return_tensors = "pt"
159+ )
115160
116161 with torch .no_grad ():
117162 outputs = self ._model (** inputs )
@@ -121,7 +166,6 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
121166 return torch .sigmoid (logits ).numpy ()
122167 return torch .softmax (logits , dim = 1 ).numpy ()
123168
124-
125169 def clear_cache (self ) -> None :
126170 if hasattr (self , "_model" ):
127171 del self ._model
0 commit comments