diff --git a/.github/workflows/build-docs.yaml b/.github/workflows/build-docs.yaml index 82a6896b6..c9d2d3152 100644 --- a/.github/workflows/build-docs.yaml +++ b/.github/workflows/build-docs.yaml @@ -36,9 +36,12 @@ jobs: with: python-version: "3.10" - - name: setup poetry - run: | - curl -sSL https://install.python-poetry.org | python - + - name: Cache Hugging Face + id: cache-hf + uses: actions/cache@v4 + with: + path: ~/.cache/huggingface + key: docs-cache-hf - name: Install pandoc run: | @@ -46,23 +49,23 @@ jobs: - name: Install dependencies run: | - poetry install --with docs + pip install .[docs] - name: Run tests if: github.event_name != 'workflow_dispatch' run: | echo "Testing documentation build..." - make test-docs + python -m sphinx build -b doctest docs/source docs/build/html - name: Build documentation if: ${{ github.ref == 'refs/heads/dev' }} && github.event_name != 'workflow_dispatch' run: | - make docs + python -m sphinx build -b html docs/source docs/build/html - name: build multiversion documentation if: github.event_name == 'release' || github.event_name == 'workflow_dispatch' run: | - make multi-version-docs + sphinx-multiversion docs/source docs/build/html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@v3 diff --git a/.github/workflows/reusable-test.yaml b/.github/workflows/reusable-test.yaml new file mode 100644 index 000000000..ac8f0b169 --- /dev/null +++ b/.github/workflows/reusable-test.yaml @@ -0,0 +1,46 @@ +name: Reusable Test Workflow + +on: + workflow_call: + inputs: + test_command: + required: true + type: string + description: 'Command to run tests' + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ ubuntu-latest ] + python-version: [ "3.10", "3.11", "3.12" ] + include: + - os: windows-latest + python-version: "3.10" + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Cache Hugging Face + id: cache-hf + uses: actions/cache@v4 + with: + path: ~/.cache/huggingface + key: ${{ runner.os }}-hf + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: "pip" + + - name: Install dependencies + run: | + pip install .[test] + + - name: Run tests + run: | + ${{ inputs.test_command }} \ No newline at end of file diff --git a/.github/workflows/test-inference.yaml b/.github/workflows/test-inference.yaml index a68ef07e8..d89e4503f 100644 --- a/.github/workflows/test-inference.yaml +++ b/.github/workflows/test-inference.yaml @@ -8,31 +8,6 @@ on: jobs: test: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ ubuntu-latest ] - python-version: [ "3.10", "3.11", "3.12" ] - include: - - os: windows-latest - python-version: "3.10" - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: "pip" - - - name: Install dependencies - run: | - pip install . - pip install pytest pytest-asyncio - - - name: Run tests - run: | - pytest tests/pipeline/test_inference.py + uses: ./.github/workflows/reusable-test.yaml + with: + test_command: pytest -n auto tests/pipeline/test_inference.py diff --git a/.github/workflows/test-nodes.yaml b/.github/workflows/test-nodes.yaml index b10161724..c1914913c 100644 --- a/.github/workflows/test-nodes.yaml +++ b/.github/workflows/test-nodes.yaml @@ -8,31 +8,6 @@ on: jobs: test: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ ubuntu-latest ] - python-version: [ "3.10", "3.11", "3.12" ] - include: - - os: windows-latest - python-version: "3.10" - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: "pip" - - - name: Install dependencies - run: | - pip install . - pip install pytest pytest-asyncio - - - name: Run tests - run: | - pytest tests/nodes + uses: ./.github/workflows/reusable-test.yaml + with: + test_command: pytest -n auto tests/nodes diff --git a/.github/workflows/test-optimization.yaml b/.github/workflows/test-optimization.yaml index 4625f39d7..ad3168dd1 100644 --- a/.github/workflows/test-optimization.yaml +++ b/.github/workflows/test-optimization.yaml @@ -8,31 +8,6 @@ on: jobs: test: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ ubuntu-latest ] - python-version: [ "3.10", "3.11", "3.12" ] - include: - - os: windows-latest - python-version: "3.10" - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: "pip" - - - name: Install dependencies - run: | - pip install . - pip install pytest pytest-asyncio - - - name: Run tests - run: | - pytest tests/pipeline/test_optimization.py + uses: ./.github/workflows/reusable-test.yaml + with: + test_command: pytest -n auto tests/pipeline/test_optimization.py diff --git a/.github/workflows/test-presets.yaml b/.github/workflows/test-presets.yaml index ab4a6723d..836c58fa9 100644 --- a/.github/workflows/test-presets.yaml +++ b/.github/workflows/test-presets.yaml @@ -8,31 +8,6 @@ on: jobs: test: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ ubuntu-latest ] - python-version: [ "3.10", "3.11", "3.12" ] - include: - - os: windows-latest - python-version: "3.10" - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: "pip" - - - name: Install dependencies - run: | - pip install . - pip install pytest pytest-asyncio - - - name: Run tests - run: | - pytest tests/pipeline/test_presets.py + uses: ./.github/workflows/reusable-test.yaml + with: + test_command: pytest -n auto tests/pipeline/test_presets.py diff --git a/.github/workflows/typing.yml b/.github/workflows/typing.yml index eb0c374ff..dfe873e68 100644 --- a/.github/workflows/typing.yml +++ b/.github/workflows/typing.yml @@ -11,14 +11,9 @@ jobs: python-version: "3.10" cache: "pip" - - name: Install Poetry - run: | - curl -sSL https://install.python-poetry.org | python3 - - echo "$HOME/.poetry/bin" >> $GITHUB_PATH - - name: Install dependencies run: | - poetry install --with typing + pip install .[typing] - name: Run mypy - run: make typing + run: mypy autointent diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index 5883080eb..4d8164f26 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -8,31 +8,6 @@ on: jobs: test: - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ ubuntu-latest ] - python-version: [ "3.10", "3.11", "3.12" ] - include: - - os: windows-latest - python-version: "3.10" - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - cache: "pip" - - - name: Install dependencies - run: | - pip install . - pip install pytest pytest-asyncio - - - name: Run tests - run: | - pytest --ignore=tests/nodes --ignore=tests/pipeline + uses: ./.github/workflows/reusable-test.yaml + with: + test_command: pytest -n auto --ignore=tests/nodes --ignore=tests/pipeline diff --git a/Makefile b/Makefile index 5606dc433..5cb4fdede 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ poetry = poetry run .PHONY: install install: - poetry install --with dev,test,typing,docs + poetry install --extras "dev test typing docs" .PHONY: test test: @@ -24,7 +24,7 @@ lint: .PHONY: sync sync: - poetry sync --with dev,test,typing,docs + poetry sync --extras "dev test typing docs" .PHONY: docs docs: diff --git a/autointent/_dump_tools.py b/autointent/_dump_tools.py index 13a7ba953..12772984f 100644 --- a/autointent/_dump_tools.py +++ b/autointent/_dump_tools.py @@ -33,6 +33,8 @@ class Dumper: estimators = "estimators" cross_encoders = "cross_encoders" pydantic_models: str = "pydantic" + hf_models = "hf_models" + hf_tokenizers = "hf_tokenizers" @staticmethod def make_subdirectories(path: Path) -> None: @@ -48,12 +50,14 @@ def make_subdirectories(path: Path) -> None: path / Dumper.estimators, path / Dumper.cross_encoders, path / Dumper.pydantic_models, + path / Dumper.hf_models, + path / Dumper.hf_tokenizers, ] for subdir in subdirectories: subdir.mkdir(parents=True, exist_ok=True) @staticmethod - def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901 + def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901, PLR0912, PLR0915 """Dump modules attributes to filestystem. Args: @@ -89,6 +93,28 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901 except Exception as e: msg = f"Error dumping pydantic model {key}: {e}" logging.exception(msg) + elif (key == "_model" or "model" in key.lower()) and hasattr(val, "save_pretrained"): + model_path = path / Dumper.hf_models / key + model_path.mkdir(parents=True, exist_ok=True) + try: + val.save_pretrained(model_path) + class_info = {"module": val.__class__.__module__, "name": val.__class__.__name__} + with (model_path / "class_info.json").open("w") as f: + json.dump(class_info, f) + except Exception as e: + msg = f"Error dumping HF model {key}: {e}" + logger.exception(msg) + elif (key == "_tokenizer" or "tokenizer" in key.lower()) and hasattr(val, "save_pretrained"): + tokenizer_path = path / Dumper.hf_tokenizers / key + tokenizer_path.mkdir(parents=True, exist_ok=True) + try: + val.save_pretrained(tokenizer_path) + class_info = {"module": val.__class__.__module__, "name": val.__class__.__name__} + with (tokenizer_path / "class_info.json").open("w") as f: + json.dump(class_info, f) + except Exception as e: + msg = f"Error dumping HF tokenizer {key}: {e}" + logger.exception(msg) else: msg = f"Attribute {key} of type {type(val)} cannot be dumped to file system." logger.error(msg) @@ -114,6 +140,8 @@ def load( # noqa: PLR0912, C901, PLR0915 estimators: dict[str, Any] = {} cross_encoders: dict[str, Any] = {} pydantic_models: dict[str, Any] = {} + hf_models: dict[str, Any] = {} + hf_tokenizers: dict[str, Any] = {} for child in path.iterdir(): if child.name == Dumper.tags: @@ -151,7 +179,6 @@ def load( # noqa: PLR0912, C901, PLR0915 sig = inspect.signature(obj.__init__) if variable_name in sig.parameters: model_type = sig.parameters[variable_name].annotation - if model_type is None: msg = f"No type annotation found for {variable_name}" logger.error(msg) @@ -174,9 +201,45 @@ def load( # noqa: PLR0912, C901, PLR0915 continue pydantic_models[variable_name] = model_type(**content) + elif child.name == Dumper.hf_models: + for model_dir in child.iterdir(): + try: + with (model_dir / "class_info.json").open("r") as f: + class_info = json.load(f) + + module = __import__(class_info["module"], fromlist=[class_info["name"]]) + model_class = getattr(module, class_info["name"]) + + hf_models[model_dir.name] = model_class.from_pretrained(model_dir) + except Exception as e: # noqa: PERF203 + msg = f"Error loading HF model {model_dir.name}: {e}" + logger.exception(msg) + elif child.name == Dumper.hf_tokenizers: + for tokenizer_dir in child.iterdir(): + try: + with (tokenizer_dir / "class_info.json").open("r") as f: + class_info = json.load(f) + + module = __import__(class_info["module"], fromlist=[class_info["name"]]) + tokenizer_class = getattr(module, class_info["name"]) + + hf_tokenizers[tokenizer_dir.name] = tokenizer_class.from_pretrained(tokenizer_dir) + except Exception as e: # noqa: PERF203 + msg = f"Error loading HF tokenizer {tokenizer_dir.name}: {e}" + logger.exception(msg) else: msg = f"Found unexpected child {child}" logger.error(msg) + obj.__dict__.update( - tags | simple_attrs | arrays | embedders | indexes | estimators | cross_encoders | pydantic_models + tags + | simple_attrs + | arrays + | embedders + | indexes + | estimators + | cross_encoders + | pydantic_models + | hf_models + | hf_tokenizers ) diff --git a/autointent/context/data_handler/_stratification.py b/autointent/context/data_handler/_stratification.py index 235f38e2b..9628924bc 100644 --- a/autointent/context/data_handler/_stratification.py +++ b/autointent/context/data_handler/_stratification.py @@ -12,7 +12,7 @@ from numpy import typing as npt from sklearn.model_selection import train_test_split from skmultilearn.model_selection import IterativeStratification -from transformers import set_seed +from transformers import set_seed # type: ignore[attr-defined] from autointent import Dataset from autointent.custom_types import LabelType @@ -128,7 +128,8 @@ def _split_multilabel(self, dataset: HFDataset, test_size: float) -> Sequence[np Returns: A sequence containing indices for train and test splits. """ - set_seed(self.random_seed) # workaround for buggy nature of IterativeStratification from skmultilearn + if self.random_seed is not None: + set_seed(self.random_seed) # workaround for buggy nature of IterativeStratification from skmultilearn splitter = IterativeStratification( n_splits=2, order=2, diff --git a/autointent/modules/__init__.py b/autointent/modules/__init__.py index 212d886b1..b8ebdf3da 100644 --- a/autointent/modules/__init__.py +++ b/autointent/modules/__init__.py @@ -54,4 +54,25 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]: ) -__all__ = [] # type: ignore[var-annotated] +__all__ = [ + "AdaptiveDecision", + "ArgmaxDecision", + "BaseDecision", + "BaseEmbedding", + "BaseModule", + "BaseRegex", + "BaseScorer", + "DNNCScorer", + "DescriptionScorer", + "JinoosDecision", + "KNNScorer", + "LinearScorer", + "LogregAimedEmbedding", + "MLKnnScorer", + "RerankScorer", + "RetrievalAimedEmbedding", + "SimpleRegex", + "SklearnScorer", + "ThresholdDecision", + "TunableDecision", +] diff --git a/autointent/modules/scoring/_bert.py b/autointent/modules/scoring/_bert.py index 5fd075ebc..d292fea1c 100644 --- a/autointent/modules/scoring/_bert.py +++ b/autointent/modules/scoring/_bert.py @@ -7,7 +7,7 @@ import numpy.typing as npt import torch from datasets import Dataset -from transformers import ( +from transformers import ( # type: ignore[attr-defined] AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, @@ -31,14 +31,14 @@ class BertScorer(BaseScorer): def __init__( self, - model_config: HFModelConfig | str | dict[str, Any] | None = None, + classification_model_config: HFModelConfig | str | dict[str, Any] | None = None, num_train_epochs: int = 3, batch_size: int = 8, learning_rate: float = 5e-5, seed: int = 0, report_to: REPORTERS_NAMES | None = None, # type: ignore # noqa: PGH003 ) -> None: - self.model_config = HFModelConfig.from_search_config(model_config) + self.classification_model_config = HFModelConfig.from_search_config(classification_model_config) self.num_train_epochs = num_train_epochs self.batch_size = batch_size self.learning_rate = learning_rate @@ -49,19 +49,19 @@ def __init__( def from_context( cls, context: Context, - model_config: HFModelConfig | str | dict[str, Any] | None = None, + classification_model_config: HFModelConfig | str | dict[str, Any] | None = None, num_train_epochs: int = 3, batch_size: int = 8, learning_rate: float = 5e-5, seed: int = 0, ) -> "BertScorer": - if model_config is None: - model_config = context.resolve_embedder() + if classification_model_config is None: + classification_model_config = context.resolve_embedder() report_to = context.logging_config.report_to return cls( - model_config=model_config, + classification_model_config=classification_model_config, num_train_epochs=num_train_epochs, batch_size=batch_size, learning_rate=learning_rate, @@ -70,7 +70,7 @@ def from_context( ) def get_embedder_config(self) -> dict[str, Any]: - return self.model_config.model_dump() + return self.classification_model_config.model_dump() def fit( self, @@ -81,7 +81,7 @@ def fit( self.clear_cache() self._validate_task(labels) - model_name = self.model_config.model_name + model_name = self.classification_model_config.model_name self._tokenizer = AutoTokenizer.from_pretrained(model_name) label2id = {i: i for i in range(self._n_classes)} @@ -95,11 +95,11 @@ def fit( problem_type="multi_label_classification" if self._multilabel else "single_label_classification", ) - use_cpu = self.model_config.device == "cpu" + use_cpu = self.classification_model_config.device == "cpu" def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]: return self._tokenizer( # type: ignore[no-any-return] - examples["text"], return_tensors="pt", **self.model_config.tokenizer_config.model_dump() + examples["text"], return_tensors="pt", **self.classification_model_config.tokenizer_config.model_dump() ) dataset = Dataset.from_dict({"text": utterances, "labels": labels}) @@ -127,7 +127,7 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]: use_cpu=use_cpu, ) - trainer = Trainer( + trainer = Trainer( # type: ignore[no-untyped-call] model=self._model, args=training_args, train_dataset=tokenized_dataset, @@ -135,7 +135,7 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]: data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer), ) - trainer.train() + trainer.train() # type: ignore[attr-defined] self._model.eval() @@ -148,7 +148,9 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]: all_predictions = [] for i in range(0, len(utterances), self.batch_size): batch = utterances[i : i + self.batch_size] - inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump()) + inputs = self._tokenizer( + batch, return_tensors="pt", **self.classification_model_config.tokenizer_config.model_dump() + ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = self._model(**inputs) diff --git a/autointent/modules/scoring/_linear.py b/autointent/modules/scoring/_linear.py index 06e04c4dd..be74ada89 100644 --- a/autointent/modules/scoring/_linear.py +++ b/autointent/modules/scoring/_linear.py @@ -4,6 +4,7 @@ import numpy as np import numpy.typing as npt +from pydantic import PositiveInt from sklearn.linear_model import LogisticRegression, LogisticRegressionCV from sklearn.multioutput import MultiOutputClassifier @@ -22,7 +23,6 @@ class LinearScorer(BaseScorer): Args: embedder_config: Config of the embedder model cv: Number of cross-validation folds, defaults to 3 - n_jobs: Number of parallel jobs for cross-validation, defaults to None seed: Random seed for reproducibility, defaults to 0 Example: @@ -72,18 +72,21 @@ def __init__( def from_context( cls, context: Context, + cv: PositiveInt = 3, embedder_config: EmbedderConfig | str | None = None, ) -> "LinearScorer": """Create a LinearScorer instance using a Context object. Args: context: Context containing configurations and utilities + cv: Number of cross-validation folds, defaults to 3 embedder_config: Config of the embedder, or None to use the best embedder """ if embedder_config is None: embedder_config = context.resolve_embedder() return cls( + cv=cv, embedder_config=embedder_config, ) diff --git a/autointent/nodes/_node_optimizer.py b/autointent/nodes/_node_optimizer.py index 3d4b4798e..8d1d4872f 100644 --- a/autointent/nodes/_node_optimizer.py +++ b/autointent/nodes/_node_optimizer.py @@ -11,32 +11,14 @@ import optuna import torch from optuna.trial import Trial -from pydantic import BaseModel, Field from typing_extensions import assert_never from autointent import Dataset from autointent.context import Context from autointent.custom_types import NodeType, SamplerType, SearchSpaceValidationMode +from autointent.nodes.emissions_tracker import EmissionsTracker from autointent.nodes.info import NODES_INFO - - -class ParamSpaceInt(BaseModel): - """Integer parameter search space configuration.""" - - low: int = Field(..., description="Lower boundary of the search space.") - high: int = Field(..., description="Upper boundary of the search space.") - step: int = Field(1, description="Step size for the search space.") - log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") - - -class ParamSpaceFloat(BaseModel): - """Float parameter search space configuration.""" - - low: float = Field(..., description="Lower boundary of the search space.") - high: float = Field(..., description="Upper boundary of the search space.") - step: float | None = Field(None, description="Step size for the search space (if applicable).") - log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") - +from autointent.schemas.node_validation import ParamSpaceFloat, ParamSpaceInt, SearchSpaceConfig logger = logging.getLogger(__name__) @@ -67,6 +49,7 @@ def __init__( self.node_type = node_type self.node_info = NODES_INFO[node_type] self.target_metric = target_metric + self.emissions_tracker = EmissionsTracker(project_name=f"{self.node_info.node_type}") self.metrics = metrics if metrics is not None else [] if self.target_metric not in self.metrics: @@ -141,8 +124,13 @@ def objective( context.callback_handler.start_module(module_name=module_name, num=self._counter, module_kwargs=config) self._logger.debug("Scoring %s module...", module_name) - all_metrics = module.score(context, metrics=self.metrics) - target_metric = all_metrics[self.target_metric] + + self.emissions_tracker.start_task("module_scoring") + final_metrics = module.score(context, metrics=self.metrics) + emissions_metrics = self.emissions_tracker.stop_task() + all_metrics = {**final_metrics, **emissions_metrics} + + target_metric = final_metrics[self.target_metric] context.callback_handler.log_metrics(all_metrics) context.callback_handler.end_module() @@ -161,7 +149,7 @@ def objective( config, target_metric, self.target_metric, - all_metrics, + final_metrics, module.get_assets(), # retriever name / scores / predictions module_dump_dir, module=module if not context.is_ram_to_clear() else None, @@ -270,7 +258,8 @@ def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidat def validate_search_space(self, search_space: list[dict[str, Any]]) -> None: """Check if search space is configured correctly.""" - for module_search_space in search_space: + validated_search_space = SearchSpaceConfig(search_space).model_dump() + for module_search_space in validated_search_space: module_search_space_no_optuna, module_name = self._reformat_search_space(deepcopy(module_search_space)) for params_combination in it.product(*module_search_space_no_optuna.values()): diff --git a/autointent/nodes/emissions_tracker.py b/autointent/nodes/emissions_tracker.py new file mode 100644 index 000000000..a186a576b --- /dev/null +++ b/autointent/nodes/emissions_tracker.py @@ -0,0 +1,53 @@ +"""Emissions tracking functionality for monitoring energy consumption and carbon emissions.""" + +import json +import logging + +from codecarbon import EmissionsTracker as CodeCarbonTracker # type: ignore[import-untyped] +from codecarbon.output import EmissionsData # type: ignore[import-untyped] + +logger = logging.getLogger(__name__) + + +class EmissionsTracker: + """Class for tracking energy consumption and carbon emissions.""" + + def __init__(self, project_name: str, measure_power_secs: int = 1) -> None: + """Initialize the emissions tracker. + + Args: + project_name: Name of the project to track emissions for. + measure_power_secs: How often to measure power consumption in seconds. + """ + self._logger = logger + self.tracker = CodeCarbonTracker(project_name=project_name, measure_power_secs=measure_power_secs) + + def start_task(self, task_name: str) -> None: + """Start tracking emissions for a specific task. + + Args: + task_name: Name of the task to track emissions for. + """ + self.tracker.start_task(task_name) + + def stop_task(self) -> dict[str, float]: + """Stop tracking emissions and return the emissions data. + + Returns: + Dictionary containing emissions metrics. + """ + emissions_data = self.tracker.stop_task() + _ = self.tracker.stop() + return self._process_metrics(emissions_data) + + def _process_metrics(self, emissions_data: EmissionsData) -> dict[str, float]: + """Process emissions data into metrics with the 'emissions/' prefix. + + Args: + emissions_data: Raw emissions data from the tracker. + + Returns: + Dictionary of processed emissions metrics with the 'emissions/' prefix. + """ + emissions_data_dict = json.loads(emissions_data.toJSON()) + return {f"emissions/{k}": v for k, v in emissions_data_dict.items() if isinstance(v, int | float)} diff --git a/autointent/schemas/node_validation.py b/autointent/schemas/node_validation.py new file mode 100644 index 000000000..ca118ecf6 --- /dev/null +++ b/autointent/schemas/node_validation.py @@ -0,0 +1,366 @@ +"""Schemes.""" + +import inspect +from collections.abc import Iterator +from typing import Annotated, Any, Literal, TypeAlias, Union, get_args, get_origin, get_type_hints + +from pydantic import BaseModel, ConfigDict, Field, PositiveInt, RootModel, ValidationError, model_validator + +from autointent.custom_types import NodeType +from autointent.modules import BaseModule +from autointent.nodes.info import DecisionNodeInfo, EmbeddingNodeInfo, RegexNodeInfo, ScoringNodeInfo + + +class ParamSpaceInt(BaseModel): + """Integer parameter search space configuration.""" + + low: int = Field(..., description="Lower boundary of the search space.") + high: int = Field(..., description="Upper boundary of the search space.") + step: int = Field(1, description="Step size for the search space.") + log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") + + +class ParamSpaceFloat(BaseModel): + """Float parameter search space configuration.""" + + low: float = Field(..., description="Lower boundary of the search space.") + high: float = Field(..., description="Upper boundary of the search space.") + step: float | None = Field(None, description="Step size for the search space (if applicable).") + log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") + + +def unwrap_annotated(tp: type) -> type: + """Unwrap the Annotated type to get the actual type. + + :param tp: Type to unwrap + :return: Unwrapped type + """ + # Check if the type is an Annotated type using get_origin + # Annotated[int, "some metadata"] would have origin as Annotated + # If it is Annotated, extract the first argument which is the actual type + # Otherwise return the original type unchanged + return get_args(tp)[0] if get_origin(tp) is Annotated else tp + + +def type_matches(target: type, tp: type) -> bool: + """Recursively check if the target type is present in the given type. + + This function handles union types by unwrapping Annotated types where necessary. + + :param target: Target type + :param tp: Given type + :return: If the target type is present in the given type + """ + # Get the origin of the type to determine if it's a generic type + # For example, Union, List, Dict, etc. + origin = get_origin(tp) + + # If the type is a Union (e.g., int | str or Union[int, str]) + if origin is Union: + # Check if any of the union's arguments match the target type + # Recursively call type_matches for each argument in the union + return any(type_matches(target, arg) for arg in get_args(tp)) + + # For non-Union types, unwrap any Annotated wrapper and compare with the target type + # This handles cases like Annotated[int, "some description"] matching with int + return unwrap_annotated(tp) is target + + +def get_optuna_class(param_type: type) -> type[ParamSpaceInt | ParamSpaceFloat] | None: + """Get the Optuna class for the given parameter type. + + If the (possibly annotated or union) type includes int or float, this function + returns the corresponding search space class. + + :param param_type: Parameter type (could be a union, annotated type, or container) + :return: ParamSpaceInt if the type matches int, ParamSpaceFloat if it matches float, else None. + """ + # Check if the parameter type matches or includes int + if type_matches(int, param_type): + return ParamSpaceInt + # Check if the parameter type matches or includes float + if type_matches(float, param_type): + return ParamSpaceFloat + # Return None if neither int nor float types match + return None + + +def generate_models_and_union_type_for_classes( + classes: list[type[BaseModule]], +) -> tuple[type[BaseModel], dict[str, type[BaseModel]]]: + """Dynamically generates Pydantic models for class constructors and creates a union type. + + This function takes a list of module classes and creates Pydantic models that represent + their initialization parameters. It also creates a union type of all these models. + + Args: + classes: A list of BaseModule subclasses to generate models for + + Returns: + A tuple containing: + - A union type of all generated models + - A dictionary mapping module names to their generated model classes + """ + # Dictionary to store the generated models, keyed by module name + models: dict[str, type[BaseModel]] = {} + + # Iterate through each module class + for cls in classes: + # Get the signature of the from_context method to extract parameters + init_signature = inspect.signature(cls.from_context) + # Get the global namespace for resolving variables in type hints + globalns = getattr(cls.from_context, "__globals__", {}) + # Get type hints with forward references resolved and extra info preserved + type_hints = get_type_hints(cls.from_context, globalns, None, include_extras=True) + + # Check if the method accepts arbitrary keyword arguments (**kwargs) + has_kwarg_arg = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in init_signature.parameters.values()) + + # Initialize fields dictionary with common fields for all models + fields = { + # Module name field with a Literal type restricting it to this specific class name + "module_name": (Literal[cls.name], Field(...)), + # Optional field for number of trials in hyperparameter optimization + "n_trials": (PositiveInt | None, Field(None, description="Number of trials")), + # Config field to control extra fields behavior based on kwargs presence + "model_config": (ConfigDict, ConfigDict(extra="allow" if has_kwarg_arg else "forbid")), + } + + # Process each parameter from the method signature + for param_name, param in init_signature.parameters.items(): + # Skip self, cls, context parameters and **kwargs + if param_name in ("self", "cls", "context") or param.kind == inspect.Parameter.VAR_KEYWORD: + continue + + # Get the parameter's type annotation, defaulting to Any if not specified + param_type: TypeAlias = type_hints.get(param_name, Any) # type: ignore[valid-type] # noqa: PYI042 + + # Create a Field with default value if provided, otherwise make it required + field = Field(default=[param.default]) if param.default is not inspect.Parameter.empty else Field(...) + + # Check if this parameter should have an Optuna search space + search_type = get_optuna_class(param_type) + + if search_type is None: + # Regular parameter: use a list of the parameter's type + fields[param_name] = (list[param_type], field) + else: + # Parameter eligible for optimization: allow either list of values or search space + fields[param_name] = (list[param_type] | search_type, field) + + # Generate a name for the model class + model_name = f"{cls.__name__}InitModel" + + # Dynamically create a Pydantic model class for this module + models[cls.name] = type( + model_name, + (BaseModel,), # Inherit from BaseModel + { + # Set type annotations for all fields + "__annotations__": {k: v[0] for k, v in fields.items()}, + # Set field objects for all fields + **{k: v[1] for k, v in fields.items()}, + }, + ) + + # Return a union type of all models and the dictionary of models + return Union[tuple(models.values())], models # type: ignore[return-value] # noqa: UP007 + + +DecisionSearchSpaceType, DecisionNodesBaseModels = generate_models_and_union_type_for_classes( + list(DecisionNodeInfo.modules_available.values()) +) +DecisionMetrics = Literal[tuple(DecisionNodeInfo.metrics_available.keys())] # type: ignore[valid-type] + + +class DecisionNodeValidator(BaseModel): + """Search space configuration for the Decision node.""" + + node_type: NodeType = NodeType.decision + target_metric: DecisionMetrics # type: ignore[valid-type] + metrics: list[DecisionMetrics] | None = None # type: ignore[valid-type] + search_space: list[DecisionSearchSpaceType] # type: ignore[valid-type] + + +EmbeddingSearchSpaceType, EmbeddingBaseModels = generate_models_and_union_type_for_classes( + list(EmbeddingNodeInfo.modules_available.values()) +) +EmbeddingMetrics: TypeAlias = Literal[tuple(EmbeddingNodeInfo.metrics_available.keys())] # type: ignore[valid-type] + + +class EmbeddingNodeValidator(BaseModel): + """Search space configuration for the Embedding node.""" + + node_type: NodeType = NodeType.embedding + target_metric: EmbeddingMetrics + metrics: list[EmbeddingMetrics] | None = None + search_space: list[EmbeddingSearchSpaceType] # type: ignore[valid-type] + + +ScoringSearchSpaceType, ScoringNodesBaseModels = generate_models_and_union_type_for_classes( + list(ScoringNodeInfo.modules_available.values()) +) +ScoringMetrics: TypeAlias = Literal[tuple(ScoringNodeInfo.metrics_available.keys())] # type: ignore[valid-type] + + +class ScoringNodeValidator(BaseModel): + """Search space configuration for the Scoring node.""" + + node_type: NodeType = NodeType.scoring + target_metric: ScoringMetrics + metrics: list[ScoringMetrics] | None = None + search_space: list[ScoringSearchSpaceType] # type: ignore[valid-type] + + +RegexpSearchSpaceType, RegexNodesBaseModels = generate_models_and_union_type_for_classes( + list(RegexNodeInfo.modules_available.values()) +) +RegexpMetrics: TypeAlias = Literal[tuple(RegexNodeInfo.metrics_available.keys())] # type: ignore[valid-type] + + +class RegexNodeValidator(BaseModel): + """Search space configuration for the Regexp node.""" + + node_type: NodeType = NodeType.regex + target_metric: RegexpMetrics + metrics: list[RegexpMetrics] | None = None + search_space: list[RegexpSearchSpaceType] # type: ignore[valid-type] + + +NodeValidatorType: TypeAlias = ( + EmbeddingNodeValidator | ScoringNodeValidator | DecisionNodeValidator | RegexNodeValidator +) +SearchSpaceType: TypeAlias = ( + DecisionSearchSpaceType | EmbeddingSearchSpaceType | ScoringSearchSpaceType | RegexpSearchSpaceType # type: ignore[valid-type] +) + + +class SearchSpaceConfig(RootModel[list[SearchSpaceType]]): + """Search space configuration.""" + + def __iter__( + self, + ) -> Iterator[SearchSpaceType]: + """Iterate over the root.""" + return iter(self.root) + + def __getitem__(self, item: int) -> SearchSpaceType: + """To get item directly from the root. + + :param item: Index + + :return: Item + """ + return self.root[item] + + @model_validator(mode="before") + @classmethod + def validate_nodes(cls, data: list[Any]) -> list[Any]: # noqa: C901 + """Validate the search space configuration. + + Args: + data: List of search space configurations. + + Returns: + List of validated search space configurations. + """ + error_message = "" + for i, item in enumerate(data): + if isinstance(item, BaseModel): + continue + if not isinstance(item, dict): + msg = "Each search space configuration must be a dictionary." + raise TypeError(msg) + node_name = item.get("module_name") + if node_name is None: + error_message += f"Search space configuration at index {i} is missing 'module_name'.\n" + continue + + if node_name in DecisionNodesBaseModels: + node_class = DecisionNodesBaseModels[node_name] + elif node_name in EmbeddingBaseModels: + node_class = EmbeddingBaseModels[node_name] + elif node_name in ScoringNodesBaseModels: + node_class = ScoringNodesBaseModels[node_name] + elif node_name in RegexNodesBaseModels: + node_class = RegexNodesBaseModels[node_name] + else: + error_message += f"Unknown node type '{item['node_type']}' at index {i}.\n" + break + try: + node_class(**item) + except ValidationError as e: + error_message += f"Search space configuration at index {i} {node_name} is invalid: {e}\n" + continue + if len(error_message) > 0: + raise TypeError(error_message) + return data + + +class OptimizationSearchSpaceConfig(RootModel[list[NodeValidatorType]]): + """Optimizer configuration.""" + + def __iter__( + self, + ) -> Iterator[NodeValidatorType]: + """Iterate over the root.""" + return iter(self.root) + + def __getitem__(self, item: int) -> NodeValidatorType: + """To get item directly from the root. + + :param item: Index + + :return: Item + """ + return self.root[item] + + @model_validator(mode="before") + @classmethod + def validate_nodes(cls, data: list[Any]) -> list[Any]: # noqa: PLR0912,C901 + """Validate the search space configuration. + + Args: + data: List of search space configurations. + + Returns: + List of validated search space configurations. + """ + error_message = "" + for i, item in enumerate(data): + if isinstance(item, BaseModel): + continue + if not isinstance(item, dict): + msg = "Each search space configuration must be a dictionary." + raise TypeError(msg) + if "node_type" not in item: + msg = "Each search space configuration must have a 'node_type' key." + raise TypeError(msg) + if not isinstance(item.get("search_space"), list): + msg = "Each search space configuration must have a 'search_space' key of type list." + raise TypeError(msg) + for search_space in item["search_space"]: + node_name = search_space.get("module_name") + if node_name is None: + error_message += f"Search space configuration at index {i} is missing 'module_name'.\n" + continue + if item["node_type"] == NodeType.decision.value: + node_class = DecisionNodesBaseModels[node_name] + elif item["node_type"] == NodeType.embedding.value: + node_class = EmbeddingBaseModels[node_name] + elif item["node_type"] == NodeType.scoring.value: + node_class = ScoringNodesBaseModels[node_name] + elif item["node_type"] == NodeType.regex.value: + node_class = RegexNodesBaseModels[node_name] + else: + error_message += f"Unknown node type '{item['node_type']}' at index {i}.\n" + break + + try: + node_class(**search_space) + except ValidationError as e: + error_message += f"Search space configuration at index {i} {node_name} is invalid: {e}\n" + continue + if len(error_message) > 0: + raise TypeError(error_message) + return data diff --git a/docs/optimizer_config.schema.json b/docs/optimizer_config.schema.json index 938d10de0..192e86c5c 100644 --- a/docs/optimizer_config.schema.json +++ b/docs/optimizer_config.schema.json @@ -66,16 +66,16 @@ "validation_size": { "default": 0.2, "description": "Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split).", - "maximum": 1.0, - "minimum": 0.0, + "maximum": 1, + "minimum": 0, "title": "Validation Size", "type": "number" }, "separation_ratio": { "anyOf": [ { - "maximum": 1.0, - "minimum": 0.0, + "maximum": 1, + "minimum": 0, "type": "number" }, { @@ -342,6 +342,7 @@ }, "search_space": { "items": { + "additionalProperties": true, "type": "object" }, "title": "Search Space", diff --git a/pyproject.toml b/pyproject.toml index 2cbe6273c..77e4f1b70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,45 @@ dependencies = [ "xxhash (>=3.5.0,<4.0.0)", "python-dotenv (>=1.0.1,<2.0.0)", "transformers[torch] (>=4.49.0,<5.0.0)", + "codecarbon (==2.6)", +] + +[project.optional-dependencies] +dev = [ + "tach (>=0.11.3,<1.0.0)", + "ipykernel (>=6.29.5,<7.0.0)", + "ipywidgets (>=8.1.5,<9.0.0)", + "ruff (==0.8.4)", +] +test = [ + "pytest (>=8.3.2,<9.0.0)", + "pytest-cov (>=5.0.0,<6.0.0)", + "coverage (>=7.6.1,<8.0.0)", + "pytest-asyncio (>=0.24.0,<1.0.0)", + "pytest-rerunfailures (>=15.0,<16.0)", + "pytest-xdist (>=3.6.1,<4.0.0)", +] +typing = [ + "mypy (>=1,<2)", + "types-pyyaml (>=6.0.12.20240917,<7.0.0)", + "types-pygments (>=2.18.0.20240506,<3.0.0)", + "types-setuptools (>=75.2.0.20241019,<76.0.0)", + "joblib-stubs (>=1.4.2.5.20240918,<2.0.0)", +] +docs = [ + "sphinx (>=8.1.3,<9.0.0)", + "pydata-sphinx-theme (>=0.16.0,<1.0.0)", + "jupytext (>=1.16.4,<2.0.0)", + "nbsphinx (>=0.9.5,<1.0.0)", + "sphinx-autodoc-typehints (>=2.5.0,<3.0.0)", + "sphinx-copybutton (>=0.5.2,<1.0.0)", + "sphinx-autoapi (>=3.3.3,<4.0.0)", + "ipykernel (>=6.29.5,<7.0.0)", + "tensorboardx (>=2.6.2.2,<3.0.0)", + "sphinx-multiversion (>=0.2.4,<1.0.0)", +] +dspy = [ + "dspy (>=2.6.5,<3.0.0)", ] [project.urls] @@ -56,57 +95,6 @@ Documentation = "https://deeppavlov.github.io/AutoIntent/" "basic-aug" = "autointent.generation.utterances.basic.cli:main" "evolution-aug" = "autointent.generation.utterances.evolution.cli:main" -[tool.poetry.group.dev] -optional = true - -[tool.poetry.group.dev.dependencies] -tach = "^0.11.3" -ipykernel = "^6.29.5" -ipywidgets = "^8.1.5" -ruff = "==0.8.4" - -[tool.poetry.group.test] -optional = true - -[tool.poetry.group.test.dependencies] -pytest = "8.3.2" -pytest-cov = "^5.0.0" -coverage = "^7.6.1" -pytest-asyncio = "^0.24.0" - -[tool.poetry.group.typing] -optional = true - -[tool.poetry.group.typing.dependencies] -mypy = "^1" -types-pyyaml = "^6.0.12.20240917" -types-pygments = "^2.18.0.20240506" -types-setuptools = "^75.2.0.20241019" -joblib-stubs = "^1.4.2.5.20240918" - -[tool.poetry.group.docs] -optional = true - -[tool.poetry.group.docs.dependencies] -sphinx = "^8.1.3" -pydata-sphinx-theme = "^0.16.0" -jupytext = "^1.16.4" -nbsphinx = "^0.9.5" -sphinx-autodoc-typehints = "^2.5.0" -sphinx-copybutton = "^0.5.2" -sphinx-autoapi = "^3.3.3" -ipykernel = "^6.29.5" -tensorboardx = "^2.6.2.2" -sphinx-multiversion = "^0.2.4" - -[tool.poetry.group.dspy] -optional = true - - -[tool.poetry.group.dspy.dependencies] -dspy = "^2.6.5" - - [tool.ruff] line-length = 120 indent-width = 4 @@ -148,16 +136,41 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] minversion = "8.0" -addopts = "-ra" # `--cov` option breaks pycharm's test debugger testpaths = [ "tests", ] pythonpath = "autointent" +# `--cov` option breaks pycharm's test debugger +addopts = """ + -ra + --reruns 3 + --only-rerun requests.exceptions.ReadTimeout + --only-rerun huggingface_hub.errors.HfHubHTTPError + --only-rerun huggingface_hub.errors.LocalEntryNotFoundError + --only-rerun FileNotFoundError + --only-rerun OSError + --durations 5 + --reruns-delay 10 +""" +# --reruns 3 -> # Retry failed tests 3 times +# requests.exceptions.ReadTimeout -> # HF Read timed out +# huggingface_hub.errors.HfHubHTTPError -> # HF is unavailable +# huggingface_hub.errors.LocalEntryNotFoundError -> # Gateway Time-out from HF +# FileNotFoundError -> HF Cache is broken +# --reruns-delay 10 -> Delay between reruns in seconds to avoid running into the same issue again [tool.coverage.run] branch = true omit = [ "__init__.py", + "*/site-packages/*", + "*/dist-packages/*", + "*/venv/*", + "*/.env/*", + "*/.venv/*", + "*/virtualenv/*", + "*/tests/*", + "*/tmp/*", ] [tool.coverage.paths] @@ -212,3 +225,4 @@ module = [ "autointent.modules.abc.*", ] warn_unreachable = false + diff --git a/tests/assets/configs/multiclass.yaml b/tests/assets/configs/multiclass.yaml index c21eb779a..a8c883b40 100644 --- a/tests/assets/configs/multiclass.yaml +++ b/tests/assets/configs/multiclass.yaml @@ -29,7 +29,7 @@ clf_name: [RandomForestClassifier] n_estimators: [5, 10] - module_name: bert - model_config: + classification_model_config: - model_name: avsolatorio/GIST-small-Embedding-v0 num_train_epochs: [1] batch_size: [8, 16] diff --git a/tests/assets/configs/multilabel.yaml b/tests/assets/configs/multilabel.yaml index f867c6109..a5702eb54 100644 --- a/tests/assets/configs/multilabel.yaml +++ b/tests/assets/configs/multilabel.yaml @@ -25,7 +25,7 @@ clf_name: [RandomForestClassifier] n_estimators: [5, 10] - module_name: bert - model_config: + classification_model_config: - model_name: avsolatorio/GIST-small-Embedding-v0 num_train_epochs: [1] batch_size: [8] diff --git a/tests/callback/test_callback.py b/tests/callback/test_callback.py index bbb094f4a..30d2b0d3c 100644 --- a/tests/callback/test_callback.py +++ b/tests/callback/test_callback.py @@ -26,6 +26,7 @@ def log_value(self, **kwargs: dict[str, Any]) -> None: def log_metrics(self, **kwargs: dict[str, Any]) -> None: metrics = kwargs["metrics"] + metrics = {k: v for k, v in metrics.items() if not k.startswith("emissions/")} for metric_name, metric_value in metrics.items(): if not isinstance(metric_value, str) and np.isnan(metric_value): metrics[metric_name] = None @@ -103,7 +104,14 @@ def test_pipeline_callbacks(dataset): "num": 0, }, ), - ("log_metric", {"metrics": {"retrieval_hit_rate": 1.0}}), + ( + "log_metric", + { + "metrics": { + "retrieval_hit_rate": 1.0, + } + }, + ), ("end_module", {}), ( "start_module", @@ -113,7 +121,14 @@ def test_pipeline_callbacks(dataset): "num": 1, }, ), - ("log_metric", {"metrics": {"retrieval_hit_rate": 1.0}}), + ( + "log_metric", + { + "metrics": { + "retrieval_hit_rate": 1.0, + } + }, + ), ("end_module", {}), ( "start_module", @@ -139,7 +154,15 @@ def test_pipeline_callbacks(dataset): "num": 0, }, ), - ("log_metric", {"metrics": {"scoring_accuracy": 1.0, "scoring_roc_auc": 1.0}}), + ( + "log_metric", + { + "metrics": { + "scoring_accuracy": 1.0, + "scoring_roc_auc": 1.0, + } + }, + ), ("end_module", {}), ( "start_module", @@ -165,7 +188,15 @@ def test_pipeline_callbacks(dataset): "num": 1, }, ), - ("log_metric", {"metrics": {"scoring_accuracy": 1.0, "scoring_roc_auc": 1.0}}), + ( + "log_metric", + { + "metrics": { + "scoring_accuracy": 1.0, + "scoring_roc_auc": 1.0, + } + }, + ), ("end_module", {}), ( "start_module", @@ -189,7 +220,15 @@ def test_pipeline_callbacks(dataset): "num": 0, }, ), - ("log_metric", {"metrics": {"scoring_accuracy": 0.75, "scoring_roc_auc": 1.0}}), + ( + "log_metric", + { + "metrics": { + "scoring_accuracy": 0.75, + "scoring_roc_auc": 1.0, + } + }, + ), ("end_module", {}), ("start_module", {"module_kwargs": {"thresh": 0.5}, "module_name": "threshold", "num": 0}), ( diff --git a/tests/configs/test_combined_config.py b/tests/configs/test_combined_config.py index 41dc5bc7a..81312c7f0 100644 --- a/tests/configs/test_combined_config.py +++ b/tests/configs/test_combined_config.py @@ -74,8 +74,7 @@ def test_invalid_optimizer_config_missing_field(): def test_invalid_optimizer_config_wrong_type(): """Test that an invalid field type raises ValidationError.""" - invalid_config = [ - { + invalid_config = { "node_type": "scoring", "target_metric": "scoring_roc_auc", "search_space": [ @@ -87,7 +86,6 @@ def test_invalid_optimizer_config_wrong_type(): } ], } - ] with pytest.raises(TypeError): NodeOptimizer(**invalid_config) diff --git a/tests/modules/scoring/test_bert.py b/tests/modules/scoring/test_bert.py index 3ef319703..4512cd1fd 100644 --- a/tests/modules/scoring/test_bert.py +++ b/tests/modules/scoring/test_bert.py @@ -1,3 +1,7 @@ +import shutil +import tempfile +from pathlib import Path + import numpy as np import pytest @@ -5,11 +9,56 @@ from autointent.modules import BertScorer +def test_bert_scorer_dump_load(dataset): + """Test that BertScorer can be saved and loaded while preserving predictions.""" + data_handler = DataHandler(dataset) + + # Create and train scorer + scorer_original = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer_original.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) + + # Test data + test_data = [ + "why is there a hold on my account", + "why is my bank account frozen", + ] + + # Get predictions before saving + predictions_before = scorer_original.predict(test_data) + + # Create temp directory and save model + temp_dir_path = Path(tempfile.mkdtemp(prefix="bert_scorer_test_")) + try: + # Save the model + scorer_original.dump(str(temp_dir_path)) + + # Create a new scorer and load saved model + scorer_loaded = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer_loaded.load(str(temp_dir_path)) + + # Verify model and tokenizer are loaded + assert hasattr(scorer_loaded, "_model") + assert scorer_loaded._model is not None + assert hasattr(scorer_loaded, "_tokenizer") + assert scorer_loaded._tokenizer is not None + + # Get predictions after loading + predictions_after = scorer_loaded.predict(test_data) + + # Verify predictions match + assert predictions_before.shape == predictions_after.shape + np.testing.assert_allclose(predictions_before, predictions_after, atol=1e-6) + + finally: + # Clean up + shutil.rmtree(temp_dir_path, ignore_errors=True) # workaround for windows permission error + + def test_bert_prediction(dataset): """Test that the transformer model can fit and make predictions.""" data_handler = DataHandler(dataset) - scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) @@ -46,7 +95,7 @@ def test_bert_cache_clearing(dataset): """Test that the transformer model properly handles cache clearing.""" data_handler = DataHandler(dataset) - scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) diff --git a/tests/nodes/test_decision.py b/tests/nodes/test_decision.py index 8bc65f820..826ca9949 100644 --- a/tests/nodes/test_decision.py +++ b/tests/nodes/test_decision.py @@ -19,7 +19,7 @@ def test_decision_multiclass(scoring_optimizer_multiclass): "node_type": "decision", "search_space": [ {"module_name": "threshold", "thresh": [0.5]}, - {"module_name": "tunable", "n_trials": [None, 3]}, + {"module_name": "tunable", "n_trials": 3}, { "module_name": "argmax", }, @@ -58,7 +58,7 @@ def test_decision_multilabel(scoring_optimizer_multilabel): "node_type": "decision", "search_space": [ {"module_name": "threshold", "thresh": [0.5]}, - {"module_name": "tunable", "n_trials": [None, 3]}, + {"module_name": "tunable", "n_trials": 3}, {"module_name": "adaptive"}, ], } diff --git a/user_guides/advanced/01_data.py b/user_guides/advanced/01_data.py index 27314a8d4..51778775c 100644 --- a/user_guides/advanced/01_data.py +++ b/user_guides/advanced/01_data.py @@ -6,9 +6,8 @@ """ # %% -import importlib.resources as ires - import datasets +import huggingface_hub from autointent import Dataset @@ -180,7 +179,11 @@ """ # %% -path_to_dataset = ires.files("tests.assets.data").joinpath("clinc_subset.json") +path_to_dataset = huggingface_hub.hf_hub_download( + repo_id="DeepPavlov/clinc150_subset", + filename="clinc_subset.json", + repo_type="dataset", +) dataset = Dataset.from_json(path_to_dataset) # %% [markdown] diff --git a/user_guides/basic_usage/01_data.py b/user_guides/basic_usage/01_data.py index a03f39bb3..f56294a76 100644 --- a/user_guides/basic_usage/01_data.py +++ b/user_guides/basic_usage/01_data.py @@ -6,9 +6,8 @@ """ # %% -import importlib.resources as ires - import datasets +import huggingface_hub from autointent import Dataset @@ -53,7 +52,11 @@ """ # %% -path_to_dataset = ires.files("tests.assets.data").joinpath("clinc_subset_unsplitted.json") +path_to_dataset = huggingface_hub.hf_hub_download( + repo_id="DeepPavlov/clinc150_subset", + filename="clinc_subset_unsplitted.json", + repo_type="dataset", +) dataset = Dataset.from_json(path_to_dataset) # %% [markdown]