Skip to content

Commit 5bf5de0

Browse files
committed
Update lora.py
1 parent 44f650e commit 5bf5de0

File tree

1 file changed

+12
-12
lines changed
  • autointent/modules/scoring/_lora

1 file changed

+12
-12
lines changed

autointent/modules/scoring/_lora/lora.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
report_to: REPORTERS_NAMES | None = None, # type: ignore[no-any-return]
4141
**lora_kwargs: dict[str, Any],
4242
) -> None:
43-
self.model_config = HFModelConfig.from_search_config(transformer_config)
43+
self.transformer_config = HFModelConfig.from_search_config(transformer_config)
4444
self.num_train_epochs = num_train_epochs
4545
self.batch_size = batch_size
4646
self.learning_rate = learning_rate
@@ -52,17 +52,17 @@ def __init__(
5252
def from_context(
5353
cls,
5454
context: Context,
55-
model_config: HFModelConfig | str | dict[str, Any] | None = None,
55+
transformer_config: HFModelConfig | str | dict[str, Any] | None = None,
5656
num_train_epochs: int = 3,
5757
batch_size: int = 8,
5858
learning_rate: float = 5e-5,
5959
seed: int = 0,
6060
**lora_kwargs: dict[str, Any],
6161
) -> "BERTLoRAScorer":
62-
if model_config is None:
63-
model_config = context.resolve_embedder()
62+
if transformer_config is None:
63+
transformer_config = context.resolve_embedder()
6464
return cls(
65-
model_config=model_config,
65+
transformer_config=transformer_config,
6666
num_train_epochs=num_train_epochs,
6767
batch_size=batch_size,
6868
learning_rate=learning_rate,
@@ -72,7 +72,7 @@ def from_context(
7272
)
7373

7474
def get_embedder_config(self) -> dict[str, Any]:
75-
return self.model_config.model_dump()
75+
return self.transformer_config.model_dump()
7676

7777
def fit(
7878
self,
@@ -84,7 +84,7 @@ def fit(
8484

8585
self._validate_task(labels)
8686

87-
model_name = self.model_config.model_name
87+
model_name = self.transformer_config.model_name
8888
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
8989
self._model = AutoModelForSequenceClassification.from_pretrained(
9090
model_name,
@@ -94,14 +94,14 @@ def fit(
9494
)
9595
self._model = get_peft_model(self._model, self._lora_config)
9696

97-
device = torch.device(self.model_config.device if self.model_config.device else "cpu")
97+
device = torch.device(self.transformer_config.device if self.transformer_config.device else "cpu")
9898
self._model = self._model.to(device)
9999

100-
use_cpu = self.model_config.device == "cpu"
100+
use_cpu = self.transformer_config.device == "cpu"
101101

102102
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
103103
return self._tokenizer( # type: ignore[no-any-return]
104-
examples["text"], return_tensors="pt", **self.model_config.tokenizer_config.model_dump()
104+
examples["text"], return_tensors="pt", **self.transformer_config.tokenizer_config.model_dump()
105105
)
106106

107107
dataset = Dataset.from_dict({"text": utterances, "labels": labels})
@@ -143,13 +143,13 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
143143
msg = "Model is not trained. Call fit() first."
144144
raise RuntimeError(msg)
145145

146-
device = torch.device(self.model_config.device if self.model_config.device else "cpu")
146+
device = torch.device(self.transformer_config.device if self.transformer_config.device else "cpu")
147147
self._model = self._model.to(device)
148148

149149
all_predictions = []
150150
for i in range(0, len(utterances), self.batch_size):
151151
batch = utterances[i : i + self.batch_size]
152-
inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())
152+
inputs = self._tokenizer(batch, return_tensors="pt", **self.transformer_config.tokenizer_config.model_dump())
153153
inputs = {k: v.to(device) for k, v in inputs.items()}
154154
with torch.no_grad():
155155
outputs = self._model(**inputs)

0 commit comments

Comments
 (0)