Skip to content

Commit 5f46d80

Browse files
committed
updated after mr #165
1 parent 1a1e89c commit 5f46d80

File tree

1 file changed

+12
-50
lines changed
  • autointent/modules/scoring/_lora

1 file changed

+12
-50
lines changed

autointent/modules/scoring/_lora/lora.py

Lines changed: 12 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,11 @@
1717
)
1818

1919
from autointent import Context
20-
from autointent.configs import EmbedderConfig
20+
from autointent.configs import HFModelConfig
2121
from autointent.custom_types import ListOfLabels
2222
from autointent.modules.base import BaseScorer
2323

2424

25-
class TokenizerConfig:
26-
"""Configuration for tokenizer parameters."""
27-
28-
def __init__(
29-
self,
30-
max_length: int = 128,
31-
padding: str = "max_length",
32-
truncation: bool = True,
33-
) -> None:
34-
self.max_length = max_length
35-
self.padding = padding
36-
self.truncation = truncation
37-
38-
3925
class BERTLoRAScorer(BaseScorer):
4026
name = "lora"
4127
supports_multiclass = True
@@ -46,40 +32,31 @@ class BERTLoRAScorer(BaseScorer):
4632

4733
def __init__(
4834
self,
49-
model_config: EmbedderConfig | str | dict[str, Any] | None = None,
35+
model_config: HFModelConfig | str | dict[str, Any] | None = None,
5036
num_train_epochs: int = 3,
5137
batch_size: int = 8,
5238
learning_rate: float = 5e-5,
5339
seed: int = 0,
54-
tokenizer_config: TokenizerConfig | None = None,
55-
lora_rank: int = 16,
56-
lora_alpha: int = 32,
57-
lora_dropout: float = 0.1,
40+
**lora_kwargs: Any,
5841
) -> None:
59-
self.model_config = EmbedderConfig.from_search_config(model_config)
42+
self.model_config = HFModelConfig.from_search_config(model_config)
6043
self.num_train_epochs = num_train_epochs
6144
self.batch_size = batch_size
6245
self.learning_rate = learning_rate
6346
self.seed = seed
6447
self._multilabel = False
65-
self.tokenizer_config = tokenizer_config or TokenizerConfig()
66-
self.lora_rank = lora_rank
67-
self.lora_alpha = lora_alpha
68-
self.lora_dropout = lora_dropout
48+
self._lora_config = LoraConfig(**lora_kwargs)
6949

7050
@classmethod
7151
def from_context(
7252
cls,
7353
context: Context,
74-
model_config: EmbedderConfig | str | dict[str, Any] | None = None,
54+
model_config: HFModelConfig | str | dict[str, Any] | None = None,
7555
num_train_epochs: int = 10,
7656
batch_size: int = 8,
7757
learning_rate: float = 5e-5,
7858
seed: int = 0,
79-
tokenizer_config: TokenizerConfig | None = None,
80-
lora_rank: int = 8,
81-
lora_alpha: int = 32,
82-
lora_dropout: float = 0.1,
59+
**lora_kwargs: Any,
8360
) -> "BERTLoRAScorer":
8461
if model_config is None:
8562
model_config = context.resolve_embedder()
@@ -89,10 +66,7 @@ def from_context(
8966
batch_size=batch_size,
9067
learning_rate=learning_rate,
9168
seed=seed,
92-
tokenizer_config=tokenizer_config,
93-
lora_rank=lora_rank,
94-
lora_alpha=lora_alpha,
95-
lora_dropout=lora_dropout,
69+
**lora_kwargs,
9670
)
9771

9872
def get_embedder_config(self) -> dict[str, Any]:
@@ -123,26 +97,14 @@ def fit(
12397
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
12498
self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
12599

126-
# Configure LoRA
127-
lora_config = LoraConfig(
128-
r=self.lora_rank, # Rank of the low-rank matrices
129-
lora_alpha=self.lora_alpha, # Scaling factor
130-
target_modules=["query", "value"], # Target modules to apply LoRA
131-
lora_dropout=self.lora_dropout, # Dropout rate for LoRA layers
132-
bias="none", # Whether to add bias to LoRA layers
133-
)
134-
135100
# Apply LoRA to the model
136-
self._model = get_peft_model(self._model, lora_config)
101+
self._model = get_peft_model(self._model, self._lora_config)
137102

138103
use_cpu = hasattr(self.model_config, "device") and self.model_config.device == "cpu"
139104

140105
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
141-
return self._tokenizer(
142-
examples["text"],
143-
padding=self.tokenizer_config.padding,
144-
truncation=self.tokenizer_config.truncation,
145-
max_length=self.tokenizer_config.max_length,
106+
return self._tokenizer( # type: ignore[no-any-return]
107+
examples["text"], return_tensors="pt", **self.model_config.tokenizer_config.model_dump()
146108
)
147109

148110
dataset = Dataset.from_dict({"text": utterances, "labels": labels})
@@ -180,7 +142,7 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
180142
raise RuntimeError(msg)
181143

182144
inputs = self._tokenizer(
183-
utterances, padding=True, truncation=True, max_length=self.tokenizer_config.max_length, return_tensors="pt"
145+
utterances, padding=True, truncation=True, max_length=self.model_config.tokenizer_config.max_length, return_tensors="pt"
184146
)
185147

186148
with torch.no_grad():

0 commit comments

Comments
 (0)