@@ -40,7 +40,7 @@ def __init__(
4040 report_to : REPORTERS_NAMES | None = None , # type: ignore[no-any-return]
4141 ** lora_kwargs : dict [str , Any ],
4242 ) -> None :
43- self .model_config = HFModelConfig .from_search_config (transformer_config )
43+ self .transformer_config = HFModelConfig .from_search_config (transformer_config )
4444 self .num_train_epochs = num_train_epochs
4545 self .batch_size = batch_size
4646 self .learning_rate = learning_rate
@@ -52,17 +52,17 @@ def __init__(
5252 def from_context (
5353 cls ,
5454 context : Context ,
55- model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
55+ transformer_config : HFModelConfig | str | dict [str , Any ] | None = None ,
5656 num_train_epochs : int = 3 ,
5757 batch_size : int = 8 ,
5858 learning_rate : float = 5e-5 ,
5959 seed : int = 0 ,
6060 ** lora_kwargs : dict [str , Any ],
6161 ) -> "BERTLoRAScorer" :
62- if model_config is None :
63- model_config = context .resolve_embedder ()
62+ if transformer_config is None :
63+ transformer_config = context .resolve_embedder ()
6464 return cls (
65- model_config = model_config ,
65+ transformer_config = transformer_config ,
6666 num_train_epochs = num_train_epochs ,
6767 batch_size = batch_size ,
6868 learning_rate = learning_rate ,
@@ -72,7 +72,7 @@ def from_context(
7272 )
7373
7474 def get_embedder_config (self ) -> dict [str , Any ]:
75- return self .model_config .model_dump ()
75+ return self .transformer_config .model_dump ()
7676
7777 def fit (
7878 self ,
@@ -84,7 +84,7 @@ def fit(
8484
8585 self ._validate_task (labels )
8686
87- model_name = self .model_config .model_name
87+ model_name = self .transformer_config .model_name
8888 self ._tokenizer = AutoTokenizer .from_pretrained (model_name )
8989 self ._model = AutoModelForSequenceClassification .from_pretrained (
9090 model_name ,
@@ -94,14 +94,14 @@ def fit(
9494 )
9595 self ._model = get_peft_model (self ._model , self ._lora_config )
9696
97- device = torch .device (self .model_config .device if self .model_config .device else "cpu" )
97+ device = torch .device (self .transformer_config .device if self .transformer_config .device else "cpu" )
9898 self ._model = self ._model .to (device )
9999
100- use_cpu = self .model_config .device == "cpu"
100+ use_cpu = self .transformer_config .device == "cpu"
101101
102102 def tokenize_function (examples : dict [str , Any ]) -> dict [str , Any ]:
103103 return self ._tokenizer ( # type: ignore[no-any-return]
104- examples ["text" ], return_tensors = "pt" , ** self .model_config .tokenizer_config .model_dump ()
104+ examples ["text" ], return_tensors = "pt" , ** self .transformer_config .tokenizer_config .model_dump ()
105105 )
106106
107107 dataset = Dataset .from_dict ({"text" : utterances , "labels" : labels })
@@ -143,13 +143,13 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
143143 msg = "Model is not trained. Call fit() first."
144144 raise RuntimeError (msg )
145145
146- device = torch .device (self .model_config .device if self .model_config .device else "cpu" )
146+ device = torch .device (self .transformer_config .device if self .transformer_config .device else "cpu" )
147147 self ._model = self ._model .to (device )
148148
149149 all_predictions = []
150150 for i in range (0 , len (utterances ), self .batch_size ):
151151 batch = utterances [i : i + self .batch_size ]
152- inputs = self ._tokenizer (batch , return_tensors = "pt" , ** self .model_config .tokenizer_config .model_dump ())
152+ inputs = self ._tokenizer (batch , return_tensors = "pt" , ** self .transformer_config .tokenizer_config .model_dump ())
153153 inputs = {k : v .to (device ) for k , v in inputs .items ()}
154154 with torch .no_grad ():
155155 outputs = self ._model (** inputs )
0 commit comments