Skip to content

Commit 29de65d

Browse files
SeBorgeyvoorhs
andauthored
bert-scorer ending (#172)
* batches * tests check * fix * return to torch * fix for tests * Fix/bert scorer (#174) * fix str and float issue and shrinken search space * update `inference node config` overriding logic * fix typing * fix codestyle * fix multilabel issue * attempt to fix `inference node config` bugs * another attempt --------- Co-authored-by: Алексеев Илья <[email protected]>
1 parent 644a849 commit 29de65d

File tree

5 files changed

+87
-33
lines changed

5 files changed

+87
-33
lines changed
Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,60 @@
11
"""Configuration for the nodes."""
22

3-
from dataclasses import asdict, dataclass
43
from typing import Any
54

65
from autointent.custom_types import NodeType
76

87
from ._transformers import CrossEncoderConfig, EmbedderConfig
98

109

11-
@dataclass
1210
class InferenceNodeConfig:
1311
"""Configuration for the inference node."""
1412

15-
node_type: NodeType
16-
"""Type of the node."""
17-
module_name: str
18-
"""Name of module which is specified as :py:attr:`autointent.modules.base.BaseModule.name`."""
19-
module_config: dict[str, Any]
20-
"""Hyperparameters of underlying module."""
21-
load_path: str
22-
"""Path to the module dump."""
23-
embedder_config: EmbedderConfig | None = None
24-
"""One can override presaved embedder config while loading from file system."""
25-
cross_encoder_config: CrossEncoderConfig | None = None
26-
"""One can override presaved cross encoder config while loading from file system."""
13+
def __init__(
14+
self,
15+
node_type: NodeType,
16+
module_name: str,
17+
module_config: dict[str, Any],
18+
load_path: str,
19+
embedder_config: EmbedderConfig | None = None,
20+
cross_encoder_config: CrossEncoderConfig | None = None,
21+
) -> None:
22+
"""Initialize the InferenceNodeConfig.
23+
24+
Args:
25+
node_type: Type of the node.
26+
module_name: Name of module which is specified as :py:attr:`autointent.modules.base.BaseModule.name`.
27+
module_config: Hyperparameters of underlying module.
28+
load_path: Path to the module dump.
29+
embedder_config: One can override presaved embedder config while loading from file system.
30+
cross_encoder_config: One can override presaved cross encoder config while loading from file system.
31+
"""
32+
self.node_type = node_type
33+
self.module_name = module_name
34+
self.module_config = module_config
35+
self.load_path = load_path
36+
37+
if embedder_config is not None:
38+
self.embedder_config = embedder_config
39+
if cross_encoder_config is not None:
40+
self.cross_encoder_config = cross_encoder_config
2741

2842
def asdict(self) -> dict[str, Any]:
29-
"""Convert config to dict format."""
30-
res = asdict(self)
31-
if self.embedder_config is not None:
32-
res["embedder_config"] = self.embedder_config.model_dump()
33-
else:
34-
res.pop("embedder_config")
35-
if self.cross_encoder_config is not None:
36-
res["cross_encoder_config"] = self.cross_encoder_config.model_dump()
37-
else:
38-
res.pop("cross_encoder_config")
39-
return res
43+
"""Convert the InferenceNodeConfig to a dictionary.
44+
45+
Returns:
46+
A dictionary representation of the InferenceNodeConfig.
47+
"""
48+
result = {
49+
"node_type": self.node_type,
50+
"module_name": self.module_name,
51+
"module_config": self.module_config,
52+
"load_path": self.load_path,
53+
}
54+
55+
if hasattr(self, "embedder_config"):
56+
result["embedder_config"] = self.embedder_config.model_dump()
57+
if hasattr(self, "cross_encoder_config"):
58+
result["cross_encoder_config"] = self.cross_encoder_config.model_dump()
59+
60+
return result

autointent/modules/scoring/_bert.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
class BertScorer(BaseScorer):
26-
name = "transformer"
26+
name = "bert"
2727
supports_multiclass = True
2828
supports_multilabel = True
2929
_model: Any
@@ -79,12 +79,21 @@ def fit(
7979
) -> None:
8080
if hasattr(self, "_model"):
8181
self.clear_cache()
82-
8382
self._validate_task(labels)
8483

8584
model_name = self.model_config.model_name
8685
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
87-
self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=self._n_classes)
86+
87+
label2id = {i: i for i in range(self._n_classes)}
88+
id2label = {i: i for i in range(self._n_classes)}
89+
90+
self._model = AutoModelForSequenceClassification.from_pretrained(
91+
model_name,
92+
num_labels=self._n_classes,
93+
label2id=label2id,
94+
id2label=id2label,
95+
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
96+
)
8897

8998
use_cpu = self.model_config.device == "cpu"
9099

@@ -94,7 +103,15 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
94103
)
95104

96105
dataset = Dataset.from_dict({"text": utterances, "labels": labels})
97-
tokenized_dataset = dataset.map(tokenize_function, batched=True)
106+
107+
if self._multilabel:
108+
# hugging face uses F.binary_cross_entropy_with_logits under the hood
109+
# which requires target labels to be of float type
110+
dataset = dataset.map(
111+
lambda example: {"label": torch.tensor(example["labels"], dtype=torch.float)}, remove_columns="labels"
112+
)
113+
114+
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=self.batch_size)
98115

99116
with tempfile.TemporaryDirectory() as tmp_dir:
100117
training_args = TrainingArguments(
@@ -127,17 +144,19 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
127144
msg = "Model is not trained. Call fit() first."
128145
raise RuntimeError(msg)
129146

147+
device = next(self._model.parameters()).device
130148
all_predictions = []
131149
for i in range(0, len(utterances), self.batch_size):
132150
batch = utterances[i : i + self.batch_size]
133151
inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())
152+
inputs = {k: v.to(device) for k, v in inputs.items()}
134153
with torch.no_grad():
135154
outputs = self._model(**inputs)
136155
logits = outputs.logits
137156
if self._multilabel:
138-
batch_predictions = torch.sigmoid(logits).numpy()
157+
batch_predictions = torch.sigmoid(logits).cpu().numpy()
139158
else:
140-
batch_predictions = torch.softmax(logits, dim=1).numpy()
159+
batch_predictions = torch.softmax(logits, dim=1).cpu().numpy()
141160
all_predictions.append(batch_predictions)
142161
return np.vstack(all_predictions) if all_predictions else np.array([])
143162

autointent/nodes/_inference_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def from_config(cls, config: InferenceNodeConfig) -> "InferenceNode":
3434
module = node_info.modules_available[config.module_name](**config.module_config)
3535
module.load(
3636
config.load_path,
37-
embedder_config=config.embedder_config,
38-
cross_encoder_config=config.cross_encoder_config,
37+
embedder_config=getattr(config, "embedder_config", None),
38+
cross_encoder_config=getattr(config, "cross_encoder_config", None),
3939
)
4040
return cls(module, config.node_type)
4141

tests/assets/configs/multiclass.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@
2828
- module_name: sklearn
2929
clf_name: [RandomForestClassifier]
3030
n_estimators: [5, 10]
31+
- module_name: bert
32+
model_config:
33+
- model_name: avsolatorio/GIST-small-Embedding-v0
34+
num_train_epochs: [1]
35+
batch_size: [8, 16]
36+
learning_rate: [5.0e-5]
37+
seed: [0]
3138
- node_type: decision
3239
target_metric: decision_accuracy
3340
search_space:

tests/assets/configs/multilabel.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
- module_name: sklearn
2525
clf_name: [RandomForestClassifier]
2626
n_estimators: [5, 10]
27+
- module_name: bert
28+
model_config:
29+
- model_name: avsolatorio/GIST-small-Embedding-v0
30+
num_train_epochs: [1]
31+
batch_size: [8]
32+
learning_rate: [5.0e-5]
33+
seed: [0]
2734
- node_type: decision
2835
target_metric: decision_accuracy
2936
search_space:

0 commit comments

Comments
 (0)