Skip to content
121 changes: 119 additions & 2 deletions auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import json
import logging
import multiprocessing
import os
import random
import shutil
import sys

logging.getLogger("datasets").setLevel(logging.WARNING)
Expand All @@ -23,6 +27,7 @@
from datasets import Dataset, Features, IterableDataset, Sequence, Value, concatenate_datasets, load_dataset
from torch.utils.data import DataLoader

from . import envs
from .utils import is_local_path, logger

CALIB_DATASETS = {}
Expand Down Expand Up @@ -641,8 +646,36 @@ def select_dataset(dataset, indices):
return dataset


def get_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, nsamples=512):
"""Generate a dataset for calibration.
_DATASET_CACHE_VERSION = "v1"


def _compute_dataset_cache_key(tokenizer, seqlen, dataset_name, seed, nsamples):
"""Compute a deterministic cache key for dataset preprocessing."""
parts = [
_DATASET_CACHE_VERSION,
getattr(tokenizer, "name_or_path", type(tokenizer).__name__),
str(getattr(tokenizer, "vocab_size", 0)),
str(seqlen),
dataset_name,
str(seed),
str(nsamples),
]
return hashlib.sha256("|".join(parts).encode()).hexdigest()[:16]


def _get_dataset_cache_path(cache_key):
"""Get the cache directory path for a given cache key."""
cache_root = envs.AR_DATASET_CACHE_DIR
return os.path.join(cache_root, cache_key)


def _is_cache_valid(cache_path):
"""Check if a dataset cache directory is valid and complete."""
return os.path.isdir(cache_path) and os.path.exists(os.path.join(cache_path, "_done"))


def _get_dataset_impl(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, nsamples=512):
"""Internal implementation: generate a dataset for calibration.

Args:
tokenizer (Tokenizer): The tokenizer to use for tokenization.
Expand Down Expand Up @@ -764,6 +797,7 @@ def concat_dataset_element(dataset):
)
if do_concat:
dataset = concat_dataset_element(dataset)

dataset = dataset.filter(filter_func)
if name in data_lens:
dataset = select_dataset(dataset, range(data_lens[name]))
Expand Down Expand Up @@ -829,6 +863,89 @@ def concat_dataset_element(dataset):
return dataset_final


def _subprocess_worker(tokenizer, seqlen, dataset_name, seed, nsamples, cache_path):
"""Worker function executed in a child process.

Runs the full dataset preprocessing pipeline and saves the result to disk.
When this process exits, all preprocessing memory is reclaimed by the OS.
"""
dataset = _get_dataset_impl(tokenizer, seqlen, dataset_name, seed, nsamples)
# Clean up any partial cache from a previous failed run
if os.path.exists(cache_path):
shutil.rmtree(cache_path, ignore_errors=True)
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
dataset.save_to_disk(cache_path)
# Write a marker file to indicate the cache is complete
with open(os.path.join(cache_path, "_done"), "w") as f:
f.write("done")


def get_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, nsamples=512):
"""Generate a dataset for calibration.

Uses a subprocess for preprocessing to ensure all temporary memory is fully
reclaimed by the OS when the subprocess exits. Results are cached to disk
so subsequent runs with the same parameters load instantly.

Set environment variable ``AR_DISABLE_DATASET_SUBPROCESS=1`` to disable
subprocess mode and run preprocessing in the main process.
Set environment variable ``AR_DATASET_CACHE_DIR`` to customize the cache
directory (default: ``~/.cache/auto_round/datasets/``).

Args:
tokenizer: The tokenizer to use for tokenization.
seqlen (int): The exact sequence length.
dataset_name (str, optional): Dataset name(s) separated by commas.
seed (int, optional): Random seed for reproducibility. Defaults to 42.
nsamples (int, optional): Total number of samples to include. Defaults to 512.

Returns:
Dataset: The processed dataset ready for calibration.
"""
# Allow disabling subprocess mode via environment variable
if envs.AR_DISABLE_DATASET_SUBPROCESS:
return _get_dataset_impl(tokenizer, seqlen, dataset_name, seed, nsamples)

cache_key = _compute_dataset_cache_key(tokenizer, seqlen, dataset_name, seed, nsamples)
cache_path = _get_dataset_cache_path(cache_key)

# Check if valid cache exists
if _is_cache_valid(cache_path):
logger.info(f"Loading cached calibration dataset from {cache_path}")
return Dataset.load_from_disk(cache_path)

# Run preprocessing in a subprocess so all temporary memory is freed on exit
logger.info("Preprocessing calibration dataset in a subprocess to avoid memory leaks...")

try:
if os.name == "nt":
raise OSError("fork is not available on Windows")

ctx = multiprocessing.get_context("fork")
p = ctx.Process(
target=_subprocess_worker,
args=(tokenizer, seqlen, dataset_name, seed, nsamples, cache_path),
)
p.start()
p.join()

if p.exitcode != 0:
raise RuntimeError(f"Dataset preprocessing subprocess exited with code {p.exitcode}")

if not _is_cache_valid(cache_path):
raise RuntimeError("Dataset cache was not created successfully")

logger.info(f"Loading preprocessed calibration dataset from {cache_path}")
return Dataset.load_from_disk(cache_path)

except Exception as e:
logger.warning(f"Subprocess dataset preprocessing failed ({e}), falling back to in-process mode.")
# Clean up any partial cache
if os.path.exists(cache_path):
shutil.rmtree(cache_path, ignore_errors=True)
return _get_dataset_impl(tokenizer, seqlen, dataset_name, seed, nsamples)


def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=512):
"""Generate a DataLoader for calibration using specified parameters.

Expand Down
4 changes: 4 additions & 0 deletions auto_round/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
in ["1", "true"],
"AR_OMP_NUM_THREADS": lambda: os.getenv("AR_OMP_NUM_THREADS", None),
"AR_DISABLE_OFFLOAD": lambda: os.getenv("AR_DISABLE_OFFLOAD", "0").lower() in ("1", "true", "yes"),
"AR_DISABLE_DATASET_SUBPROCESS": lambda: os.getenv("AR_DISABLE_DATASET_SUBPROCESS", "0").lower() in ("1", "true"),
"AR_DATASET_CACHE_DIR": lambda: os.getenv(
"AR_DATASET_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "auto_round", "datasets")
),
}


Expand Down
19 changes: 19 additions & 0 deletions docs/environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,25 @@ export AR_USE_MODELSCOPE=true
export AR_WORK_SPACE=/path/to/custom/workspace
```

### AR_DISABLE_DATASET_SUBPROCESS
- **Description**: Disables the use of a subprocess for dataset preprocessing. By default, AutoRound uses a subprocess to ensure all temporary memory is reclaimed by the OS.
- **Default**: `False`
- **Valid Values**: `"1"`, `"true"` (case-insensitive) for disabling; any other value for enabling
- **Usage**: Set this to run dataset preprocessing in the main process

```bash
export AR_DISABLE_DATASET_SUBPROCESS=true
```

### AR_DATASET_CACHE_DIR
- **Description**: Sets the cache directory for preprocessed datasets.
- **Default**: `"~/.cache/auto_round/datasets/"`
- **Usage**: Specify a custom directory to store cached datasets

```bash
export AR_DATASET_CACHE_DIR=/path/to/custom/cache
```

## Usage Examples

### Setting Environment Variables
Expand Down
Loading