3333from auto_round import envs
3434from auto_round .auto_scheme .gen_auto_scheme import AutoScheme
3535from auto_round .compressors .utils import (
36+ IndexSampler ,
3637 block_forward ,
3738 check_need_act_calibration ,
3839 check_skippable_keywords ,
@@ -196,7 +197,7 @@ def __init__(
196197 disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0). Defaults to False.
197198 enable_alg_ext (bool, optional): Enable algorithm extension (primarily for INT2). Defaults to False.
198199 **kwargs: Backward compatible options:
199- - enable_alg_ext, quant_lm_head, lr, lr_scheduler, sampler, not_use_best_mse, dynamic_max_gap,
200+ - enable_alg_ext, quant_lm_head, lr, lr_scheduler, not_use_best_mse, dynamic_max_gap,
200201 super_group_size, super_bits, scale_dtype ("fp16" etc.),
201202 nblocks, to_quant_block_names,
202203 enable_norm_bias_tuning, enable_quanted_input,
@@ -259,7 +260,6 @@ def __init__(
259260 enable_minmax_tuning = kwargs .pop ("enable_minmax_tuning" , True )
260261 minmax_lr = kwargs .pop ("minmax_lr" , None )
261262 lr_scheduler = kwargs .pop ("lr_scheduler" , None )
262- sampler = kwargs .pop ("sampler" , "rand" )
263263 not_use_best_mse = kwargs .pop ("not_use_best_mse" , False )
264264 dynamic_max_gap = kwargs .pop ("dynamic_max_gap" , - 1 )
265265 nblocks = kwargs .pop ("nblocks" , 1 )
@@ -350,7 +350,6 @@ def __init__(
350350 self .lr = lr
351351 self .minmax_lr = minmax_lr or self .lr
352352 self .enable_alg_ext = enable_alg_ext
353- self .sampler = sampler
354353 self .not_use_best_mse = not_use_best_mse
355354 self .dynamic_max_gap = dynamic_max_gap
356355 self .lr_scheduler = lr_scheduler
@@ -2487,29 +2486,33 @@ def _quantize_layer(
24872486 scaler = self ._get_scaler () # pylint: disable=assignment-from-none
24882487 init_loss = None
24892488 gradient_accumulate_steps = self .batch_size # Force to low gpu
2490- batch_size = 1 # Force to low gpu
2491- global_batch_size = batch_size * gradient_accumulate_steps
2492- global_batch_size = min (nsamples , global_batch_size )
2493- if self .sampler != "rand" :
2494- whole_indices = torch .randperm (nsamples )[:global_batch_size ]
2489+
24952490 total_loss = 0
24962491 num_elm = 1
24972492 mse_reduction = "mean"
24982493 if gradient_accumulate_steps != 1 :
24992494 mse_reduction = "sum"
25002495 mse_loss = torch .nn .MSELoss (reduction = mse_reduction ).to (device )
2496+ batch_size = 1 # Force to low gpu
2497+ global_batch_size = self .batch_size * gradient_accumulate_steps
2498+ global_batch_size = min (nsamples , global_batch_size )
2499+ if gradient_accumulate_steps != 1 and not self .attention_mask :
2500+ whole_indices = torch .arange (global_batch_size )
2501+ if q_inputs is not None :
2502+ num_elm = self ._get_current_num_elm (q_inputs , whole_indices )
2503+ else :
2504+ num_elm = self ._get_current_num_elm (inputs , whole_indices )
2505+
2506+ index_sampler = IndexSampler (nsamples , global_batch_size )
25012507
25022508 for i in range (self .iters ):
25032509 total_loss = 0
2504- if self .sampler == "rand" :
2505- whole_indices = torch .randperm (nsamples )[:global_batch_size ]
2506- if gradient_accumulate_steps != 1 :
2507- if q_inputs is not None :
2508- num_elm = self ._get_current_num_elm (q_inputs , whole_indices )
2509- else :
2510- num_elm = self ._get_current_num_elm (inputs , whole_indices )
2510+ global_indices = index_sampler .next_batch ()
2511+ if self .attention_mask :
2512+ num_elm = self ._get_non_zero_cnt (self .attention_mask , global_indices )
2513+
25112514 for tmp_step in range (gradient_accumulate_steps ):
2512- indices = whole_indices [tmp_step * batch_size : (tmp_step + 1 ) * batch_size ]
2515+ indices = global_indices [tmp_step * batch_size : (tmp_step + 1 ) * batch_size ]
25132516 if q_inputs is not None :
25142517 current_input = [q_inputs [i ] for i in indices ]
25152518 current_input = torch .cat (current_input , dim = 0 ).to (device )
@@ -2551,7 +2554,7 @@ def _quantize_layer(
25512554 loss = mse_loss ( # pylint: disable=not-callable
25522555 output_q .to (torch .float32 ), current_output .to (torch .float32 )
25532556 )
2554-
2557+ num_elm = 1 if num_elm <= 0 else num_elm
25552558 total_loss += loss .item () / num_elm
25562559
25572560 self ._scale_loss_and_backward (scaler , loss )
@@ -2615,6 +2618,13 @@ def _get_current_num_elm(
26152618 current_input_ids = [input_ids [i ] for i in indices ]
26162619 return sum (id .numel () for id in current_input_ids )
26172620
2621+ def _get_non_zero_cnt (self , tensor : list [torch .Tensor ], indices : list [int ]) -> int :
2622+ current_tensors = [tensor [i ] for i in indices ]
2623+ non_zero_cnt = 0
2624+ for t in current_tensors :
2625+ non_zero_cnt += torch .count_nonzero (t ).item ()
2626+ return non_zero_cnt
2627+
26182628 def quantize_block (
26192629 self ,
26202630 block : torch .nn .Module ,
@@ -2808,7 +2818,7 @@ def _quantize_block(
28082818 f"layers in the block"
28092819 )
28102820 logger .info (dump_info )
2811- unwrapper_block (block , {}) # TODO Quant layer should change
2821+ unwrapper_block (block , {})
28122822 mv_module_from_gpu (block )
28132823 return output , output
28142824
@@ -2823,11 +2833,6 @@ def _quantize_block(
28232833 nsamples = len (input_ids ["hidden_states" ])
28242834 else :
28252835 nsamples = len (input_ids )
2826-
2827- global_batch_size = self .batch_size * self .gradient_accumulate_steps
2828- global_batch_size = min (nsamples , global_batch_size )
2829- if self .sampler != "rand" :
2830- whole_indices = torch .randperm (nsamples )[:global_batch_size ]
28312836 last_best_iter = 0
28322837 best_loss = torch .finfo (torch .float ).max
28332838 num_elm = 1
@@ -2839,30 +2844,31 @@ def _quantize_block(
28392844 init_loss = None
28402845 best_params = {}
28412846 total_loss = 0
2847+ global_batch_size = self .batch_size * self .gradient_accumulate_steps
2848+ global_batch_size = min (nsamples , global_batch_size )
28422849 # We assume the block input and output shape is same
2843- if self .gradient_accumulate_steps != 1 :
2850+ if self .gradient_accumulate_steps != 1 and not self . attention_mask :
28442851 whole_indices = torch .arange (global_batch_size )
28452852 num_elm = self ._get_current_num_elm (input_ids , whole_indices )
28462853
2854+ index_sampler = IndexSampler (nsamples , global_batch_size )
2855+ batch_size = self .batch_size
28472856 for i in range (self .iters ):
28482857 if self .enable_alg_ext and self .data_type .endswith ("dq" ):
28492858 for n , m in block .named_modules ():
28502859 m .cur_iter = i
28512860 total_loss = 0
2852- if self .sampler == "rand" :
2853- whole_indices = torch .randperm (nsamples )[:global_batch_size ]
2861+ global_indices = index_sampler .next_batch ()
2862+ if self .attention_mask :
2863+ num_elm = self ._get_non_zero_cnt (self .attention_mask , global_indices )
28542864
28552865 for tmp_step in range (self .gradient_accumulate_steps ):
2856- indices = whole_indices [tmp_step * self .batch_size : (tmp_step + 1 ) * self .batch_size ]
2857-
2866+ indices = global_indices [tmp_step * batch_size : (tmp_step + 1 ) * batch_size ]
28582867 current_output = self ._get_current_output (output , indices )
2859-
28602868 current_output = to_device (current_output , loss_device )
2861-
28622869 output_q = self ._get_current_q_output (block , input_ids , input_others , indices , device , loss_device )
2863-
28642870 loss = self ._get_loss (output_q , current_output , indices , mse_loss , device )
2865-
2871+ num_elm = 1 if num_elm <= 0 else num_elm
28662872 total_loss += loss .item () / num_elm
28672873
28682874 if self .low_gpu_mem_usage and card_0_in_high_risk :
0 commit comments