Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions plugins/online-data-mixing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,32 @@ Plugin | Description | Depends | Loading | Augmentation | Callbacks

`OnlineMixingDataset` can be imported easily and integrated into existing training loops with minimal changes. A sample custom training loop implementation can be found [here](./artifacts/custom_loop_usage.py). Given code sample uses two instruction tuning datasets and trains `ibm-granite/granite-3.1-2b-instruct` model for next token prediction task.

### Automatic Categorization

When only a single dataset (without category splits) is passed, the dataset will be embedded with a sentence-transformer model and clustered (K-Means by default) to build pseudo categories used by the online data mixer.

```python
from datasets import load_dataset
from fms_acceleration_odm import OnlineMixingDataset

dataset = load_dataset("tatsu-lab/alpaca", split="train[:1%]")
collator = ... # e.g., DataCollatorForLanguageModeling(...)

odm_dataset = OnlineMixingDataset(
dataset_dict=dataset,
collators_dict={"train": collator},
eval_dataset_dict={},
eval_collators_dict={},
auto_categorize_config={
"input_column": "text",
"num_categories": 6,
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
},
)
```

Without an explicit `num_categories`, a heuristic based on the square root of the dataset size is used. Additional knobs such as `category_prefix`, `batch_size`, or clustering-specific kwargs can also be provided through `auto_categorize_config`.

## Metrics

All metrics related to the online data mixing will be logged to `odm.jsonl` file in the checkpoint output directory.
Expand Down Expand Up @@ -47,11 +73,9 @@ Rewards | Description
`GRADNORM` | Gradient norm where norms are maintained across categories and are updated based on the latest values and sampled dataset/category. Higher values mean reducing samples from that particular dataset/category.

### Adding a Custom Reward
Custom rewards can be added to the `compute_reward` function and adding it to the `Reward` enum. If the custom reward requires specific set of information from the training loop then `_extract_information_from_state_for_reward` function has to be extended for extracting such information from trainer state. This is member function of `OnlineMixingDataset`.

Custom rewards can be added to the `compute_reward` function and adding it to the `Reward` enum. If the custom reward requires specific set of information from the training loop then `_extract_information_from_state_for_reward` function has to be extended for extracting such information from trainer state. This is member function of `OnlineMixingDataset`.

### Planned TODOs
Please see issue [#153](https://github.com/foundation-model-stack/fms-acceleration/issues/153).



Please see issue [#153](https://github.com/foundation-model-stack/fms-acceleration/issues/153).
70 changes: 37 additions & 33 deletions plugins/online-data-mixing/artifacts/custom_loop_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,46 @@

# Third Party
from accelerate import Accelerator, DataLoaderConfiguration
from datasets import load_dataset
from datasets import load_dataset, DatasetDict
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from functools import partial

# First Party
from fms_acceleration_odm import OnlineMixingDataset
from fms_acceleration_odm.odm.reward import Reward

model_name = "ibm-granite/granite-3.1-2b-instruct"
model_name = "ibm-granite/granite-4.0-h-1b"
output_dir = "./odm_custom_use"
max_steps = 125
batch_size = 12
batch_size = 4
log_file = os.path.join(output_dir, "loss.jsonl")

# odm related
step_idx = 0
update_interval = 1 # every step

# model
model = AutoModelForCausalLM.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.bfloat16)

# tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token


# dataset related
def tokenize_fn(examples):
return tokenizer(
examples["text"], truncation=True, padding="max_length", max_length=128
)
# If you have a single dataset, you can declare it with a single key, pair.
# ODM will auto categorize the dataset into psuedo categories
# If you have multiple categories of dataset, you can declare it with multiple key, pair, eg:
# dataset_dict = {
# "alpaca": load_dataset("tatsu-lab/alpaca", split="train[:1%]"),
# "oasst": load_dataset("hakurei/open-instruct-v1", split="train[:1%]"),
# }


dataset_dict = {
"alpaca": load_dataset("tatsu-lab/alpaca", split="train[:1%]"),
"oasst": load_dataset("hakurei/open-instruct-v1", split="train[:1%]"),
}
dataset_dict = {"alpaca_train": load_dataset("tatsu-lab/alpaca", split="train[90%:]")}
eval_dict = {"alpaca_val": load_dataset("tatsu-lab/alpaca", split="train[:1%]")}


def format_example(example):
Expand All @@ -63,43 +61,49 @@ def format_example(example):
for name in dataset_dict:
dataset_dict[name] = dataset_dict[name].map(format_example)

for name in eval_dict:
eval_dict[name] = eval_dict[name].map(format_example)

dataset_dict = DatasetDict(dataset_dict) # type: ignore
eval_dict = DatasetDict(eval_dict) # type: ignore


def collate_fn(batch, tokenizer):
msgs = [b.pop("text") for b in batch]

def tokenize_fn(examples):
return tokenizer(
examples["text"],
msgs,
truncation=True,
padding="max_length",
max_length=1024,
return_tensors="pt",
)


for name in dataset_dict:
dataset_dict[name] = dataset_dict[name].map(
tokenize_fn,
batched=True,
remove_columns=dataset_dict[name].column_names,
)

collator_dict = {
name: DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
for name in dataset_dict
name: partial(collate_fn, tokenizer=tokenizer) for name in dataset_dict
}

eval_collator_dict = {
name: partial(collate_fn, tokenizer=tokenizer) for name in eval_dict
}

# dataset preparation
dataset = OnlineMixingDataset(
dataset_dict=dataset_dict,
collators_dict=collator_dict,
eval_dataset_dict={},
eval_collators_dict={},
eval_dataset_dict=eval_dict,
eval_collators_dict=eval_collator_dict,
output_dir=output_dir,
reward_type="train_loss",
reward_type=Reward.TRAIN_LOSS,
sampling_interval=batch_size,
auto_categorize_config={"input_column": "text"},
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=None)

# distributed setup
dataloader_config = DataLoaderConfiguration(split_batches=True, dispatch_batches=True)
accelerator = Accelerator(split_batches=True, dataloader_config=dataloader_config)
accelerator = Accelerator(dataloader_config=dataloader_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the reason to remove split_batches=True ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataLoaderConfiguration already includes split_batches. Do we still need to add it to Accelerator?

model, dataloader = accelerator.prepare(model, dataloader)

# training setup
Expand Down
17 changes: 16 additions & 1 deletion plugins/online-data-mixing/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,25 @@ classifiers=[
"Programming Language :: Python :: 3.11",
]

dependencies = ["datasets", "torchdata"]
dependencies = [
"scikit-learn",
"datasets==4.*",
"torchdata==0.11.0",
"sentence-transformers==5.*",
]

[project.optional-dependencies]
dev = ["pytest"]

[tool.hatch.build.targets.wheel]
only-include = ["src/fms_acceleration_odm"]

[tool.hatch.build.targets.wheel.sources]
"src" = ""

[tool.hatch.build.targets.sdist]
include = [
"src",
"pyproject.toml",
"README.md",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Utilities to automatically cluster a dataset into pseudo categories."""

# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Future
from __future__ import annotations

# Standard
from dataclasses import dataclass, field
from logging import getLogger
from typing import Any, Dict, List, Optional
import copy
import math

# Third Party
import torch
from datasets import Dataset, DatasetDict
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
import numpy as np

logger = getLogger(__name__)

AUTO_CATEGORIZATION_COLUMN_NAME = "auto_categorization_odm_raw_text"


@dataclass
class AutoCategorizeConfig: # pylint: disable=too-many-instance-attributes
"""Configuration for sentence-embedding based auto-categorization."""

input_column: str = "text"
num_categories: Optional[int] = None
min_categories: int = 2
max_categories: int = 15
model_name: str = "Qwen/Qwen3-Embedding-0.6B"
batch_size: int = 64
cluster_algo: str = "kmeans"
category_prefix: str = "auto_category"
# Args for loading model
model_kwargs: Dict[str, any] = field(
default_factory=lambda: {
"device_map": "auto",
# "attn_implementation": "flash_attention_2",
}
)
# Args for K means
cluster_kwargs: Dict[str, Any] = field(default_factory=dict)
# If the `input_column`` provided does not contain str
# it is assumed that the data is pre-tokenized
# and the column will first be detokenized using the tokenizer
# before performing k means
tokenizer: Optional[Any] = None


class DatasetAutoCategorizer:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any thoughts on supporting when the dataset is iterable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, at least for the current implementation - no. Clustering would require all the data to be in memory, so with iterable datasets, we would need to fetch all the records and then run clustering.

This auto categorization is only suitable for smaller datasets (sub million).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, do you want to raise an error or something that iterable is currently not supported if the dataset is of that type?

"""Clusters a dataset into pseudo categories using embeddings."""

def __init__(self, config: Optional[AutoCategorizeConfig] = None):
self.config = copy.deepcopy(config) or AutoCategorizeConfig()

def __call__(self, dataset: Dataset) -> DatasetDict:
if isinstance(dataset, torch.utils.data.IterableDataset):
raise NotImplementedError(
"Iteratble (or streaming) datasets are not yet supported for auto categorization."
"Please use a non-iterable dataset."
)

if len(dataset) == 0:
raise ValueError("Cannot auto-categorize an empty dataset")
if self.config.input_column not in dataset.column_names:
raise ValueError(
"Dataset is missing column '{col}'. Provide a input field in "
"auto_categorize_config['input_column'].".format(
col=self.config.input_column
)
)

dataset = self._maybe_detokenize_data(dataset)

num_categories = self._determine_category_count(len(dataset))
logger.info(
"Auto-categorizing %s rows into %s clusters using %s",
len(dataset),
num_categories,
self.config.model_name,
)
embeddings = self._compute_embeddings(dataset)
labels = self._cluster_embeddings(embeddings, num_categories)

if AUTO_CATEGORIZATION_COLUMN_NAME in dataset.column_names:
dataset = dataset.remove_columns(AUTO_CATEGORIZATION_COLUMN_NAME)

return self._build_dataset_dict(dataset, labels)

def _maybe_detokenize_data(self, dataset: Dataset) -> Dataset:
existing_field = self.config.input_column

if isinstance(dataset[existing_field][0], str):
logger.info("Detokenization not needed, text data already provided")
return dataset

assert self.config.tokenizer is not None, (
"Attempting detokenizing the data on column '{%s}' but the tokenizer is not provided",
self.config.input_column,
)
assert AUTO_CATEGORIZATION_COLUMN_NAME not in dataset.column_names, (
"Default detokenizing column '{%s}' is already present in the dataset",
AUTO_CATEGORIZATION_COLUMN_NAME,
)

tokenizer = self.config.tokenizer

dataset = dataset.map(
lambda x: {
AUTO_CATEGORIZATION_COLUMN_NAME: tokenizer.batch_decode(
x[existing_field]
)
},
batched=True,
num_proc=12,
)
self.config.input_column = AUTO_CATEGORIZATION_COLUMN_NAME

return dataset

def _determine_category_count(self, dataset_size: int) -> int:
if self.config.num_categories is not None:
desired = self.config.num_categories
else:
# heuristic: sqrt scaling with dataset size
desired = int(math.sqrt(max(dataset_size, 1)))
desired = max(desired, self.config.min_categories)
desired = min(desired, self.config.max_categories)

# clusters cannot exceed dataset size and must be >=1
desired = max(1, min(dataset_size, desired))
return desired

def _compute_embeddings(self, dataset: Dataset) -> np.ndarray:
model = SentenceTransformer(
self.config.model_name,
model_kwargs=self.config.model_kwargs,
prompts={
"clustering": "Identify the topic or theme based on the text: ",
},
default_prompt_name="clustering",
)

vectors = model.encode(
dataset[self.config.input_column],
convert_to_numpy=True,
show_progress_bar=True,
batch_size=self.config.batch_size,
normalize_embeddings=True,
)
return vectors

def _cluster_embeddings(
self, embeddings: np.ndarray, num_categories: int
) -> np.ndarray:
if self.config.cluster_algo.lower() != "kmeans":
raise ValueError(
"Unsupported clustering algorithm '%s'. Only 'kmeans' is currently supported."
% self.config.cluster_algo
)
kwargs = {"n_init": 10}
kwargs.update(self.config.cluster_kwargs)
model = KMeans(n_clusters=num_categories, **kwargs)

logger.info("Starting %s clustering", self.config.cluster_algo)

return model.fit_predict(embeddings)

def _build_dataset_dict(self, dataset: Dataset, labels: np.ndarray) -> DatasetDict:
grouped_indices: Dict[int, List[int]] = {}
for idx, label in enumerate(labels.tolist()):
grouped_indices.setdefault(label, []).append(idx)
categorized = {}
for label, indices in sorted(grouped_indices.items()):
name = f"{self.config.category_prefix}_{label}"
categorized[name] = dataset.select(indices)
return DatasetDict(categorized)
Loading