diff --git a/autointent/configs/_inference_node.py b/autointent/configs/_inference_node.py index 755ee3b00..7a8580924 100644 --- a/autointent/configs/_inference_node.py +++ b/autointent/configs/_inference_node.py @@ -1,6 +1,5 @@ """Configuration for the nodes.""" -from dataclasses import asdict, dataclass from typing import Any from autointent.custom_types import NodeType @@ -8,32 +7,54 @@ from ._transformers import CrossEncoderConfig, EmbedderConfig -@dataclass class InferenceNodeConfig: """Configuration for the inference node.""" - node_type: NodeType - """Type of the node.""" - module_name: str - """Name of module which is specified as :py:attr:`autointent.modules.base.BaseModule.name`.""" - module_config: dict[str, Any] - """Hyperparameters of underlying module.""" - load_path: str - """Path to the module dump.""" - embedder_config: EmbedderConfig | None = None - """One can override presaved embedder config while loading from file system.""" - cross_encoder_config: CrossEncoderConfig | None = None - """One can override presaved cross encoder config while loading from file system.""" + def __init__( + self, + node_type: NodeType, + module_name: str, + module_config: dict[str, Any], + load_path: str, + embedder_config: EmbedderConfig | None = None, + cross_encoder_config: CrossEncoderConfig | None = None, + ) -> None: + """Initialize the InferenceNodeConfig. + + Args: + node_type: Type of the node. + module_name: Name of module which is specified as :py:attr:`autointent.modules.base.BaseModule.name`. + module_config: Hyperparameters of underlying module. + load_path: Path to the module dump. + embedder_config: One can override presaved embedder config while loading from file system. + cross_encoder_config: One can override presaved cross encoder config while loading from file system. + """ + self.node_type = node_type + self.module_name = module_name + self.module_config = module_config + self.load_path = load_path + + if embedder_config is not None: + self.embedder_config = embedder_config + if cross_encoder_config is not None: + self.cross_encoder_config = cross_encoder_config def asdict(self) -> dict[str, Any]: - """Convert config to dict format.""" - res = asdict(self) - if self.embedder_config is not None: - res["embedder_config"] = self.embedder_config.model_dump() - else: - res.pop("embedder_config") - if self.cross_encoder_config is not None: - res["cross_encoder_config"] = self.cross_encoder_config.model_dump() - else: - res.pop("cross_encoder_config") - return res + """Convert the InferenceNodeConfig to a dictionary. + + Returns: + A dictionary representation of the InferenceNodeConfig. + """ + result = { + "node_type": self.node_type, + "module_name": self.module_name, + "module_config": self.module_config, + "load_path": self.load_path, + } + + if hasattr(self, "embedder_config"): + result["embedder_config"] = self.embedder_config.model_dump() + if hasattr(self, "cross_encoder_config"): + result["cross_encoder_config"] = self.cross_encoder_config.model_dump() + + return result diff --git a/autointent/modules/scoring/_bert.py b/autointent/modules/scoring/_bert.py index a2e3eb946..5fd075ebc 100644 --- a/autointent/modules/scoring/_bert.py +++ b/autointent/modules/scoring/_bert.py @@ -23,7 +23,7 @@ class BertScorer(BaseScorer): - name = "transformer" + name = "bert" supports_multiclass = True supports_multilabel = True _model: Any @@ -79,12 +79,21 @@ def fit( ) -> None: if hasattr(self, "_model"): self.clear_cache() - self._validate_task(labels) model_name = self.model_config.model_name self._tokenizer = AutoTokenizer.from_pretrained(model_name) - self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=self._n_classes) + + label2id = {i: i for i in range(self._n_classes)} + id2label = {i: i for i in range(self._n_classes)} + + self._model = AutoModelForSequenceClassification.from_pretrained( + model_name, + num_labels=self._n_classes, + label2id=label2id, + id2label=id2label, + problem_type="multi_label_classification" if self._multilabel else "single_label_classification", + ) use_cpu = self.model_config.device == "cpu" @@ -94,7 +103,15 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]: ) dataset = Dataset.from_dict({"text": utterances, "labels": labels}) - tokenized_dataset = dataset.map(tokenize_function, batched=True) + + if self._multilabel: + # hugging face uses F.binary_cross_entropy_with_logits under the hood + # which requires target labels to be of float type + dataset = dataset.map( + lambda example: {"label": torch.tensor(example["labels"], dtype=torch.float)}, remove_columns="labels" + ) + + tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=self.batch_size) with tempfile.TemporaryDirectory() as tmp_dir: training_args = TrainingArguments( @@ -127,17 +144,19 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]: msg = "Model is not trained. Call fit() first." raise RuntimeError(msg) + device = next(self._model.parameters()).device all_predictions = [] for i in range(0, len(utterances), self.batch_size): batch = utterances[i : i + self.batch_size] inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump()) + inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = self._model(**inputs) logits = outputs.logits if self._multilabel: - batch_predictions = torch.sigmoid(logits).numpy() + batch_predictions = torch.sigmoid(logits).cpu().numpy() else: - batch_predictions = torch.softmax(logits, dim=1).numpy() + batch_predictions = torch.softmax(logits, dim=1).cpu().numpy() all_predictions.append(batch_predictions) return np.vstack(all_predictions) if all_predictions else np.array([]) diff --git a/autointent/nodes/_inference_node.py b/autointent/nodes/_inference_node.py index 32d29f0ad..3fc7390fa 100644 --- a/autointent/nodes/_inference_node.py +++ b/autointent/nodes/_inference_node.py @@ -34,8 +34,8 @@ def from_config(cls, config: InferenceNodeConfig) -> "InferenceNode": module = node_info.modules_available[config.module_name](**config.module_config) module.load( config.load_path, - embedder_config=config.embedder_config, - cross_encoder_config=config.cross_encoder_config, + embedder_config=getattr(config, "embedder_config", None), + cross_encoder_config=getattr(config, "cross_encoder_config", None), ) return cls(module, config.node_type) diff --git a/tests/assets/configs/multiclass.yaml b/tests/assets/configs/multiclass.yaml index 689813cd3..c21eb779a 100644 --- a/tests/assets/configs/multiclass.yaml +++ b/tests/assets/configs/multiclass.yaml @@ -28,6 +28,13 @@ - module_name: sklearn clf_name: [RandomForestClassifier] n_estimators: [5, 10] + - module_name: bert + model_config: + - model_name: avsolatorio/GIST-small-Embedding-v0 + num_train_epochs: [1] + batch_size: [8, 16] + learning_rate: [5.0e-5] + seed: [0] - node_type: decision target_metric: decision_accuracy search_space: diff --git a/tests/assets/configs/multilabel.yaml b/tests/assets/configs/multilabel.yaml index 241239b3c..f867c6109 100644 --- a/tests/assets/configs/multilabel.yaml +++ b/tests/assets/configs/multilabel.yaml @@ -24,6 +24,13 @@ - module_name: sklearn clf_name: [RandomForestClassifier] n_estimators: [5, 10] + - module_name: bert + model_config: + - model_name: avsolatorio/GIST-small-Embedding-v0 + num_train_epochs: [1] + batch_size: [8] + learning_rate: [5.0e-5] + seed: [0] - node_type: decision target_metric: decision_accuracy search_space: