Skip to content

Commit 727e573

Browse files
authored
refine sampler (#1077)
1 parent c4a1479 commit 727e573

File tree

3 files changed

+97
-36
lines changed

3 files changed

+97
-36
lines changed

auto_round/compressors/base.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from auto_round import envs
3434
from auto_round.auto_scheme.gen_auto_scheme import AutoScheme
3535
from auto_round.compressors.utils import (
36+
IndexSampler,
3637
block_forward,
3738
check_need_act_calibration,
3839
check_skippable_keywords,
@@ -196,7 +197,7 @@ def __init__(
196197
disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0). Defaults to False.
197198
enable_alg_ext (bool, optional): Enable algorithm extension (primarily for INT2). Defaults to False.
198199
**kwargs: Backward compatible options:
199-
- enable_alg_ext, quant_lm_head, lr, lr_scheduler, sampler, not_use_best_mse, dynamic_max_gap,
200+
- enable_alg_ext, quant_lm_head, lr, lr_scheduler, not_use_best_mse, dynamic_max_gap,
200201
super_group_size, super_bits, scale_dtype ("fp16" etc.),
201202
nblocks, to_quant_block_names,
202203
enable_norm_bias_tuning, enable_quanted_input,
@@ -259,7 +260,6 @@ def __init__(
259260
enable_minmax_tuning = kwargs.pop("enable_minmax_tuning", True)
260261
minmax_lr = kwargs.pop("minmax_lr", None)
261262
lr_scheduler = kwargs.pop("lr_scheduler", None)
262-
sampler = kwargs.pop("sampler", "rand")
263263
not_use_best_mse = kwargs.pop("not_use_best_mse", False)
264264
dynamic_max_gap = kwargs.pop("dynamic_max_gap", -1)
265265
nblocks = kwargs.pop("nblocks", 1)
@@ -350,7 +350,6 @@ def __init__(
350350
self.lr = lr
351351
self.minmax_lr = minmax_lr or self.lr
352352
self.enable_alg_ext = enable_alg_ext
353-
self.sampler = sampler
354353
self.not_use_best_mse = not_use_best_mse
355354
self.dynamic_max_gap = dynamic_max_gap
356355
self.lr_scheduler = lr_scheduler
@@ -2487,29 +2486,33 @@ def _quantize_layer(
24872486
scaler = self._get_scaler() # pylint: disable=assignment-from-none
24882487
init_loss = None
24892488
gradient_accumulate_steps = self.batch_size # Force to low gpu
2490-
batch_size = 1 # Force to low gpu
2491-
global_batch_size = batch_size * gradient_accumulate_steps
2492-
global_batch_size = min(nsamples, global_batch_size)
2493-
if self.sampler != "rand":
2494-
whole_indices = torch.randperm(nsamples)[:global_batch_size]
2489+
24952490
total_loss = 0
24962491
num_elm = 1
24972492
mse_reduction = "mean"
24982493
if gradient_accumulate_steps != 1:
24992494
mse_reduction = "sum"
25002495
mse_loss = torch.nn.MSELoss(reduction=mse_reduction).to(device)
2496+
batch_size = 1 # Force to low gpu
2497+
global_batch_size = self.batch_size * gradient_accumulate_steps
2498+
global_batch_size = min(nsamples, global_batch_size)
2499+
if gradient_accumulate_steps != 1 and not self.attention_mask:
2500+
whole_indices = torch.arange(global_batch_size)
2501+
if q_inputs is not None:
2502+
num_elm = self._get_current_num_elm(q_inputs, whole_indices)
2503+
else:
2504+
num_elm = self._get_current_num_elm(inputs, whole_indices)
2505+
2506+
index_sampler = IndexSampler(nsamples, global_batch_size)
25012507

25022508
for i in range(self.iters):
25032509
total_loss = 0
2504-
if self.sampler == "rand":
2505-
whole_indices = torch.randperm(nsamples)[:global_batch_size]
2506-
if gradient_accumulate_steps != 1:
2507-
if q_inputs is not None:
2508-
num_elm = self._get_current_num_elm(q_inputs, whole_indices)
2509-
else:
2510-
num_elm = self._get_current_num_elm(inputs, whole_indices)
2510+
global_indices = index_sampler.next_batch()
2511+
if self.attention_mask:
2512+
num_elm = self._get_non_zero_cnt(self.attention_mask, global_indices)
2513+
25112514
for tmp_step in range(gradient_accumulate_steps):
2512-
indices = whole_indices[tmp_step * batch_size : (tmp_step + 1) * batch_size]
2515+
indices = global_indices[tmp_step * batch_size : (tmp_step + 1) * batch_size]
25132516
if q_inputs is not None:
25142517
current_input = [q_inputs[i] for i in indices]
25152518
current_input = torch.cat(current_input, dim=0).to(device)
@@ -2551,7 +2554,7 @@ def _quantize_layer(
25512554
loss = mse_loss( # pylint: disable=not-callable
25522555
output_q.to(torch.float32), current_output.to(torch.float32)
25532556
)
2554-
2557+
num_elm = 1 if num_elm <= 0 else num_elm
25552558
total_loss += loss.item() / num_elm
25562559

25572560
self._scale_loss_and_backward(scaler, loss)
@@ -2615,6 +2618,13 @@ def _get_current_num_elm(
26152618
current_input_ids = [input_ids[i] for i in indices]
26162619
return sum(id.numel() for id in current_input_ids)
26172620

2621+
def _get_non_zero_cnt(self, tensor: list[torch.Tensor], indices: list[int]) -> int:
2622+
current_tensors = [tensor[i] for i in indices]
2623+
non_zero_cnt = 0
2624+
for t in current_tensors:
2625+
non_zero_cnt += torch.count_nonzero(t).item()
2626+
return non_zero_cnt
2627+
26182628
def quantize_block(
26192629
self,
26202630
block: torch.nn.Module,
@@ -2808,7 +2818,7 @@ def _quantize_block(
28082818
f"layers in the block"
28092819
)
28102820
logger.info(dump_info)
2811-
unwrapper_block(block, {}) # TODO Quant layer should change
2821+
unwrapper_block(block, {})
28122822
mv_module_from_gpu(block)
28132823
return output, output
28142824

@@ -2823,11 +2833,6 @@ def _quantize_block(
28232833
nsamples = len(input_ids["hidden_states"])
28242834
else:
28252835
nsamples = len(input_ids)
2826-
2827-
global_batch_size = self.batch_size * self.gradient_accumulate_steps
2828-
global_batch_size = min(nsamples, global_batch_size)
2829-
if self.sampler != "rand":
2830-
whole_indices = torch.randperm(nsamples)[:global_batch_size]
28312836
last_best_iter = 0
28322837
best_loss = torch.finfo(torch.float).max
28332838
num_elm = 1
@@ -2839,30 +2844,31 @@ def _quantize_block(
28392844
init_loss = None
28402845
best_params = {}
28412846
total_loss = 0
2847+
global_batch_size = self.batch_size * self.gradient_accumulate_steps
2848+
global_batch_size = min(nsamples, global_batch_size)
28422849
# We assume the block input and output shape is same
2843-
if self.gradient_accumulate_steps != 1:
2850+
if self.gradient_accumulate_steps != 1 and not self.attention_mask:
28442851
whole_indices = torch.arange(global_batch_size)
28452852
num_elm = self._get_current_num_elm(input_ids, whole_indices)
28462853

2854+
index_sampler = IndexSampler(nsamples, global_batch_size)
2855+
batch_size = self.batch_size
28472856
for i in range(self.iters):
28482857
if self.enable_alg_ext and self.data_type.endswith("dq"):
28492858
for n, m in block.named_modules():
28502859
m.cur_iter = i
28512860
total_loss = 0
2852-
if self.sampler == "rand":
2853-
whole_indices = torch.randperm(nsamples)[:global_batch_size]
2861+
global_indices = index_sampler.next_batch()
2862+
if self.attention_mask:
2863+
num_elm = self._get_non_zero_cnt(self.attention_mask, global_indices)
28542864

28552865
for tmp_step in range(self.gradient_accumulate_steps):
2856-
indices = whole_indices[tmp_step * self.batch_size : (tmp_step + 1) * self.batch_size]
2857-
2866+
indices = global_indices[tmp_step * batch_size : (tmp_step + 1) * batch_size]
28582867
current_output = self._get_current_output(output, indices)
2859-
28602868
current_output = to_device(current_output, loss_device)
2861-
28622869
output_q = self._get_current_q_output(block, input_ids, input_others, indices, device, loss_device)
2863-
28642870
loss = self._get_loss(output_q, current_output, indices, mse_loss, device)
2865-
2871+
num_elm = 1 if num_elm <= 0 else num_elm
28662872
total_loss += loss.item() / num_elm
28672873

28682874
if self.low_gpu_mem_usage and card_0_in_high_risk:

auto_round/compressors/utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import copy
1515
import os
16+
import random
1617
import re
1718
import sys
1819
from dataclasses import asdict, fields
@@ -1310,3 +1311,56 @@ def _flush_current_shard():
13101311
clear_memory()
13111312
except Exception as _cleanup_err:
13121313
logger.warning(f"shard cleanup warning: {_cleanup_err}")
1314+
1315+
1316+
class IndexSampler:
1317+
"""A cyclic sampler that returns shuffled index batches.
1318+
1319+
This sampler maintains internal state so that each call to `next_batch()`
1320+
continues from where it left off. When the remaining number of samples is
1321+
less than `batch_size`, the sampler reshuffles all indices and starts from
1322+
the beginning, discarding the last incomplete batch.
1323+
1324+
Attributes:
1325+
nsamples (int): Total number of samples.
1326+
batch_size (int): Number of indices to return in each batch.
1327+
index (int): Current position in the index list.
1328+
indices (List[int]): Shuffled list of indices.
1329+
"""
1330+
1331+
def __init__(self, nsamples: int, batch_size: int) -> None:
1332+
"""Initializes the sampler.
1333+
1334+
Args:
1335+
nsamples (int): Total number of samples (must be >= batch_size).
1336+
batch_size (int): Number of indices per batch.
1337+
1338+
Raises:
1339+
ValueError: If batch_size is not in the range (0, nsamples].
1340+
"""
1341+
if batch_size <= 0 or batch_size > nsamples:
1342+
raise ValueError("batch_size must be > 0 and <= nsamples")
1343+
1344+
self.nsamples: int = nsamples
1345+
self.batch_size: int = batch_size
1346+
self.index: int = 0
1347+
1348+
self.indices: list[int] = list(range(nsamples))
1349+
random.shuffle(self.indices)
1350+
1351+
def next_batch(self) -> list[int]:
1352+
"""Returns the next batch of shuffled indices.
1353+
1354+
If the remaining indices are fewer than `batch_size`, the sampler
1355+
reshuffles the entire list and starts from the beginning.
1356+
1357+
Returns:
1358+
list[int]: A list of size `batch_size` containing sample indices.
1359+
"""
1360+
if self.index + self.batch_size > self.nsamples:
1361+
random.shuffle(self.indices)
1362+
self.index = 0
1363+
1364+
batch = self.indices[self.index : self.index + self.batch_size]
1365+
self.index += self.batch_size
1366+
return batch

auto_round/utils/common.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import os
1818
import re
1919
import sys
20-
from typing import Any, Callable, Dict, List, Tuple, Union
20+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2121

2222
import torch
2323
import transformers
@@ -331,11 +331,12 @@ def get_reciprocal(tensor):
331331
return recip
332332

333333

334-
def normalize_input(decoding_layer_inputs: list[tuple[Any]]) -> Tuple[List[torch.Tensor], Dict[str, Any]]:
334+
def normalize_input(
335+
decoding_layer_inputs: tuple[Union[list[torch.Tensor], dict, Any], Optional[dict]],
336+
) -> Tuple[List[torch.Tensor], Dict[str, Any]]:
335337
"""Normalize the decoding layer inputs into input_ids and other inputs."""
336338
input_ids = []
337-
input_others = {}
338-
input_others["positional_inputs"] = []
339+
input_others = {"positional_inputs": []}
339340
for cur_inp in decoding_layer_inputs:
340341
input_ids.append(cur_inp[0][0][0])
341342
for key, val in cur_inp[0][1].items():

0 commit comments

Comments
 (0)