Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 46 additions & 25 deletions autointent/configs/_inference_node.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,60 @@
"""Configuration for the nodes."""

from dataclasses import asdict, dataclass
from typing import Any

from autointent.custom_types import NodeType

from ._transformers import CrossEncoderConfig, EmbedderConfig


@dataclass
class InferenceNodeConfig:
"""Configuration for the inference node."""

node_type: NodeType
"""Type of the node."""
module_name: str
"""Name of module which is specified as :py:attr:`autointent.modules.base.BaseModule.name`."""
module_config: dict[str, Any]
"""Hyperparameters of underlying module."""
load_path: str
"""Path to the module dump."""
embedder_config: EmbedderConfig | None = None
"""One can override presaved embedder config while loading from file system."""
cross_encoder_config: CrossEncoderConfig | None = None
"""One can override presaved cross encoder config while loading from file system."""
def __init__(
self,
node_type: NodeType,
module_name: str,
module_config: dict[str, Any],
load_path: str,
embedder_config: EmbedderConfig | None = None,
cross_encoder_config: CrossEncoderConfig | None = None,
) -> None:
"""Initialize the InferenceNodeConfig.

Args:
node_type: Type of the node.
module_name: Name of module which is specified as :py:attr:`autointent.modules.base.BaseModule.name`.
module_config: Hyperparameters of underlying module.
load_path: Path to the module dump.
embedder_config: One can override presaved embedder config while loading from file system.
cross_encoder_config: One can override presaved cross encoder config while loading from file system.
"""
self.node_type = node_type
self.module_name = module_name
self.module_config = module_config
self.load_path = load_path

if embedder_config is not None:
self.embedder_config = embedder_config
if cross_encoder_config is not None:
self.cross_encoder_config = cross_encoder_config

def asdict(self) -> dict[str, Any]:
"""Convert config to dict format."""
res = asdict(self)
if self.embedder_config is not None:
res["embedder_config"] = self.embedder_config.model_dump()
else:
res.pop("embedder_config")
if self.cross_encoder_config is not None:
res["cross_encoder_config"] = self.cross_encoder_config.model_dump()
else:
res.pop("cross_encoder_config")
return res
"""Convert the InferenceNodeConfig to a dictionary.

Returns:
A dictionary representation of the InferenceNodeConfig.
"""
result = {
"node_type": self.node_type,
"module_name": self.module_name,
"module_config": self.module_config,
"load_path": self.load_path,
}

if hasattr(self, "embedder_config"):
result["embedder_config"] = self.embedder_config.model_dump()
if hasattr(self, "cross_encoder_config"):
result["cross_encoder_config"] = self.cross_encoder_config.model_dump()

return result
31 changes: 25 additions & 6 deletions autointent/modules/scoring/_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


class BertScorer(BaseScorer):
name = "transformer"
name = "bert"
supports_multiclass = True
supports_multilabel = True
_model: Any
Expand Down Expand Up @@ -79,12 +79,21 @@ def fit(
) -> None:
if hasattr(self, "_model"):
self.clear_cache()

self._validate_task(labels)

model_name = self.model_config.model_name
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=self._n_classes)

label2id = {i: i for i in range(self._n_classes)}
id2label = {i: i for i in range(self._n_classes)}

self._model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=self._n_classes,
label2id=label2id,
id2label=id2label,
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
)

use_cpu = self.model_config.device == "cpu"

Expand All @@ -94,7 +103,15 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
)

dataset = Dataset.from_dict({"text": utterances, "labels": labels})
tokenized_dataset = dataset.map(tokenize_function, batched=True)

if self._multilabel:
# hugging face uses F.binary_cross_entropy_with_logits under the hood
# which requires target labels to be of float type
dataset = dataset.map(
lambda example: {"label": torch.tensor(example["labels"], dtype=torch.float)}, remove_columns="labels"
)

tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=self.batch_size)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
Expand Down Expand Up @@ -127,17 +144,19 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
msg = "Model is not trained. Call fit() first."
raise RuntimeError(msg)

device = next(self._model.parameters()).device
all_predictions = []
for i in range(0, len(utterances), self.batch_size):
batch = utterances[i : i + self.batch_size]
inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self._model(**inputs)
logits = outputs.logits
if self._multilabel:
batch_predictions = torch.sigmoid(logits).numpy()
batch_predictions = torch.sigmoid(logits).cpu().numpy()
else:
batch_predictions = torch.softmax(logits, dim=1).numpy()
batch_predictions = torch.softmax(logits, dim=1).cpu().numpy()
all_predictions.append(batch_predictions)
return np.vstack(all_predictions) if all_predictions else np.array([])

Expand Down
4 changes: 2 additions & 2 deletions autointent/nodes/_inference_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def from_config(cls, config: InferenceNodeConfig) -> "InferenceNode":
module = node_info.modules_available[config.module_name](**config.module_config)
module.load(
config.load_path,
embedder_config=config.embedder_config,
cross_encoder_config=config.cross_encoder_config,
embedder_config=getattr(config, "embedder_config", None),
cross_encoder_config=getattr(config, "cross_encoder_config", None),
)
return cls(module, config.node_type)

Expand Down
7 changes: 7 additions & 0 deletions tests/assets/configs/multiclass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
- module_name: sklearn
clf_name: [RandomForestClassifier]
n_estimators: [5, 10]
- module_name: bert
model_config:
- model_name: avsolatorio/GIST-small-Embedding-v0
num_train_epochs: [1]
batch_size: [8, 16]
learning_rate: [5.0e-5]
seed: [0]
- node_type: decision
target_metric: decision_accuracy
search_space:
Expand Down
7 changes: 7 additions & 0 deletions tests/assets/configs/multilabel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
- module_name: sklearn
clf_name: [RandomForestClassifier]
n_estimators: [5, 10]
- module_name: bert
model_config:
- model_name: avsolatorio/GIST-small-Embedding-v0
num_train_epochs: [1]
batch_size: [8]
learning_rate: [5.0e-5]
seed: [0]
- node_type: decision
target_metric: decision_accuracy
search_space:
Expand Down