@@ -945,63 +945,63 @@ def update_use_cudagraph(self, argument: bool):
945
945
argument = self .use_cudagraph
946
946
947
947
948
- class MobaAttentionConfig :
948
+ class PlasAttentionConfig :
949
949
def __init__ (
950
950
self ,
951
951
args ,
952
952
):
953
- self .moba_encoder_top_k_left : int = None
954
- self .moba_encoder_top_k_right : int = None
955
- "The sparse topk of encoder attention is located at [moba_encoder_top_k_left, moba_encoder top_k_right]"
956
- self .moba_decoder_top_k_left : int = None
957
- self .moba_decoder_top_k_right : int = None
958
- "The sparse topk of decoder attention is located at [moba_decoder_top_k_left, moba_decoder top_k_right]"
959
- self .moba_use_encoder_seq_limit : int = None
960
- "When the number of encdoer token is less than moba_use_encoder_seq_limit , it is not sparse"
961
- self .moba_use_decoder_seq_limit : int = None
962
- "When the number of decdoer token is less than moba_use_decoder_seq_limit , it is not sparse"
963
- self .moba_block_size : int = 128
964
- self .mlp_weight_name : str = "moba_mlp_weight .safetensors"
965
- self .moba_max_seq_length : int = 128 * 1024
953
+ self .plas_encoder_top_k_left : int = None
954
+ self .plas_encoder_top_k_right : int = None
955
+ "The sparse topk of encoder attention is located at [plas_encoder_top_k_left, plas_encoder top_k_right]"
956
+ self .plas_decoder_top_k_left : int = None
957
+ self .plas_decoder_top_k_right : int = None
958
+ "The sparse topk of decoder attention is located at [plas_decoder_top_k_left, plas_decoder top_k_right]"
959
+ self .plas_use_encoder_seq_limit : int = None
960
+ "When the number of encdoer token is less than plas_use_encoder_seq_limit , it is not sparse"
961
+ self .plas_use_decoder_seq_limit : int = None
962
+ "When the number of decdoer token is less than plas_use_decoder_seq_limit , it is not sparse"
963
+ self .plas_block_size : int = 128
964
+ self .mlp_weight_name : str = "plas_attention_mlp_weight .safetensors"
965
+ self .plas_max_seq_length : int = 128 * 1024
966
966
if args is not None :
967
967
for key , value in args .items ():
968
968
if hasattr (self , key ):
969
969
setattr (self , key , value )
970
- if self .moba_use_encoder_seq_limit is None and self .moba_encoder_top_k_left is not None :
971
- self .moba_use_encoder_seq_limit = self .moba_encoder_top_k_left * self .moba_block_size
972
- if self .moba_use_decoder_seq_limit is None and self .moba_decoder_top_k_left is not None :
973
- self .moba_use_decoder_seq_limit = self .moba_decoder_top_k_left * self .moba_block_size
970
+ if self .plas_use_encoder_seq_limit is None and self .plas_encoder_top_k_left is not None :
971
+ self .plas_use_encoder_seq_limit = self .plas_encoder_top_k_left * self .plas_block_size
972
+ if self .plas_use_decoder_seq_limit is None and self .plas_decoder_top_k_left is not None :
973
+ self .plas_use_decoder_seq_limit = self .plas_decoder_top_k_left * self .plas_block_size
974
974
self .check_legality_parameters ()
975
975
976
976
def check_legality_parameters (
977
977
self ,
978
978
) -> None :
979
- if self .moba_encoder_top_k_left is not None :
980
- assert self .moba_encoder_top_k_left > 0 , "moba_encoder_top_k_left must large than 0"
979
+ if self .plas_encoder_top_k_left is not None :
980
+ assert self .plas_encoder_top_k_left > 0 , "plas_encoder_top_k_left must large than 0"
981
981
982
- if self .moba_encoder_top_k_right is not None :
983
- assert self .moba_encoder_top_k_right > 0 , "moba_encoder_top_k_right must large than 0"
982
+ if self .plas_encoder_top_k_right is not None :
983
+ assert self .plas_encoder_top_k_right > 0 , "plas_encoder_top_k_right must large than 0"
984
984
assert (
985
- self .moba_encoder_top_k_right >= self .moba_encoder_top_k_left
986
- ), "moba_encoder_top_k_right must large than moba_encoder_top_k_left "
985
+ self .plas_encoder_top_k_right >= self .plas_encoder_top_k_left
986
+ ), "plas_encoder_top_k_right must large than plas_encoder_top_k_left "
987
987
988
- if self .moba_decoder_top_k_left is not None :
989
- assert self .moba_decoder_top_k_left > 0 , "moba_decoder_top_k_left must large than 0"
988
+ if self .plas_decoder_top_k_left is not None :
989
+ assert self .plas_decoder_top_k_left > 0 , "plas_decoder_top_k_left must large than 0"
990
990
991
- if self .moba_decoder_top_k_right is not None :
992
- assert self .moba_decoder_top_k_right > 0 , "moba_decoder_top_k_right must large than 0"
991
+ if self .plas_decoder_top_k_right is not None :
992
+ assert self .plas_decoder_top_k_right > 0 , "plas_decoder_top_k_right must large than 0"
993
993
assert (
994
- self .moba_decoder_top_k_right >= self .moba_decoder_top_k_left
995
- ), "moba_decoder_top_k_right must large than moba_decoder_top_k_left "
994
+ self .plas_decoder_top_k_right >= self .plas_decoder_top_k_left
995
+ ), "plas_decoder_top_k_right must large than plas_decoder_top_k_left "
996
996
997
- if self .moba_use_encoder_seq_limit is not None and self .moba_encoder_top_k_left is not None :
998
- assert self .moba_use_encoder_seq_limit >= self .moba_encoder_top_k_left * self .moba_block_size
999
- if self .moba_use_decoder_seq_limit is not None and self .moba_decoder_top_k_left is not None :
1000
- assert self .moba_use_decoder_seq_limit >= self .moba_decoder_top_k_left * self .moba_block_size
997
+ if self .plas_use_encoder_seq_limit is not None and self .plas_encoder_top_k_left is not None :
998
+ assert self .plas_use_encoder_seq_limit >= self .plas_encoder_top_k_left * self .plas_block_size
999
+ if self .plas_use_decoder_seq_limit is not None and self .plas_decoder_top_k_left is not None :
1000
+ assert self .plas_use_decoder_seq_limit >= self .plas_decoder_top_k_left * self .plas_block_size
1001
1001
1002
1002
def to_json_string (self ):
1003
1003
"""
1004
- Convert moba_attention_config to json string.
1004
+ Convert plas_attention_config to json string.
1005
1005
"""
1006
1006
return json .dumps ({key : value for key , value in self .__dict__ .items () if value is not None })
1007
1007
@@ -1396,7 +1396,7 @@ def __init__(
1396
1396
decoding_config : DecodingConfig = None ,
1397
1397
quant_config : QuantConfigBase = None ,
1398
1398
graph_opt_config : GraphOptimizationConfig = None ,
1399
- moba_attention_config : MobaAttentionConfig = None ,
1399
+ plas_attention_config : PlasAttentionConfig = None ,
1400
1400
speculative_config : SpeculativeConfig = None ,
1401
1401
tokenizer : str = None ,
1402
1402
max_model_len : int = 8192 ,
@@ -1427,7 +1427,7 @@ def __init__(
1427
1427
self .early_stop_config : Optional [EarlyStopConfig ] = early_stop_config
1428
1428
self .decoding_config : DecodingConfig = decoding_config # type: ignore
1429
1429
self .cache_config : CacheConfig = cache_config # type: ignore
1430
- self .moba_attention_config : Optional [MobaAttentionConfig ] = moba_attention_config
1430
+ self .plas_attention_config : Optional [PlasAttentionConfig ] = plas_attention_config
1431
1431
# Initialize cuda graph capture list
1432
1432
if self .graph_opt_config .cudagraph_capture_sizes is None :
1433
1433
self .graph_opt_config ._set_cudagraph_sizes (max_num_seqs = self .scheduler_config .max_num_seqs )
0 commit comments