@@ -30,25 +30,33 @@ def __init__(
3030 self .total_head_num = head_num * dist .get_world_size () if dist .is_initialized () else head_num
3131 self .count = 0
3232 self .scales = None
33+ self .scales_list = []
3334 self .abs_max = None
3435
3536 if is_export_mode :
36- self .abs_max = torch .zeros ((layer_num , 2 * head_num ), dtype = torch .float32 , device = "cuda" )
37+ scales_shape = [layer_num , 2 * head_num ] if get_env_start_args ().enable_fa3 else [layer_num , 2 ]
38+ self .abs_max = torch .zeros (scales_shape , dtype = torch .float32 , device = "cuda" )
3739 elif get_env_start_args ().kv_quant_calibration_config_path is not None :
3840 logger .info (
3941 f"kv_quant_calibration_config_path { get_env_start_args ().kv_quant_calibration_config_path } is set, "
4042 "will load kv quant calibration config"
4143 )
4244 cfg = self ._load_and_check_config ()
4345
44- self .scales = torch .tensor (cfg ["scales" ], dtype = torch .float32 , device = "cuda" ).view (cfg ["scales_shape" ])
45- if dist .is_initialized () and dist .get_world_size () > 1 :
46+ self .scales_list = cfg ["scales" ]
47+ self .scales = torch .tensor (self .scales_list , dtype = torch .float32 , device = "cuda" ).view (cfg ["scales_shape" ])
48+ if not get_env_start_args ().enable_fa3 :
49+ self .scales = torch .repeat_interleave (self .scales , self .head_num , dim = - 1 )
50+ if get_env_start_args ().enable_fa3 and dist .is_initialized () and dist .get_world_size () > 1 :
4651 half_head = self .total_head_num // 2
4752 start_head = dist .get_rank () * head_num
4853 end_head = start_head + head_num
4954 k_scales = self .scales [:, start_head :end_head ].contiguous ()
5055 v_scales = self .scales [:, start_head + half_head : end_head + half_head ].contiguous ()
51- self .scales = torch .cat ((k_scales , v_scales ), dim = - 1 )
56+ current_scales = torch .cat ((k_scales , v_scales ), dim = - 1 )
57+
58+ self .scales_list = current_scales .tolist ()
59+ self .scales = current_scales
5260 else :
5361 logger .warning ("scales is None, no kv_quant_calibration_config_path be set, will use 1.0 as scales" )
5462
@@ -74,8 +82,12 @@ def _load_and_check_config(self):
7482 raise ValueError (
7583 f"num_head { cfg ['num_head' ]} in config " f"not match current model head num { self .total_head_num } "
7684 )
77- if cfg ["quant_type" ] != "per_head" :
78- raise ValueError (f"quant type { cfg ['quant_type' ]} in config not match fa3 backend" )
85+ if get_env_start_args ().enable_fa3 :
86+ if cfg ["quant_type" ] != "per_head" :
87+ raise ValueError (f"quant type { cfg ['num_head' ]} in config not match fa3 backend" )
88+ else :
89+ if cfg ["quant_type" ] != "per_tensor" :
90+ raise ValueError (f"quant type { cfg ['quant_type' ]} in config not match flashinfer backend" )
7991
8092 return cfg
8193 else :
@@ -93,21 +105,29 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int):
93105 logger .info ("kv cache calibration mode will collect kv cache data for quantization calibration" )
94106
95107 if self .abs_max is not None and self .count >= warmup_counts :
96- kv_max = kv_buffer .abs ().amax (dim = (0 , 2 )).to (torch .float32 )
108+ if get_env_start_args ().enable_fa3 :
109+ kv_max = kv_buffer .abs ().amax (dim = (0 , 2 )).to (torch .float32 )
110+ else :
111+ k_max = kv_buffer [:, : self .head_num , :].abs ().amax (dim = ()).to (torch .float32 )
112+ v_max = kv_buffer [:, self .head_num :, :].abs ().amax (dim = ()).to (torch .float32 )
113+ kv_max = torch .tensor ([k_max , v_max ], device = "cuda" , dtype = torch .float32 )
97114 self .abs_max [layer_index ] = torch .maximum (self .abs_max [layer_index ], kv_max )
98115 if self .count == warmup_counts + inference_counts - 1 and layer_index == self .layer_num - 1 :
99116 final_abs_max = self .abs_max
100117 if dist .is_initialized () and dist .get_world_size () > 1 :
101- k_max , v_max = torch .chunk (self .abs_max , 2 , dim = - 1 )
102- k_max = k_max .contiguous ()
103- v_max = v_max .contiguous ()
104- gathered_k_max = [torch .zeros_like (k_max ) for _ in range (dist .get_world_size ())]
105- gathered_v_max = [torch .zeros_like (v_max ) for _ in range (dist .get_world_size ())]
106- dist .all_gather (gathered_k_max , k_max , group = None , async_op = False )
107- dist .all_gather (gathered_v_max , v_max , group = None , async_op = False )
108- k_max = torch .cat (gathered_k_max , dim = - 1 )
109- v_max = torch .cat (gathered_v_max , dim = - 1 )
110- final_abs_max = torch .cat ((k_max , v_max ), dim = - 1 )
118+ if get_env_start_args ().enable_fa3 :
119+ k_max , v_max = torch .chunk (self .abs_max , 2 , dim = - 1 )
120+ k_max = k_max .contiguous ()
121+ v_max = v_max .contiguous ()
122+ gathered_k_max = [torch .zeros_like (k_max ) for _ in range (dist .get_world_size ())]
123+ gathered_v_max = [torch .zeros_like (v_max ) for _ in range (dist .get_world_size ())]
124+ dist .all_gather (gathered_k_max , k_max , group = None , async_op = False )
125+ dist .all_gather (gathered_v_max , v_max , group = None , async_op = False )
126+ k_max = torch .cat (gathered_k_max , dim = - 1 )
127+ v_max = torch .cat (gathered_v_max , dim = - 1 )
128+ final_abs_max = torch .cat ((k_max , v_max ), dim = - 1 )
129+ else :
130+ dist .all_reduce (self .abs_max , op = dist .ReduceOp .MAX , group = None , async_op = False )
111131
112132 self .scales = final_abs_max / self .qmax
113133 self .scales = torch .where (self .scales > 0 , self .scales , torch .ones_like (self .scales ))
@@ -124,7 +144,7 @@ def _export_calibration_data(self):
124144 cfg = {
125145 "version" : "1.0" ,
126146 "architectures" : model_arch ,
127- "quant_type" : "per_head" ,
147+ "quant_type" : "per_head" if get_env_start_args (). enable_fa3 else "per_tensor" ,
128148 "qmin" : self .qmin ,
129149 "qmax" : self .qmax ,
130150 "num_layers" : self .layer_num ,
0 commit comments