Skip to content
117 changes: 111 additions & 6 deletions auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import json
import logging
import multiprocessing
import os
import random
import sys

Expand All @@ -23,6 +25,7 @@
from datasets import Dataset, Features, IterableDataset, Sequence, Value, concatenate_datasets, load_dataset
from torch.utils.data import DataLoader

from . import envs
from .utils import is_local_path, logger

CALIB_DATASETS = {}
Expand Down Expand Up @@ -85,6 +88,30 @@ def apply_chat_template_to_samples(samples, tokenizer, seqlen, system_prompt=Non
return example


def _make_map_fingerprint(dataset, tokenizer, seqlen, apply_chat_template, system_prompt, text_key="text"):
"""Compute a stable fingerprint for Dataset.map() calls.

datasets uses dill to serialize the transform function for cache fingerprinting.
HuggingFace tokenizer objects are not reliably serializable by dill, causing
a random hash to be used each run — which breaks caching entirely.

This function computes a deterministic fingerprint from stable string
identifiers (tokenizer name, seqlen, etc.) so that caching works correctly
and subsequent runs can load from disk instead of re-tokenizing in RAM.
"""
import hashlib

parts = [
getattr(dataset, "_fingerprint", "no_fingerprint"),
getattr(tokenizer, "name_or_path", type(tokenizer).__name__),
str(seqlen),
str(apply_chat_template),
str(system_prompt),
text_key,
]
return hashlib.sha256("|".join(parts).encode()).hexdigest()


def get_tokenizer_function(tokenizer, seqlen, apply_chat_template=False, system_prompt=None):
"""Returns a default tokenizer function.

Expand Down Expand Up @@ -154,7 +181,13 @@ def get_pile_dataset(
logger.error(f"Failed to load the dataset: {error_message}")
sys.exit(1)
calib_dataset = calib_dataset.shuffle(seed=seed)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
calib_dataset = calib_dataset.map(
tokenizer_function,
batched=True,
new_fingerprint=_make_map_fingerprint(
calib_dataset, tokenizer, seqlen, apply_chat_template, system_prompt, "text"
),
)

return calib_dataset

Expand Down Expand Up @@ -450,7 +483,13 @@ def default_tokenizer_function(examples):

calib_dataset = load_dataset("madao33/new-title-chinese", split=split)
calib_dataset = calib_dataset.shuffle(seed=seed)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
calib_dataset = calib_dataset.map(
tokenizer_function,
batched=True,
new_fingerprint=_make_map_fingerprint(
calib_dataset, tokenizer, seqlen, apply_chat_template, system_prompt, "content"
),
)

return calib_dataset

Expand Down Expand Up @@ -502,7 +541,13 @@ def get_mbpp_dataset(
import datasets

calib_dataset = datasets.Dataset.from_list(samples)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
calib_dataset = calib_dataset.map(
tokenizer_function,
batched=True,
new_fingerprint=_make_map_fingerprint(
calib_dataset, tokenizer, seqlen, apply_chat_template, system_prompt, "text"
),
)

return calib_dataset

Expand Down Expand Up @@ -571,7 +616,13 @@ def load_local_data(data_path):
import datasets

calib_dataset = datasets.Dataset.from_list(samples)
calib_dataset = calib_dataset.map(tokenizer_function, batched=True)
calib_dataset = calib_dataset.map(
tokenizer_function,
batched=True,
new_fingerprint=_make_map_fingerprint(
calib_dataset, tokenizer, seqlen, apply_chat_template, system_prompt, "text"
),
)
return calib_dataset


Expand Down Expand Up @@ -641,8 +692,8 @@ def select_dataset(dataset, indices):
return dataset


def get_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, nsamples=512):
"""Generate a dataset for calibration.
def _get_dataset_impl(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, nsamples=512):
"""Internal implementation: generate a dataset for calibration.

Args:
tokenizer (Tokenizer): The tokenizer to use for tokenization.
Expand Down Expand Up @@ -764,6 +815,7 @@ def concat_dataset_element(dataset):
)
if do_concat:
dataset = concat_dataset_element(dataset)

dataset = dataset.filter(filter_func)
if name in data_lens:
dataset = select_dataset(dataset, range(data_lens[name]))
Expand Down Expand Up @@ -829,6 +881,59 @@ def concat_dataset_element(dataset):
return dataset_final


def get_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, nsamples=512):
"""Generate a dataset for calibration.

Uses a subprocess for preprocessing to ensure all temporary memory is fully
reclaimed by the OS when the subprocess exits. The HuggingFace ``datasets``
library automatically caches intermediate results (e.g. ``.map()``,
``.filter()``), so the main process can reload them cheaply after the
subprocess finishes.

Set environment variable ``AR_DISABLE_DATASET_SUBPROCESS=1`` to disable
subprocess mode and run preprocessing in the main process.

Args:
tokenizer: The tokenizer to use for tokenization.
seqlen (int): The exact sequence length.
dataset_name (str, optional): Dataset name(s) separated by commas.
seed (int, optional): Random seed for reproducibility. Defaults to 42.
nsamples (int, optional): Total number of samples to include. Defaults to 512.

Returns:
Dataset: The processed dataset ready for calibration.
"""
# Allow disabling subprocess mode via environment variable
if envs.AR_DISABLE_DATASET_SUBPROCESS:
return _get_dataset_impl(tokenizer, seqlen, dataset_name, seed, nsamples)

# Run preprocessing in a subprocess so all temporary memory is freed on exit.
# The HuggingFace datasets cache is warmed up as a side effect.
logger.info("Preprocessing calibration dataset in a subprocess to avoid memory leaks...")

try:
if os.name == "nt":
raise OSError("fork is not available on Windows")

ctx = multiprocessing.get_context("fork")
p = ctx.Process(
target=_get_dataset_impl,
args=(tokenizer, seqlen, dataset_name, seed, nsamples),
)
p.start()
p.join()

if p.exitcode != 0:
raise RuntimeError(f"Dataset preprocessing subprocess exited with code {p.exitcode}")

except Exception as e:
logger.warning(f"Subprocess dataset preprocessing failed ({e}), falling back to in-process mode.")

# (Re-)load the dataset in the main process. When the subprocess
# succeeded the HF datasets cache makes this almost instant.
return _get_dataset_impl(tokenizer, seqlen, dataset_name, seed, nsamples)


def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=512):
"""Generate a DataLoader for calibration using specified parameters.

Expand Down
1 change: 1 addition & 0 deletions auto_round/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
in ["1", "true"],
"AR_OMP_NUM_THREADS": lambda: os.getenv("AR_OMP_NUM_THREADS", None),
"AR_DISABLE_OFFLOAD": lambda: os.getenv("AR_DISABLE_OFFLOAD", "0").lower() in ("1", "true", "yes"),
"AR_DISABLE_DATASET_SUBPROCESS": lambda: os.getenv("AR_DISABLE_DATASET_SUBPROCESS", "0").lower() in ("1", "true"),
}


Expand Down
12 changes: 12 additions & 0 deletions docs/environments.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# AutoRound Environment Variables Configuration

English | [简体中文](./environments_CN.md)

This document describes the environment variables used by AutoRound for configuration and their usage.

## Overview
Expand Down Expand Up @@ -47,6 +49,16 @@ export AR_USE_MODELSCOPE=true
export AR_WORK_SPACE=/path/to/custom/workspace
```

### AR_DISABLE_DATASET_SUBPROCESS
- **Description**: Disables the use of a subprocess for dataset preprocessing. By default, AutoRound uses a subprocess to ensure all temporary memory is reclaimed by the OS.
- **Default**: `False`
- **Valid Values**: `"1"`, `"true"` (case-insensitive) for disabling; any other value for enabling
- **Usage**: Set this to run dataset preprocessing in the main process

```bash
export AR_DISABLE_DATASET_SUBPROCESS=true
```

## Usage Examples

### Setting Environment Variables
Expand Down
147 changes: 147 additions & 0 deletions docs/environments_CN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# AutoRound 环境变量配置

[English](./environments.md) | 简体中文

本文档介绍 AutoRound 使用的环境变量及其配置说明。

## 概述

AutoRound 通过 `envs.py` 模块提供统一的环境变量管理系统,支持懒加载求值与程序化配置。

## 可用环境变量

### AR_LOG_LEVEL
- **描述**:控制 AutoRound 默认日志级别
- **默认值**:`"INFO"`
- **有效值**:`"TRACE"`、`"DEBUG"`、`"INFO"`、`"WARNING"`、`"ERROR"`、`"CRITICAL"`
- **用途**:通过设置该变量控制 AutoRound 的日志详细程度

```bash
export AR_LOG_LEVEL=DEBUG
```

### AR_ENABLE_COMPILE_PACKING
- **描述**:启用编译打包优化
- **默认值**:`False`(等价于 `"0"`)
- **有效值**:`"1"`、`"true"`、`"yes"`(不区分大小写)表示启用;其他值表示禁用
- **用途**:启用后可在将 FP4 张量打包为 `uint8` 时获得性能优化

```bash
export AR_ENABLE_COMPILE_PACKING=1
```

### AR_USE_MODELSCOPE
- **描述**:控制是否使用 ModelScope 下载模型
- **默认值**:`False`
- **有效值**:`"1"`、`"true"`(不区分大小写)表示启用;其他值表示禁用
- **用途**:启用后将使用 ModelScope 替代 Hugging Face Hub 下载模型

```bash
export AR_USE_MODELSCOPE=true
```

### AR_WORK_SPACE
- **描述**:设置 AutoRound 操作的工作目录
- **默认值**:`"ar_work_space"`
- **用途**:指定 AutoRound 存储临时文件和输出结果的自定义目录

```bash
export AR_WORK_SPACE=/path/to/custom/workspace
```

### AR_DISABLE_DATASET_SUBPROCESS
- **描述**:禁用子进程方式进行数据集预处理。默认情况下,AutoRound 使用子进程确保所有临时内存在进程退出后被操作系统回收。
- **默认值**:`False`
- **有效值**:`"1"`、`"true"`(不区分大小写)表示禁用子进程;其他值表示启用子进程
- **用途**:设置后数据集预处理将在主进程中运行

```bash
export AR_DISABLE_DATASET_SUBPROCESS=true
```

## 使用示例

### 设置环境变量

#### 通过 Shell 命令
```bash
# 将日志级别设置为 DEBUG
export AR_LOG_LEVEL=DEBUG

# 启用编译打包
export AR_ENABLE_COMPILE_PACKING=1

# 使用 ModelScope 下载模型
export AR_USE_MODELSCOPE=true

# 设置自定义工作目录
export AR_WORK_SPACE=/tmp/autoround_workspace
```

#### 通过 Python 代码
```python
from auto_round.envs import set_config

# 同时配置多个环境变量
set_config(
AR_LOG_LEVEL="DEBUG",
AR_USE_MODELSCOPE=True,
AR_ENABLE_COMPILE_PACKING=True,
AR_WORK_SPACE="/tmp/autoround_workspace",
)
```

### 查看环境变量

#### 通过 Python 代码
```python
from auto_round import envs

# 访问环境变量(懒加载求值)
log_level = envs.AR_LOG_LEVEL
use_modelscope = envs.AR_USE_MODELSCOPE
enable_packing = envs.AR_ENABLE_COMPILE_PACKING
workspace = envs.AR_WORK_SPACE

print(f"日志级别: {log_level}")
print(f"使用 ModelScope: {use_modelscope}")
print(f"启用编译打包: {enable_packing}")
print(f"工作目录: {workspace}")
```

#### 检查变量是否显式设置
```python
from auto_round.envs import is_set

# 检查环境变量是否被显式设置
if is_set("AR_LOG_LEVEL"):
print("AR_LOG_LEVEL 已被显式设置")
else:
print("AR_LOG_LEVEL 正在使用默认值")
```

### AR_DISABLE_OFFLOAD
- **描述**:强制禁用 `OffloadManager` 中的权重卸载功能。在开发和调试时可跳过所有卸载/重载开销。
- **默认值**:`False`(等价于 `"0"`)
- **有效值**:`"1"`、`"true"`、`"yes"`(不区分大小写)表示禁用卸载;其他值保持默认行为
- **用途**:设置后将完全绕过权重卸载

```bash
export AR_DISABLE_OFFLOAD=1
```

## 配置最佳实践

1. **开发环境**:设置 `AR_LOG_LEVEL=TRACE` 或 `AR_LOG_LEVEL=DEBUG` 以获取详细日志
2. **生产环境**:使用 `AR_LOG_LEVEL=WARNING` 或 `AR_LOG_LEVEL=ERROR` 减少日志噪声
3. **中国用户**:建议设置 `AR_USE_MODELSCOPE=true` 以获得更好的模型下载速度
4. **性能优化**:如有足够算力,可启用 `AR_ENABLE_COMPILE_PACKING=1`
5. **自定义工作目录**:将 `AR_WORK_SPACE` 设置为磁盘空间充足的目录

## 注意事项

- 环境变量采用懒加载方式,仅在首次访问时读取
- `set_config()` 函数提供了便捷的程序化多变量配置方式
- `AR_USE_MODELSCOPE` 的布尔值会自动转换为适当的字符串表示
- 所有环境变量名称区分大小写
- 通过 `set_config()` 所做的修改将影响当前进程及其子进程
Loading