diff --git a/plugins/online-data-mixing/README.md b/plugins/online-data-mixing/README.md index 57ef136b..8644440f 100644 --- a/plugins/online-data-mixing/README.md +++ b/plugins/online-data-mixing/README.md @@ -17,6 +17,32 @@ Plugin | Description | Depends | Loading | Augmentation | Callbacks `OnlineMixingDataset` can be imported easily and integrated into existing training loops with minimal changes. A sample custom training loop implementation can be found [here](./artifacts/custom_loop_usage.py). Given code sample uses two instruction tuning datasets and trains `ibm-granite/granite-3.1-2b-instruct` model for next token prediction task. +### Automatic Categorization + +When only a single dataset (without category splits) is passed, the dataset will be embedded with a sentence-transformer model and clustered (K-Means by default) to build pseudo categories used by the online data mixer. + +```python +from datasets import load_dataset +from fms_acceleration_odm import OnlineMixingDataset + +dataset = load_dataset("tatsu-lab/alpaca", split="train[:1%]") +collator = ... # e.g., DataCollatorForLanguageModeling(...) + +odm_dataset = OnlineMixingDataset( + dataset_dict=dataset, + collators_dict={"train": collator}, + eval_dataset_dict={}, + eval_collators_dict={}, + auto_categorize_config={ + "input_column": "text", + "num_categories": 6, + "model_name": "sentence-transformers/all-MiniLM-L6-v2", + }, +) +``` + +Without an explicit `num_categories`, a heuristic based on the square root of the dataset size is used. Additional knobs such as `category_prefix`, `batch_size`, or clustering-specific kwargs can also be provided through `auto_categorize_config`. + ## Metrics All metrics related to the online data mixing will be logged to `odm.jsonl` file in the checkpoint output directory. @@ -47,11 +73,9 @@ Rewards | Description `GRADNORM` | Gradient norm where norms are maintained across categories and are updated based on the latest values and sampled dataset/category. Higher values mean reducing samples from that particular dataset/category. ### Adding a Custom Reward -Custom rewards can be added to the `compute_reward` function and adding it to the `Reward` enum. If the custom reward requires specific set of information from the training loop then `_extract_information_from_state_for_reward` function has to be extended for extracting such information from trainer state. This is member function of `OnlineMixingDataset`. +Custom rewards can be added to the `compute_reward` function and adding it to the `Reward` enum. If the custom reward requires specific set of information from the training loop then `_extract_information_from_state_for_reward` function has to be extended for extracting such information from trainer state. This is member function of `OnlineMixingDataset`. ### Planned TODOs -Please see issue [#153](https://github.com/foundation-model-stack/fms-acceleration/issues/153). - - +Please see issue [#153](https://github.com/foundation-model-stack/fms-acceleration/issues/153). diff --git a/plugins/online-data-mixing/artifacts/custom_loop_usage.py b/plugins/online-data-mixing/artifacts/custom_loop_usage.py index d9802b5e..4c7f0cba 100644 --- a/plugins/online-data-mixing/artifacts/custom_loop_usage.py +++ b/plugins/online-data-mixing/artifacts/custom_loop_usage.py @@ -8,23 +8,21 @@ # Third Party from accelerate import Accelerator, DataLoaderConfiguration -from datasets import load_dataset +from datasets import load_dataset, DatasetDict from torch.utils.data import DataLoader from tqdm import tqdm -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - DataCollatorForLanguageModeling, -) +from transformers import AutoModelForCausalLM, AutoTokenizer import torch +from functools import partial # First Party from fms_acceleration_odm import OnlineMixingDataset +from fms_acceleration_odm.odm.reward import Reward -model_name = "ibm-granite/granite-3.1-2b-instruct" +model_name = "ibm-granite/granite-4.0-h-1b" output_dir = "./odm_custom_use" max_steps = 125 -batch_size = 12 +batch_size = 4 log_file = os.path.join(output_dir, "loss.jsonl") # odm related @@ -32,7 +30,7 @@ update_interval = 1 # every step # model -model = AutoModelForCausalLM.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16) # tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -40,16 +38,16 @@ # dataset related -def tokenize_fn(examples): - return tokenizer( - examples["text"], truncation=True, padding="max_length", max_length=128 - ) +# If you have a single dataset, you can declare it with a single key, pair. +# ODM will auto categorize the dataset into psuedo categories +# If you have multiple categories of dataset, you can declare it with multiple key, pair, eg: +# dataset_dict = { +# "alpaca": load_dataset("tatsu-lab/alpaca", split="train[:1%]"), +# "oasst": load_dataset("hakurei/open-instruct-v1", split="train[:1%]"), +# } - -dataset_dict = { - "alpaca": load_dataset("tatsu-lab/alpaca", split="train[:1%]"), - "oasst": load_dataset("hakurei/open-instruct-v1", split="train[:1%]"), -} +dataset_dict = {"alpaca_train": load_dataset("tatsu-lab/alpaca", split="train[90%:]")} +eval_dict = {"alpaca_val": load_dataset("tatsu-lab/alpaca", split="train[:1%]")} def format_example(example): @@ -63,43 +61,49 @@ def format_example(example): for name in dataset_dict: dataset_dict[name] = dataset_dict[name].map(format_example) +for name in eval_dict: + eval_dict[name] = eval_dict[name].map(format_example) + +dataset_dict = DatasetDict(dataset_dict) # type: ignore +eval_dict = DatasetDict(eval_dict) # type: ignore + + +def collate_fn(batch, tokenizer): + msgs = [b.pop("text") for b in batch] -def tokenize_fn(examples): return tokenizer( - examples["text"], + msgs, truncation=True, padding="max_length", max_length=1024, + return_tensors="pt", ) -for name in dataset_dict: - dataset_dict[name] = dataset_dict[name].map( - tokenize_fn, - batched=True, - remove_columns=dataset_dict[name].column_names, - ) - collator_dict = { - name: DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) - for name in dataset_dict + name: partial(collate_fn, tokenizer=tokenizer) for name in dataset_dict +} + +eval_collator_dict = { + name: partial(collate_fn, tokenizer=tokenizer) for name in eval_dict } # dataset preparation dataset = OnlineMixingDataset( dataset_dict=dataset_dict, collators_dict=collator_dict, - eval_dataset_dict={}, - eval_collators_dict={}, + eval_dataset_dict=eval_dict, + eval_collators_dict=eval_collator_dict, output_dir=output_dir, - reward_type="train_loss", + reward_type=Reward.TRAIN_LOSS, sampling_interval=batch_size, + auto_categorize_config={"input_column": "text"}, ) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=None) # distributed setup dataloader_config = DataLoaderConfiguration(split_batches=True, dispatch_batches=True) -accelerator = Accelerator(split_batches=True, dataloader_config=dataloader_config) +accelerator = Accelerator(dataloader_config=dataloader_config) model, dataloader = accelerator.prepare(model, dataloader) # training setup diff --git a/plugins/online-data-mixing/pyproject.toml b/plugins/online-data-mixing/pyproject.toml index 5539a681..2be33b88 100644 --- a/plugins/online-data-mixing/pyproject.toml +++ b/plugins/online-data-mixing/pyproject.toml @@ -22,10 +22,25 @@ classifiers=[ "Programming Language :: Python :: 3.11", ] -dependencies = ["datasets", "torchdata"] +dependencies = [ + "scikit-learn", + "datasets==4.*", + "torchdata==0.11.0", + "sentence-transformers==5.*", +] + +[project.optional-dependencies] +dev = ["pytest"] [tool.hatch.build.targets.wheel] only-include = ["src/fms_acceleration_odm"] [tool.hatch.build.targets.wheel.sources] "src" = "" + +[tool.hatch.build.targets.sdist] +include = [ + "src", + "pyproject.toml", + "README.md", +] \ No newline at end of file diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/auto_categorizer.py b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/auto_categorizer.py new file mode 100644 index 00000000..337b49de --- /dev/null +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/auto_categorizer.py @@ -0,0 +1,194 @@ +"""Utilities to automatically cluster a dataset into pseudo categories.""" + +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Future +from __future__ import annotations + +# Standard +from dataclasses import dataclass, field +from logging import getLogger +from typing import Any, Dict, List, Optional +import copy +import math + +# Third Party +import torch +from datasets import Dataset, DatasetDict +from sentence_transformers import SentenceTransformer +from sklearn.cluster import KMeans +import numpy as np + +logger = getLogger(__name__) + +AUTO_CATEGORIZATION_COLUMN_NAME = "auto_categorization_odm_raw_text" + + +@dataclass +class AutoCategorizeConfig: # pylint: disable=too-many-instance-attributes + """Configuration for sentence-embedding based auto-categorization.""" + + input_column: str = "text" + num_categories: Optional[int] = None + min_categories: int = 2 + max_categories: int = 15 + model_name: str = "Qwen/Qwen3-Embedding-0.6B" + batch_size: int = 64 + cluster_algo: str = "kmeans" + category_prefix: str = "auto_category" + # Args for loading model + model_kwargs: Dict[str, any] = field( + default_factory=lambda: { + "device_map": "auto", + # "attn_implementation": "flash_attention_2", + } + ) + # Args for K means + cluster_kwargs: Dict[str, Any] = field(default_factory=dict) + # If the `input_column`` provided does not contain str + # it is assumed that the data is pre-tokenized + # and the column will first be detokenized using the tokenizer + # before performing k means + tokenizer: Optional[Any] = None + + +class DatasetAutoCategorizer: + """Clusters a dataset into pseudo categories using embeddings.""" + + def __init__(self, config: Optional[AutoCategorizeConfig] = None): + self.config = copy.deepcopy(config) or AutoCategorizeConfig() + + def __call__(self, dataset: Dataset) -> DatasetDict: + if isinstance(dataset, torch.utils.data.IterableDataset): + raise NotImplementedError( + "Iteratble (or streaming) datasets are not yet supported for auto categorization." + "Please use a non-iterable dataset." + ) + + if len(dataset) == 0: + raise ValueError("Cannot auto-categorize an empty dataset") + if self.config.input_column not in dataset.column_names: + raise ValueError( + "Dataset is missing column '{col}'. Provide a input field in " + "auto_categorize_config['input_column'].".format( + col=self.config.input_column + ) + ) + + dataset = self._maybe_detokenize_data(dataset) + + num_categories = self._determine_category_count(len(dataset)) + logger.info( + "Auto-categorizing %s rows into %s clusters using %s", + len(dataset), + num_categories, + self.config.model_name, + ) + embeddings = self._compute_embeddings(dataset) + labels = self._cluster_embeddings(embeddings, num_categories) + + if AUTO_CATEGORIZATION_COLUMN_NAME in dataset.column_names: + dataset = dataset.remove_columns(AUTO_CATEGORIZATION_COLUMN_NAME) + + return self._build_dataset_dict(dataset, labels) + + def _maybe_detokenize_data(self, dataset: Dataset) -> Dataset: + existing_field = self.config.input_column + + if isinstance(dataset[existing_field][0], str): + logger.info("Detokenization not needed, text data already provided") + return dataset + + assert self.config.tokenizer is not None, ( + "Attempting detokenizing the data on column '{%s}' but the tokenizer is not provided", + self.config.input_column, + ) + assert AUTO_CATEGORIZATION_COLUMN_NAME not in dataset.column_names, ( + "Default detokenizing column '{%s}' is already present in the dataset", + AUTO_CATEGORIZATION_COLUMN_NAME, + ) + + tokenizer = self.config.tokenizer + + dataset = dataset.map( + lambda x: { + AUTO_CATEGORIZATION_COLUMN_NAME: tokenizer.batch_decode( + x[existing_field] + ) + }, + batched=True, + num_proc=12, + ) + self.config.input_column = AUTO_CATEGORIZATION_COLUMN_NAME + + return dataset + + def _determine_category_count(self, dataset_size: int) -> int: + if self.config.num_categories is not None: + desired = self.config.num_categories + else: + # heuristic: sqrt scaling with dataset size + desired = int(math.sqrt(max(dataset_size, 1))) + desired = max(desired, self.config.min_categories) + desired = min(desired, self.config.max_categories) + + # clusters cannot exceed dataset size and must be >=1 + desired = max(1, min(dataset_size, desired)) + return desired + + def _compute_embeddings(self, dataset: Dataset) -> np.ndarray: + model = SentenceTransformer( + self.config.model_name, + model_kwargs=self.config.model_kwargs, + prompts={ + "clustering": "Identify the topic or theme based on the text: ", + }, + default_prompt_name="clustering", + ) + + vectors = model.encode( + dataset[self.config.input_column], + convert_to_numpy=True, + show_progress_bar=True, + batch_size=self.config.batch_size, + normalize_embeddings=True, + ) + return vectors + + def _cluster_embeddings( + self, embeddings: np.ndarray, num_categories: int + ) -> np.ndarray: + if self.config.cluster_algo.lower() != "kmeans": + raise ValueError( + "Unsupported clustering algorithm '%s'. Only 'kmeans' is currently supported." + % self.config.cluster_algo + ) + kwargs = {"n_init": 10} + kwargs.update(self.config.cluster_kwargs) + model = KMeans(n_clusters=num_categories, **kwargs) + + logger.info("Starting %s clustering", self.config.cluster_algo) + + return model.fit_predict(embeddings) + + def _build_dataset_dict(self, dataset: Dataset, labels: np.ndarray) -> DatasetDict: + grouped_indices: Dict[int, List[int]] = {} + for idx, label in enumerate(labels.tolist()): + grouped_indices.setdefault(label, []).append(idx) + categorized = {} + for label, indices in sorted(grouped_indices.items()): + name = f"{self.config.category_prefix}_{label}" + categorized[name] = dataset.select(indices) + return DatasetDict(categorized) diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py index c52134c9..c7447a55 100644 --- a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py @@ -7,13 +7,14 @@ import random # Third Party -from datasets import DatasetDict +from datasets import Dataset, DatasetDict from torch.utils.data import DataLoader, IterableDataset from torchdata.stateful_dataloader import StatefulDataLoader from tqdm import tqdm import torch # Local +from .auto_categorizer import AutoCategorizeConfig, DatasetAutoCategorizer from .reward import Reward, compute_reward logger = getLogger(__name__) @@ -34,6 +35,7 @@ def __init__( eval_batch_size: int = 5, output_dir="odm", reward_type=Reward.ENTROPY, + auto_categorize_config: Optional[dict | AutoCategorizeConfig] = None, ): """Mixes datasets with sampling ratios learnt using Multi Armed Bandit (MAB) EXP3 and rewards defined. @@ -45,6 +47,8 @@ def __init__( Args: dataset_dict (DatasetDict): keys are category names and values are HF datasets. + If only a single key (dataset) is provided, we will run an auto categorization step. + Refer to `auto_categorize_config` for options regarding auto categorization. collators_dict (dict): collator corresponding to each dataset used while constructing torch dataloader. eval_dataset_dict (DatasetDict): keys are category names and values are HF @@ -61,7 +65,22 @@ def __init__( output_dir (str, optional): output dir to store logs. Defaults to "odm". reward_type (_type_, optional): type of reward to use, more details can be found in compute_reward function. Defaults to Reward.ENTROPY. + auto_categorize_config (dict | AutoCategorizeConfig, optional): + configuration overrides for the auto-categorizer such as text column, + embedding model, cluster count etc. This will only be used if the `dataset_dict` + has only one key. """ + self.auto_categorize = len(dataset_dict.keys()) == 1 + self._auto_categorize_config = self._build_auto_categorize_config( + auto_categorize_config + ) + dataset_dict, collators_dict = self._maybe_auto_categorize_dataset( + dataset_dict, collators_dict, dataset_role="train" + ) + eval_dataset_dict, eval_collators_dict = self._maybe_auto_categorize_dataset( + eval_dataset_dict, eval_collators_dict, dataset_role="eval" + ) + logger.info( """Values set to OnlineMixingDataset dataset_dict: {dataset_dict} @@ -75,6 +94,7 @@ def __init__( eval_batch_size: {eval_batch_size} output_dir: {output_dir} reward_type: {reward_type} + auto_categorize_config: {auto_categorize_config} """.format( dataset_dict=dataset_dict, collators_dict=collators_dict, @@ -87,6 +107,7 @@ def __init__( eval_batch_size=eval_batch_size, output_dir=output_dir, reward_type=reward_type, + auto_categorize_config=auto_categorize_config, ) ) @@ -302,6 +323,58 @@ def _reset_eval_dataloaders(self): else None ) + def _build_auto_categorize_config(self, config): + if isinstance(config, AutoCategorizeConfig): + return config + if config is None: + return AutoCategorizeConfig() + return AutoCategorizeConfig(**config) + + def _maybe_auto_categorize_dataset( + self, dataset_container, collators_dict, dataset_role + ): + if len(dataset_container) != 1: + return dataset_container, collators_dict + + logger.info("Starting auto categorization process") + + dataset_candidate: Dataset = next(iter(dataset_container.values())) + auto_categorizer = DatasetAutoCategorizer(config=self._auto_categorize_config) + categorized = auto_categorizer(dataset=dataset_candidate) + + # We can delete the auto categorizer object since + # it loads a sentence embedding model + del auto_categorizer + torch.cuda.empty_cache() + + collators_dict = self._broadcast_collators_to_auto_categories( + collators_dict, list(categorized.keys()) # type: ignore + ) + logger.info( + "Auto-categorized dataset into %d pseudo categories: %s", + len(categorized), + list(categorized.keys()), + ) + return categorized, collators_dict + + def _broadcast_collators_to_auto_categories( + self, collators_dict: Optional[dict], category_names: list[str] + ) -> dict: + if not category_names: + return collators_dict or {} + if not collators_dict: + return {name: None for name in category_names} + mapping = dict(collators_dict) + if set(mapping.keys()) == set(category_names): + return mapping + if len(mapping) == 1: + collator = next(iter(mapping.values())) + return {name: collator for name in category_names} + raise ValueError( + "Unable to broadcast collators to auto-categorized datasets. Provide a single " + "collator or one entry per generated category." + ) + def _update_sampling_ratio(self, weights) -> list: """Helper function to convert weights to ratio diff --git a/plugins/online-data-mixing/tests/test_auto_categorization.py b/plugins/online-data-mixing/tests/test_auto_categorization.py new file mode 100644 index 00000000..67a1dac2 --- /dev/null +++ b/plugins/online-data-mixing/tests/test_auto_categorization.py @@ -0,0 +1,188 @@ +# Third Party +from datasets import Dataset, DatasetDict +from transformers import AutoTokenizer +import numpy as np +import pytest +import torch + +# First Party +from fms_acceleration_odm import OnlineMixingDataset + +np.random.seed(42) +torch.random.manual_seed(42) + + +class DummySentenceTransformer: + """Simple sentence embedder used to avoid network calls in tests.""" + + def __init__(self, *_, **__): + pass + + def encode(self, texts, **_): + vectors = [] + for _ in texts: + if np.random.uniform() < 0.5: + vectors.append([0.0, 0.0]) + else: + vectors.append([10.0, 10.0]) + return np.asarray(vectors, dtype=np.float32) + + +class DummyIterable(torch.utils.data.IterableDataset): + def __iter__(self): + yield {"x": 1} + + +def _patch_sentence_transformer(monkeypatch): + monkeypatch.setattr( + "fms_acceleration_odm.odm.auto_categorizer.SentenceTransformer", + DummySentenceTransformer, + ) + + +def test_auto_categorize_single_dataset(monkeypatch): + _patch_sentence_transformer(monkeypatch) + dataset = Dataset.from_dict( + {"text": ["cat", "dog", "wolf", "apple", "pear", "banana"]} + ) + dataset_dict = DatasetDict({"train": dataset}) + + def x(): # noqa: E731 - simple identity collator for test + return + + collator = x + odm_dataset = OnlineMixingDataset( + dataset_dict=dataset_dict, + collators_dict={"train": collator}, + eval_dataset_dict={}, + eval_collators_dict={}, + auto_categorize_config={ + "input_column": "text", + "num_categories": 2, + "category_prefix": "cluster", + "model_name": "dummy", + }, + ) + + assert len(odm_dataset.dataset_dict) == 2 + assert set(odm_dataset.category_list) == {"cluster_0", "cluster_1"} + assert set(odm_dataset.collators_dict.keys()) == set( + odm_dataset.dataset_dict.keys() + ) + + total_rows = sum(len(ds) for ds in odm_dataset.dataset_dict.values()) + assert total_rows == len(dataset) + + +def test_auto_categorize_requires_input_column(monkeypatch): + _patch_sentence_transformer(monkeypatch) + dataset = Dataset.from_dict({"content": ["hello", "world"]}) + dataset_dict = DatasetDict({"train": dataset}) + + with pytest.raises(ValueError): + OnlineMixingDataset( + dataset_dict=dataset_dict, + collators_dict={}, + eval_dataset_dict={}, + eval_collators_dict={}, + auto_categorize_config={"input_column": "text", "model_name": "dummy"}, + ) + + +def test_auto_categorize_pretokenized_data_w_tokenizer(monkeypatch): + _patch_sentence_transformer(monkeypatch) + + tok = AutoTokenizer.from_pretrained("openai-community/gpt2") + + batch_size, seq_len = 16, 50 + dataset = Dataset.from_dict( + {"input_ids": torch.randint(0, tok.vocab_size, (batch_size, seq_len))} + ) + dataset_dict = DatasetDict({"train": dataset}) + + def x(): # noqa: E731 - simple identity collator for test + return + + collator = x + odm_dataset = OnlineMixingDataset( + dataset_dict=dataset_dict, + collators_dict={"train": collator}, + eval_dataset_dict={}, + eval_collators_dict={}, + auto_categorize_config={ + "input_column": "input_ids", + "num_categories": 2, + "category_prefix": "cluster", + "model_name": "dummy", + "tokenizer": tok, + }, + ) + + print(odm_dataset.dataset_dict, len(odm_dataset.dataset_dict)) + + assert len(odm_dataset.dataset_dict) == 2 + assert set(odm_dataset.category_list) == {"cluster_0", "cluster_1"} + assert set(odm_dataset.collators_dict.keys()) == set( + odm_dataset.dataset_dict.keys() + ) + + total_rows = sum(len(ds) for ds in odm_dataset.dataset_dict.values()) + assert total_rows == len(dataset) == batch_size + + +def test_auto_categorize_pretokenized_data_wo_tokenizer(monkeypatch): + _patch_sentence_transformer(monkeypatch) + + batch_size, seq_len = 16, 50 + dataset = Dataset.from_dict( + {"input_ids": torch.randint(0, 100, (batch_size, seq_len))} + ) + dataset_dict = DatasetDict({"train": dataset}) + + def x(): # noqa: E731 - simple identity collator for test + return + + collator = x + + with pytest.raises(AssertionError): + _ = OnlineMixingDataset( + dataset_dict=dataset_dict, + collators_dict={"train": collator}, + eval_dataset_dict={}, + eval_collators_dict={}, + auto_categorize_config={ + "input_column": "input_ids", + "num_categories": 2, + "category_prefix": "cluster", + "model_name": "dummy", + }, + ) + + +def test_iterable_dataset_not_supported_auto_categorize(monkeypatch): + + _patch_sentence_transformer(monkeypatch) + + dataset = DummyIterable() + dataset_dict = DatasetDict({"train": dataset}) + + def x(): # noqa: E731 - simple identity collator for test + return + + collator = x + + with pytest.raises(NotImplementedError) as err: + _ = OnlineMixingDataset( + dataset_dict=dataset_dict, + collators_dict={"train": collator}, + eval_dataset_dict={}, + eval_collators_dict={}, + auto_categorize_config={ + "input_column": "input_ids", + "num_categories": 2, + "category_prefix": "cluster", + "model_name": "dummy", + }, + ) + + assert "not yet supported" in str(err.value) diff --git a/plugins/online-data-mixing/tox.ini b/plugins/online-data-mixing/tox.ini index 1a21a899..a93e35f4 100644 --- a/plugins/online-data-mixing/tox.ini +++ b/plugins/online-data-mixing/tox.ini @@ -21,6 +21,7 @@ deps = -e {toxinidir}/../framework pylint>=2.16.2,<=3.1.0 datasets() + .[dev] commands = pylint src tests allowlist_externals = pylint @@ -31,9 +32,9 @@ skip_install = true deps = black>=22.12 isort>=5.11 -commands = - black {posargs:.} - isort {posargs:.} +commands = + black src tests + isort src tests [testenv:build] description = build wheel