-
Notifications
You must be signed in to change notification settings - Fork 11
full tuning #165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
full tuning #165
Changes from 6 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
75d9b31
Added code for full tuning
SeBorgey 5890d45
work on review
SeBorgey 7d1d0f8
renaming
SeBorgey 13c82c6
fix ruff
SeBorgey c44b80f
mypy test
SeBorgey cb2e610
ignote mypy
SeBorgey 5e08b44
Feat/bert scorer config refactoring (#168)
voorhs cf4167d
delete validate_task
SeBorgey 92b7f61
report_to
SeBorgey 30cc0ce
batches
SeBorgey 01bd051
Fix/docs building for bert scorer (#171)
voorhs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| """BertScorer class for transformer-based classification.""" | ||
|
|
||
| import tempfile | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| import numpy.typing as npt | ||
| import torch | ||
| from datasets import Dataset | ||
| from transformers import ( | ||
| AutoModelForSequenceClassification, | ||
| AutoTokenizer, | ||
| DataCollatorWithPadding, | ||
| Trainer, | ||
| TrainingArguments, | ||
| ) | ||
|
|
||
| from autointent import Context | ||
| from autointent.configs import EmbedderConfig | ||
| from autointent.custom_types import ListOfLabels | ||
| from autointent.modules.base import BaseScorer | ||
|
|
||
|
|
||
| class TokenizerConfig: | ||
| """Configuration for tokenizer parameters.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| max_length: int = 128, | ||
| padding: str = "max_length", | ||
| truncation: bool = True, | ||
| ) -> None: | ||
| self.max_length = max_length | ||
| self.padding = padding | ||
| self.truncation = truncation | ||
|
|
||
|
|
||
| class BertScorer(BaseScorer): | ||
| name = "transformer" | ||
| supports_multiclass = True | ||
| supports_multilabel = True | ||
| _multilabel: bool | ||
| _model: Any | ||
| _tokenizer: Any | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_config: EmbedderConfig | str | dict[str, Any] | None = None, | ||
| num_train_epochs: int = 3, | ||
| batch_size: int = 8, | ||
| learning_rate: float = 5e-5, | ||
| seed: int = 0, | ||
| tokenizer_config: TokenizerConfig | None = None, | ||
| ) -> None: | ||
| self.model_config = EmbedderConfig.from_search_config(model_config) | ||
| self.num_train_epochs = num_train_epochs | ||
| self.batch_size = batch_size | ||
| self.learning_rate = learning_rate | ||
| self.seed = seed | ||
| self.tokenizer_config = tokenizer_config or TokenizerConfig() | ||
| self._multilabel = False | ||
voorhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @classmethod | ||
| def from_context( | ||
| cls, | ||
| context: Context, | ||
| model_config: EmbedderConfig | str | None = None, | ||
| num_train_epochs: int = 3, | ||
| batch_size: int = 8, | ||
| learning_rate: float = 5e-5, | ||
| seed: int = 0, | ||
| tokenizer_config: TokenizerConfig | None = None, | ||
| ) -> "BertScorer": | ||
| if model_config is None: | ||
| model_config = context.resolve_embedder() | ||
| return cls( | ||
| model_config=model_config, | ||
| num_train_epochs=num_train_epochs, | ||
| batch_size=batch_size, | ||
| learning_rate=learning_rate, | ||
| seed=seed, | ||
| tokenizer_config=tokenizer_config, | ||
| ) | ||
|
|
||
| def get_embedder_config(self) -> dict[str, Any]: | ||
| return self.model_config.model_dump() | ||
|
|
||
| def _validate_task(self, labels: ListOfLabels) -> None: | ||
| """Validate the task and set _multilabel flag.""" | ||
| super()._validate_task(labels) | ||
| self._multilabel = isinstance(labels[0], list) | ||
voorhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def fit( | ||
| self, | ||
| utterances: list[str], | ||
| labels: ListOfLabels, | ||
| ) -> None: | ||
| if hasattr(self, "_model"): | ||
| self.clear_cache() | ||
|
|
||
| self._validate_task(labels) | ||
|
|
||
| if self._multilabel: | ||
| labels_array = np.array(labels) | ||
| num_labels = labels_array.shape[1] | ||
| else: | ||
| num_labels = len(set(labels)) | ||
voorhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| model_name = self.model_config.model_name | ||
| self._tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) | ||
|
|
||
| use_cpu = hasattr(self.model_config, "device") and self.model_config.device == "cpu" | ||
voorhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]: | ||
| return self._tokenizer( # type: ignore[no-any-return] | ||
| examples["text"], | ||
| padding=self.tokenizer_config.padding, | ||
| truncation=self.tokenizer_config.truncation, | ||
| max_length=self.tokenizer_config.max_length, | ||
| ) | ||
voorhs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| dataset = Dataset.from_dict({"text": utterances, "labels": labels}) | ||
| tokenized_dataset = dataset.map(tokenize_function, batched=True) | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmp_dir: | ||
| training_args = TrainingArguments( | ||
| output_dir=tmp_dir, | ||
| num_train_epochs=self.num_train_epochs, | ||
| per_device_train_batch_size=self.batch_size, | ||
| learning_rate=self.learning_rate, | ||
| seed=self.seed, | ||
| save_strategy="no", | ||
| logging_strategy="steps", | ||
| logging_steps=10, | ||
| report_to="wandb", | ||
voorhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| use_cpu=use_cpu, | ||
| ) | ||
|
|
||
| trainer = Trainer( | ||
| model=self._model, | ||
| args=training_args, | ||
| train_dataset=tokenized_dataset, | ||
| tokenizer=self._tokenizer, | ||
| data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer), | ||
| ) | ||
|
|
||
| trainer.train() | ||
|
|
||
| self._model.eval() | ||
|
|
||
| def predict(self, utterances: list[str]) -> npt.NDArray[Any]: | ||
| if not hasattr(self, "_model") or not hasattr(self, "_tokenizer"): | ||
| msg = "Model is not trained. Call fit() first." | ||
| raise RuntimeError(msg) | ||
|
|
||
| inputs = self._tokenizer( | ||
| utterances, padding=True, truncation=True, max_length=self.tokenizer_config.max_length, return_tensors="pt" | ||
voorhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| with torch.no_grad(): | ||
| outputs = self._model(**inputs) | ||
| logits = outputs.logits | ||
voorhs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if self._multilabel: | ||
| return torch.sigmoid(logits).numpy() | ||
| return torch.softmax(logits, dim=1).numpy() | ||
|
|
||
| def clear_cache(self) -> None: | ||
| if hasattr(self, "_model"): | ||
| del self._model | ||
| if hasattr(self, "_tokenizer"): | ||
| del self._tokenizer | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from autointent.context.data_handler import DataHandler | ||
| from autointent.modules import BertScorer | ||
|
|
||
|
|
||
| def test_bert_prediction(dataset): | ||
| """Test that the transformer model can fit and make predictions.""" | ||
| data_handler = DataHandler(dataset) | ||
|
|
||
| scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) | ||
|
|
||
| scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) | ||
|
|
||
| test_data = [ | ||
| "why is there a hold on my american saving bank account", | ||
| "i am nost sure why my account is blocked", | ||
| "why is there a hold on my capital one checking account", | ||
| "i think my account is blocked but i do not know the reason", | ||
| "can you tell me why is my bank account frozen", | ||
| ] | ||
|
|
||
| predictions = scorer.predict(test_data) | ||
|
|
||
| # Verify prediction shape | ||
| assert predictions.shape[0] == len(test_data) | ||
| assert predictions.shape[1] == len(set(data_handler.train_labels(0))) | ||
|
|
||
| # Verify predictions are probabilities | ||
| assert 0.0 <= np.min(predictions) <= np.max(predictions) <= 1.0 | ||
|
|
||
| # Verify probabilities sum to 1 for multiclass | ||
| if not scorer._multilabel: | ||
| for pred_row in predictions: | ||
| np.testing.assert_almost_equal(np.sum(pred_row), 1.0, decimal=5) | ||
|
|
||
| # Test metadata function if available | ||
| if hasattr(scorer, "predict_with_metadata"): | ||
| predictions, metadata = scorer.predict_with_metadata(test_data) | ||
| assert len(predictions) == len(test_data) | ||
| assert metadata is None | ||
|
|
||
|
|
||
| def test_bert_cache_clearing(dataset): | ||
| """Test that the transformer model properly handles cache clearing.""" | ||
| data_handler = DataHandler(dataset) | ||
|
|
||
| scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) | ||
|
|
||
| scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) | ||
|
|
||
| test_data = ["test text"] | ||
|
|
||
| # Should work before clearing cache | ||
| scorer.predict(test_data) | ||
|
|
||
| # Clear the cache | ||
| scorer.clear_cache() | ||
|
|
||
| # Verify model and tokenizer are removed | ||
| assert not hasattr(scorer, "_model") or scorer._model is None | ||
| assert not hasattr(scorer, "_tokenizer") or scorer._tokenizer is None | ||
|
|
||
| # Should raise exception after clearing cache | ||
| with pytest.raises(RuntimeError): | ||
| scorer.predict(test_data) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.