Skip to content
117 changes: 111 additions & 6 deletions auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import json
import logging
import multiprocessing
import os
import random
import sys

Expand All @@ -23,6 +25,7 @@
from datasets import Dataset, Features, IterableDataset, Sequence, Value, concatenate_datasets, load_dataset
from torch.utils.data import DataLoader

from . import envs
from .utils import is_local_path, logger

CALIB_DATASETS = {}
Expand Down Expand Up @@ -85,6 +88,30 @@ def apply_chat_template_to_samples(samples, tokenizer, seqlen, system_prompt=Non
return example


def _make_map_fingerprint(dataset, tokenizer, seqlen, apply_chat_template, system_prompt, text_key="text"):
"""Compute a stable fingerprint for Dataset.map() calls.

datasets uses dill to serialize the transform function for cache fingerprinting.
HuggingFace tokenizer objects are not reliably serializable by dill, causing
a random hash to be used each run — which breaks caching entirely.

This function computes a deterministic fingerprint from stable string
identifiers (tokenizer name, seqlen, etc.) so that caching works correctly
and subsequent runs can load from disk instead of re-tokenizing in RAM.
"""
import hashlib

parts = [
getattr(dataset, "_fingerprint", "no_fingerprint"),
getattr(tokenizer, "name_or_path", type(tokenizer).__name__),
str(seqlen),
str(apply_chat_template),
str(system_prompt),
text_key,
]
return hashlib.md5("|".join(parts).encode()).hexdigest()


def get_tokenizer_function(tokenizer, seqlen, apply_chat_template=False, system_prompt=None):
"""Returns a default tokenizer function.

Expand Down Expand Up @@ -154,7 +181,13 @@ def get_pile_dataset(
logger.error(f"Failed to load the dataset: {error_message}")
sys.exit(1)
calib_dataset = calib_dataset.shuffle(seed=seed)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
calib_dataset = calib_dataset.map(
tokenizer_function,
batched=True,
new_fingerprint=_make_map_fingerprint(
calib_dataset, tokenizer, seqlen, apply_chat_template, system_prompt, "text"
),
)

return calib_dataset

Expand Down Expand Up @@ -450,7 +483,13 @@ def default_tokenizer_function(examples):

calib_dataset = load_dataset("madao33/new-title-chinese", split=split)
calib_dataset = calib_dataset.shuffle(seed=seed)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
calib_dataset = calib_dataset.map(
tokenizer_function,
batched=True,
new_fingerprint=_make_map_fingerprint(
calib_dataset, tokenizer, seqlen, apply_chat_template, system_prompt, "content"
),
)

return calib_dataset

Expand Down Expand Up @@ -502,7 +541,13 @@ def get_mbpp_dataset(
import datasets

calib_dataset = datasets.Dataset.from_list(samples)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
calib_dataset = calib_dataset.map(
tokenizer_function,
batched=True,
new_fingerprint=_make_map_fingerprint(
calib_dataset, tokenizer, seqlen, apply_chat_template, system_prompt, "text"
),
)

return calib_dataset

Expand Down Expand Up @@ -571,7 +616,13 @@ def load_local_data(data_path):
import datasets

calib_dataset = datasets.Dataset.from_list(samples)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
calib_dataset = calib_dataset.map(
tokenizer_function,
batched=True,
new_fingerprint=_make_map_fingerprint(
calib_dataset, tokenizer, seqlen, apply_chat_template, system_prompt, "text"
),
)
return calib_dataset


Expand Down Expand Up @@ -641,8 +692,8 @@ def select_dataset(dataset, indices):
return dataset


def get_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, nsamples=512):
"""Generate a dataset for calibration.
def _get_dataset_impl(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, nsamples=512):
"""Internal implementation: generate a dataset for calibration.

Args:
tokenizer (Tokenizer): The tokenizer to use for tokenization.
Expand Down Expand Up @@ -764,6 +815,7 @@ def concat_dataset_element(dataset):
)
if do_concat:
dataset = concat_dataset_element(dataset)

dataset = dataset.filter(filter_func)
if name in data_lens:
dataset = select_dataset(dataset, range(data_lens[name]))
Expand Down Expand Up @@ -829,6 +881,59 @@ def concat_dataset_element(dataset):
return dataset_final


def get_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, nsamples=512):
"""Generate a dataset for calibration.

Uses a subprocess for preprocessing to ensure all temporary memory is fully
reclaimed by the OS when the subprocess exits. The HuggingFace ``datasets``
library automatically caches intermediate results (e.g. ``.map()``,
``.filter()``), so the main process can reload them cheaply after the
subprocess finishes.

Set environment variable ``AR_DISABLE_DATASET_SUBPROCESS=1`` to disable
subprocess mode and run preprocessing in the main process.

Args:
tokenizer: The tokenizer to use for tokenization.
seqlen (int): The exact sequence length.
dataset_name (str, optional): Dataset name(s) separated by commas.
seed (int, optional): Random seed for reproducibility. Defaults to 42.
nsamples (int, optional): Total number of samples to include. Defaults to 512.

Returns:
Dataset: The processed dataset ready for calibration.
"""
# Allow disabling subprocess mode via environment variable
if envs.AR_DISABLE_DATASET_SUBPROCESS:
return _get_dataset_impl(tokenizer, seqlen, dataset_name, seed, nsamples)

# Run preprocessing in a subprocess so all temporary memory is freed on exit.
# The HuggingFace datasets cache is warmed up as a side effect.
logger.info("Preprocessing calibration dataset in a subprocess to avoid memory leaks...")

try:
if os.name == "nt":
raise OSError("fork is not available on Windows")

ctx = multiprocessing.get_context("fork")
p = ctx.Process(
target=_get_dataset_impl,
args=(tokenizer, seqlen, dataset_name, seed, nsamples),
)
p.start()
p.join()

if p.exitcode != 0:
raise RuntimeError(f"Dataset preprocessing subprocess exited with code {p.exitcode}")

except Exception as e:
logger.warning(f"Subprocess dataset preprocessing failed ({e}), falling back to in-process mode.")

# (Re-)load the dataset in the main process. When the subprocess
# succeeded the HF datasets cache makes this almost instant.
return _get_dataset_impl(tokenizer, seqlen, dataset_name, seed, nsamples)


def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=512):
"""Generate a DataLoader for calibration using specified parameters.

Expand Down
1 change: 1 addition & 0 deletions auto_round/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
in ["1", "true"],
"AR_OMP_NUM_THREADS": lambda: os.getenv("AR_OMP_NUM_THREADS", None),
"AR_DISABLE_OFFLOAD": lambda: os.getenv("AR_DISABLE_OFFLOAD", "0").lower() in ("1", "true", "yes"),
"AR_DISABLE_DATASET_SUBPROCESS": lambda: os.getenv("AR_DISABLE_DATASET_SUBPROCESS", "0").lower() in ("1", "true"),
}


Expand Down
10 changes: 10 additions & 0 deletions docs/environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ export AR_USE_MODELSCOPE=true
export AR_WORK_SPACE=/path/to/custom/workspace
```

### AR_DISABLE_DATASET_SUBPROCESS
- **Description**: Disables the use of a subprocess for dataset preprocessing. By default, AutoRound uses a subprocess to ensure all temporary memory is reclaimed by the OS.
- **Default**: `False`
- **Valid Values**: `"1"`, `"true"` (case-insensitive) for disabling; any other value for enabling
- **Usage**: Set this to run dataset preprocessing in the main process

```bash
export AR_DISABLE_DATASET_SUBPROCESS=true
```

## Usage Examples

### Setting Environment Variables
Expand Down
Loading