1717)
1818
1919from autointent import Context
20- from autointent .configs import EmbedderConfig
20+ from autointent .configs import HFModelConfig
2121from autointent .custom_types import ListOfLabels
2222from autointent .modules .base import BaseScorer
2323
2424
25- class TokenizerConfig :
26- """Configuration for tokenizer parameters."""
27-
28- def __init__ (
29- self ,
30- max_length : int = 128 ,
31- padding : str = "max_length" ,
32- truncation : bool = True ,
33- ) -> None :
34- self .max_length = max_length
35- self .padding = padding
36- self .truncation = truncation
37-
38-
3925class BERTLoRAScorer (BaseScorer ):
4026 name = "lora"
4127 supports_multiclass = True
@@ -46,40 +32,31 @@ class BERTLoRAScorer(BaseScorer):
4632
4733 def __init__ (
4834 self ,
49- model_config : EmbedderConfig | str | dict [str , Any ] | None = None ,
35+ model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
5036 num_train_epochs : int = 3 ,
5137 batch_size : int = 8 ,
5238 learning_rate : float = 5e-5 ,
5339 seed : int = 0 ,
54- tokenizer_config : TokenizerConfig | None = None ,
55- lora_rank : int = 16 ,
56- lora_alpha : int = 32 ,
57- lora_dropout : float = 0.1 ,
40+ ** lora_kwargs : Any ,
5841 ) -> None :
59- self .model_config = EmbedderConfig .from_search_config (model_config )
42+ self .model_config = HFModelConfig .from_search_config (model_config )
6043 self .num_train_epochs = num_train_epochs
6144 self .batch_size = batch_size
6245 self .learning_rate = learning_rate
6346 self .seed = seed
6447 self ._multilabel = False
65- self .tokenizer_config = tokenizer_config or TokenizerConfig ()
66- self .lora_rank = lora_rank
67- self .lora_alpha = lora_alpha
68- self .lora_dropout = lora_dropout
48+ self ._lora_config = LoraConfig (** lora_kwargs )
6949
7050 @classmethod
7151 def from_context (
7252 cls ,
7353 context : Context ,
74- model_config : EmbedderConfig | str | dict [str , Any ] | None = None ,
54+ model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
7555 num_train_epochs : int = 10 ,
7656 batch_size : int = 8 ,
7757 learning_rate : float = 5e-5 ,
7858 seed : int = 0 ,
79- tokenizer_config : TokenizerConfig | None = None ,
80- lora_rank : int = 8 ,
81- lora_alpha : int = 32 ,
82- lora_dropout : float = 0.1 ,
59+ ** lora_kwargs : Any ,
8360 ) -> "BERTLoRAScorer" :
8461 if model_config is None :
8562 model_config = context .resolve_embedder ()
@@ -89,10 +66,7 @@ def from_context(
8966 batch_size = batch_size ,
9067 learning_rate = learning_rate ,
9168 seed = seed ,
92- tokenizer_config = tokenizer_config ,
93- lora_rank = lora_rank ,
94- lora_alpha = lora_alpha ,
95- lora_dropout = lora_dropout ,
69+ ** lora_kwargs ,
9670 )
9771
9872 def get_embedder_config (self ) -> dict [str , Any ]:
@@ -123,26 +97,14 @@ def fit(
12397 self ._tokenizer = AutoTokenizer .from_pretrained (model_name )
12498 self ._model = AutoModelForSequenceClassification .from_pretrained (model_name , num_labels = num_labels )
12599
126- # Configure LoRA
127- lora_config = LoraConfig (
128- r = self .lora_rank , # Rank of the low-rank matrices
129- lora_alpha = self .lora_alpha , # Scaling factor
130- target_modules = ["query" , "value" ], # Target modules to apply LoRA
131- lora_dropout = self .lora_dropout , # Dropout rate for LoRA layers
132- bias = "none" , # Whether to add bias to LoRA layers
133- )
134-
135100 # Apply LoRA to the model
136- self ._model = get_peft_model (self ._model , lora_config )
101+ self ._model = get_peft_model (self ._model , self . _lora_config )
137102
138103 use_cpu = hasattr (self .model_config , "device" ) and self .model_config .device == "cpu"
139104
140105 def tokenize_function (examples : dict [str , Any ]) -> dict [str , Any ]:
141- return self ._tokenizer (
142- examples ["text" ],
143- padding = self .tokenizer_config .padding ,
144- truncation = self .tokenizer_config .truncation ,
145- max_length = self .tokenizer_config .max_length ,
106+ return self ._tokenizer ( # type: ignore[no-any-return]
107+ examples ["text" ], return_tensors = "pt" , ** self .model_config .tokenizer_config .model_dump ()
146108 )
147109
148110 dataset = Dataset .from_dict ({"text" : utterances , "labels" : labels })
@@ -180,7 +142,7 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
180142 raise RuntimeError (msg )
181143
182144 inputs = self ._tokenizer (
183- utterances , padding = True , truncation = True , max_length = self .tokenizer_config .max_length , return_tensors = "pt"
145+ utterances , padding = True , truncation = True , max_length = self .model_config . tokenizer_config .max_length , return_tensors = "pt"
184146 )
185147
186148 with torch .no_grad ():
0 commit comments