@@ -25,7 +25,6 @@ def __init__(
2525
2626 self .qmax = torch .finfo (torch .float8_e4m3fn ).max
2727 self .qmin = torch .finfo (torch .float8_e4m3fn ).min
28- self .layer_num = layer_num
2928 self .total_head_num = head_num * dist .get_world_size () if dist .is_initialized () else head_num
3029 self .count = 0
3130 self .scales = None
@@ -45,7 +44,13 @@ def __init__(
4544 self .scales_list = cfg ["scales" ]
4645 self .scales = torch .tensor (self .scales_list , dtype = torch .float32 , device = "cuda" ).view (cfg ["scales_shape" ])
4746 if not get_env_start_args ().enable_fa3 :
48- self .scales = torch .repeat_interleave (self .scales , self .head_num , dim = - 1 )
47+ self .scales = torch .repeat_interleave (self .scales , head_num , dim = - 1 )
48+ elif cfg ["num_head" ] > self .total_head_num :
49+ factor = cfg ["num_head" ] // self .total_head_num
50+ self .scales = self .scales [..., ::factor ].contiguous ()
51+ elif cfg ["num_head" ] < self .total_head_num :
52+ factor = self .total_head_num // cfg ["num_head" ]
53+ self .scales = torch .repeat_interleave (self .scales , factor , dim = - 1 ).contiguous ()
4954 if get_env_start_args ().enable_fa3 and dist .is_initialized () and dist .get_world_size () > 1 :
5055 half_head = self .total_head_num // 2
5156 start_head = dist .get_rank () * head_num
@@ -77,7 +82,7 @@ def _load_and_check_config(self):
7782 raise ValueError (
7883 f"num_layers { cfg ['num_layers' ]} in config " f"not match current layer_num { self .layer_num } "
7984 )
80- if cfg ["num_head" ] != self .total_head_num :
85+ if cfg ["num_head" ] % self . total_head_num != 0 and self .total_head_num % cfg [ "num_head" ] != 0 :
8186 raise ValueError (
8287 f"num_head { cfg ['num_head' ]} in config " f"not match current model head num { self .total_head_num } "
8388 )
0 commit comments