77from lightllm .utils .dist_utils import get_global_rank
88from lightllm .utils .config_utils import get_model_architectures
99from lightllm .utils .log_utils import init_logger
10- from lightllm .utils .envs_utils import get_env_start_args
11- from lightllm .common .basemodel .basemodel import get_model_init_status
10+ from lightllm .utils .envs_utils import get_env_start_args , get_model_init_status
1211
1312logger = init_logger (__name__ )
1413
@@ -30,25 +29,33 @@ def __init__(
3029 self .total_head_num = head_num * dist .get_world_size () if dist .is_initialized () else head_num
3130 self .count = 0
3231 self .scales = None
32+ self .scales_list = None
3333 self .abs_max = None
3434
3535 if is_export_mode :
36- self .abs_max = torch .zeros ((layer_num , 2 * head_num ), dtype = torch .float32 , device = "cuda" )
36+ scales_shape = [layer_num , 2 * head_num ] if get_env_start_args ().enable_fa3 else [layer_num , 2 ]
37+ self .abs_max = torch .zeros (scales_shape , dtype = torch .float32 , device = "cuda" )
3738 elif get_env_start_args ().kv_quant_calibration_config_path is not None :
3839 logger .info (
3940 f"kv_quant_calibration_config_path { get_env_start_args ().kv_quant_calibration_config_path } is set, "
4041 "will load kv quant calibration config"
4142 )
4243 cfg = self ._load_and_check_config ()
4344
44- self .scales = torch .tensor (cfg ["scales" ], dtype = torch .float32 , device = "cuda" ).view (cfg ["scales_shape" ])
45- if dist .is_initialized () and dist .get_world_size () > 1 :
45+ self .scales_list = cfg ["scales" ]
46+ self .scales = torch .tensor (self .scales_list , dtype = torch .float32 , device = "cuda" ).view (cfg ["scales_shape" ])
47+ if not get_env_start_args ().enable_fa3 :
48+ self .scales = torch .repeat_interleave (self .scales , self .head_num , dim = - 1 )
49+ if get_env_start_args ().enable_fa3 and dist .is_initialized () and dist .get_world_size () > 1 :
4650 half_head = self .total_head_num // 2
4751 start_head = dist .get_rank () * head_num
4852 end_head = start_head + head_num
4953 k_scales = self .scales [:, start_head :end_head ].contiguous ()
5054 v_scales = self .scales [:, start_head + half_head : end_head + half_head ].contiguous ()
51- self .scales = torch .cat ((k_scales , v_scales ), dim = - 1 )
55+ current_scales = torch .cat ((k_scales , v_scales ), dim = - 1 )
56+
57+ self .scales_list = current_scales .tolist ()
58+ self .scales = current_scales
5259 else :
5360 logger .warning ("scales is None, no kv_quant_calibration_config_path be set, will use 1.0 as scales" )
5461
@@ -74,8 +81,12 @@ def _load_and_check_config(self):
7481 raise ValueError (
7582 f"num_head { cfg ['num_head' ]} in config " f"not match current model head num { self .total_head_num } "
7683 )
77- if cfg ["quant_type" ] != "per_head" :
78- raise ValueError (f"quant type { cfg ['quant_type' ]} in config not match fa3 backend" )
84+ if get_env_start_args ().enable_fa3 :
85+ if cfg ["quant_type" ] != "per_head" :
86+ raise ValueError (f"quant type { cfg ['num_head' ]} in config not match fa3 backend" )
87+ else :
88+ if cfg ["quant_type" ] != "per_tensor" :
89+ raise ValueError (f"quant type { cfg ['quant_type' ]} in config not match flashinfer backend" )
7990
8091 return cfg
8192 else :
@@ -93,21 +104,29 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int):
93104 logger .info ("kv cache calibration mode will collect kv cache data for quantization calibration" )
94105
95106 if self .abs_max is not None and self .count >= warmup_counts :
96- kv_max = kv_buffer .abs ().amax (dim = (0 , 2 )).to (torch .float32 )
107+ if get_env_start_args ().enable_fa3 :
108+ kv_max = kv_buffer .abs ().amax (dim = (0 , 2 )).to (torch .float32 )
109+ else :
110+ k_max = kv_buffer [:, : self .head_num , :].abs ().amax (dim = ()).to (torch .float32 )
111+ v_max = kv_buffer [:, self .head_num :, :].abs ().amax (dim = ()).to (torch .float32 )
112+ kv_max = torch .tensor ([k_max , v_max ], device = "cuda" , dtype = torch .float32 )
97113 self .abs_max [layer_index ] = torch .maximum (self .abs_max [layer_index ], kv_max )
98114 if self .count == warmup_counts + inference_counts - 1 and layer_index == self .layer_num - 1 :
99115 final_abs_max = self .abs_max
100116 if dist .is_initialized () and dist .get_world_size () > 1 :
101- k_max , v_max = torch .chunk (self .abs_max , 2 , dim = - 1 )
102- k_max = k_max .contiguous ()
103- v_max = v_max .contiguous ()
104- gathered_k_max = [torch .zeros_like (k_max ) for _ in range (dist .get_world_size ())]
105- gathered_v_max = [torch .zeros_like (v_max ) for _ in range (dist .get_world_size ())]
106- dist .all_gather (gathered_k_max , k_max , group = None , async_op = False )
107- dist .all_gather (gathered_v_max , v_max , group = None , async_op = False )
108- k_max = torch .cat (gathered_k_max , dim = - 1 )
109- v_max = torch .cat (gathered_v_max , dim = - 1 )
110- final_abs_max = torch .cat ((k_max , v_max ), dim = - 1 )
117+ if get_env_start_args ().enable_fa3 :
118+ k_max , v_max = torch .chunk (self .abs_max , 2 , dim = - 1 )
119+ k_max = k_max .contiguous ()
120+ v_max = v_max .contiguous ()
121+ gathered_k_max = [torch .zeros_like (k_max ) for _ in range (dist .get_world_size ())]
122+ gathered_v_max = [torch .zeros_like (v_max ) for _ in range (dist .get_world_size ())]
123+ dist .all_gather (gathered_k_max , k_max , group = None , async_op = False )
124+ dist .all_gather (gathered_v_max , v_max , group = None , async_op = False )
125+ k_max = torch .cat (gathered_k_max , dim = - 1 )
126+ v_max = torch .cat (gathered_v_max , dim = - 1 )
127+ final_abs_max = torch .cat ((k_max , v_max ), dim = - 1 )
128+ else :
129+ dist .all_reduce (self .abs_max , op = dist .ReduceOp .MAX , group = None , async_op = False )
111130
112131 self .scales = final_abs_max / self .qmax
113132 self .scales = torch .where (self .scales > 0 , self .scales , torch .ones_like (self .scales ))
@@ -124,7 +143,7 @@ def _export_calibration_data(self):
124143 cfg = {
125144 "version" : "1.0" ,
126145 "architectures" : model_arch ,
127- "quant_type" : "per_head" ,
146+ "quant_type" : "per_head" if get_env_start_args (). enable_fa3 else "per_tensor" ,
128147 "qmin" : self .qmin ,
129148 "qmax" : self .qmax ,
130149 "num_layers" : self .layer_num ,
0 commit comments