Skip to content

Commit 5e58a21

Browse files
committed
rename bert model
1 parent e51be88 commit 5e58a21

File tree

4 files changed

+15
-15
lines changed

4 files changed

+15
-15
lines changed

autointent/modules/scoring/_bert.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ class BertScorer(BaseScorer):
3131

3232
def __init__(
3333
self,
34-
model_config: HFModelConfig | str | dict[str, Any] | None = None,
34+
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
3535
num_train_epochs: int = 3,
3636
batch_size: int = 8,
3737
learning_rate: float = 5e-5,
3838
seed: int = 0,
3939
report_to: REPORTERS_NAMES | None = None, # type: ignore # noqa: PGH003
4040
) -> None:
41-
self.model_config = HFModelConfig.from_search_config(model_config)
41+
self.classification_model_config = HFModelConfig.from_search_config(classification_model_config)
4242
self.num_train_epochs = num_train_epochs
4343
self.batch_size = batch_size
4444
self.learning_rate = learning_rate
@@ -49,19 +49,19 @@ def __init__(
4949
def from_context(
5050
cls,
5151
context: Context,
52-
model_config: HFModelConfig | str | dict[str, Any] | None = None,
52+
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
5353
num_train_epochs: int = 3,
5454
batch_size: int = 8,
5555
learning_rate: float = 5e-5,
5656
seed: int = 0,
5757
) -> "BertScorer":
58-
if model_config is None:
59-
model_config = context.resolve_embedder()
58+
if classification_model_config is None:
59+
classification_model_config = context.resolve_embedder()
6060

6161
report_to = context.logging_config.report_to
6262

6363
return cls(
64-
model_config=model_config,
64+
classification_model_config=classification_model_config,
6565
num_train_epochs=num_train_epochs,
6666
batch_size=batch_size,
6767
learning_rate=learning_rate,
@@ -70,7 +70,7 @@ def from_context(
7070
)
7171

7272
def get_embedder_config(self) -> dict[str, Any]:
73-
return self.model_config.model_dump()
73+
return self.classification_model_config.model_dump()
7474

7575
def fit(
7676
self,
@@ -81,7 +81,7 @@ def fit(
8181
self.clear_cache()
8282
self._validate_task(labels)
8383

84-
model_name = self.model_config.model_name
84+
model_name = self.classification_model_config.model_name
8585
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
8686

8787
label2id = {i: i for i in range(self._n_classes)}
@@ -95,11 +95,11 @@ def fit(
9595
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
9696
)
9797

98-
use_cpu = self.model_config.device == "cpu"
98+
use_cpu = self.classification_model_config.device == "cpu"
9999

100100
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
101101
return self._tokenizer( # type: ignore[no-any-return]
102-
examples["text"], return_tensors="pt", **self.model_config.tokenizer_config.model_dump()
102+
examples["text"], return_tensors="pt", **self.classification_model_config.tokenizer_config.model_dump()
103103
)
104104

105105
dataset = Dataset.from_dict({"text": utterances, "labels": labels})
@@ -148,7 +148,7 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
148148
all_predictions = []
149149
for i in range(0, len(utterances), self.batch_size):
150150
batch = utterances[i : i + self.batch_size]
151-
inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())
151+
inputs = self._tokenizer(batch, return_tensors="pt", **self.classification_model_config.tokenizer_config.model_dump())
152152
inputs = {k: v.to(device) for k, v in inputs.items()}
153153
with torch.no_grad():
154154
outputs = self._model(**inputs)

tests/assets/configs/multiclass.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
clf_name: [RandomForestClassifier]
3030
n_estimators: [5, 10]
3131
- module_name: bert
32-
model_config:
32+
classification_model_config:
3333
- model_name: avsolatorio/GIST-small-Embedding-v0
3434
num_train_epochs: [1]
3535
batch_size: [8, 16]

tests/assets/configs/multilabel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
clf_name: [RandomForestClassifier]
2626
n_estimators: [5, 10]
2727
- module_name: bert
28-
model_config:
28+
classification_model_config:
2929
- model_name: avsolatorio/GIST-small-Embedding-v0
3030
num_train_epochs: [1]
3131
batch_size: [8]

tests/modules/scoring/test_bert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def test_bert_prediction(dataset):
99
"""Test that the transformer model can fit and make predictions."""
1010
data_handler = DataHandler(dataset)
1111

12-
scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
12+
scorer = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
1313

1414
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
1515

@@ -46,7 +46,7 @@ def test_bert_cache_clearing(dataset):
4646
"""Test that the transformer model properly handles cache clearing."""
4747
data_handler = DataHandler(dataset)
4848

49-
scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
49+
scorer = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
5050

5151
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
5252

0 commit comments

Comments
 (0)