@@ -32,7 +32,7 @@ def __init__(
3232 self .abs_max = None
3333
3434 if is_export_mode :
35- scales_shape = [layer_num , 2 * head_num ] if get_env_start_args ().enable_fa3 else [layer_num , 2 ]
35+ scales_shape = [layer_num , 2 * head_num ] if not get_env_start_args ().disable_fa3 else [layer_num , 2 ]
3636 self .abs_max = torch .zeros (scales_shape , dtype = torch .float32 , device = "cuda" )
3737 elif get_env_start_args ().kv_quant_calibration_config_path is not None :
3838 logger .info (
@@ -43,15 +43,15 @@ def __init__(
4343
4444 self .scales_list = cfg ["scales" ]
4545 self .scales = torch .tensor (self .scales_list , dtype = torch .float32 , device = "cuda" ).view (cfg ["scales_shape" ])
46- if not get_env_start_args ().enable_fa3 :
46+ if get_env_start_args ().disable_fa3 :
4747 self .scales = torch .repeat_interleave (self .scales , head_num , dim = - 1 )
4848 elif cfg ["num_head" ] > self .total_head_num :
4949 factor = cfg ["num_head" ] // self .total_head_num
5050 self .scales = self .scales [..., ::factor ].contiguous ()
5151 elif cfg ["num_head" ] < self .total_head_num :
5252 factor = self .total_head_num // cfg ["num_head" ]
5353 self .scales = torch .repeat_interleave (self .scales , factor , dim = - 1 ).contiguous ()
54- if get_env_start_args ().enable_fa3 and dist .is_initialized () and dist .get_world_size () > 1 :
54+ if not get_env_start_args ().disable_fa3 and dist .is_initialized () and dist .get_world_size () > 1 :
5555 half_head = self .total_head_num // 2
5656 start_head = dist .get_rank () * head_num
5757 end_head = start_head + head_num
@@ -86,7 +86,7 @@ def _load_and_check_config(self):
8686 raise ValueError (
8787 f"num_head { cfg ['num_head' ]} in config " f"not match current model head num { self .total_head_num } "
8888 )
89- if get_env_start_args ().enable_fa3 :
89+ if not get_env_start_args ().disable_fa3 :
9090 if cfg ["quant_type" ] != "per_head" :
9191 raise ValueError (f"quant type { cfg ['num_head' ]} in config not match fa3 backend" )
9292 else :
@@ -109,7 +109,7 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int):
109109 logger .info ("kv cache calibration mode will collect kv cache data for quantization calibration" )
110110
111111 if self .abs_max is not None and self .count >= warmup_counts :
112- if get_env_start_args ().enable_fa3 :
112+ if not get_env_start_args ().disable_fa3 :
113113 kv_max = kv_buffer .abs ().amax (dim = (0 , 2 )).to (torch .float32 )
114114 else :
115115 k_max = kv_buffer [:, : self .head_num , :].abs ().amax (dim = ()).to (torch .float32 )
@@ -119,7 +119,7 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int):
119119 if self .count == warmup_counts + inference_counts - 1 and layer_index == self .layer_num - 1 :
120120 final_abs_max = self .abs_max
121121 if dist .is_initialized () and dist .get_world_size () > 1 :
122- if get_env_start_args ().enable_fa3 :
122+ if not get_env_start_args ().disable_fa3 :
123123 k_max , v_max = torch .chunk (self .abs_max , 2 , dim = - 1 )
124124 k_max = k_max .contiguous ()
125125 v_max = v_max .contiguous ()
@@ -148,7 +148,7 @@ def _export_calibration_data(self):
148148 cfg = {
149149 "version" : "1.0" ,
150150 "architectures" : model_arch ,
151- "quant_type" : "per_head" if get_env_start_args ().enable_fa3 else "per_tensor" ,
151+ "quant_type" : "per_head" if not get_env_start_args ().disable_fa3 else "per_tensor" ,
152152 "qmin" : self .qmin ,
153153 "qmax" : self .qmax ,
154154 "num_layers" : self .layer_num ,
0 commit comments