Skip to content

Commit 4325b73

Browse files
【FIX】Change the name of sparse attn from moba to plas (#4006) (#4076)
* 【FIX】Change the name of sparse attn from moba to plas (#4006) * 更新文档 * 【docs】 update readme (#4000) * 更新文档 * update readme * update docs * 【FIX】Change the name of sparse attn from moba to plas (#3845) * 更新文档 * 更新文档 * 更新文档 * 更新文档 * 修改moba为plas * code style * update ci * code style * update ci * code style --------- Co-authored-by: Jiang-Jia-Jun <[email protected]> * fix max_num_seqs * fix test load attn --------- Co-authored-by: Jiang-Jia-Jun <[email protected]>
1 parent 2c34a55 commit 4325b73

File tree

14 files changed

+152
-152
lines changed

14 files changed

+152
-152
lines changed

docs/features/plas_attention.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ We selected a subset (longbook_sum_eng) from InfiniteBench as the performance ev
196196
## Usage
197197

198198
```
199-
export FD_ATTENTION_BACKEND="MOBA_ATTN"
199+
export FD_ATTENTION_BACKEND="PLAS_ATTN"
200200
201201
python -m fastdeploy.entrypoints.openai.api_server
202202
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
@@ -207,13 +207,13 @@ python -m fastdeploy.entrypoints.openai.api_server
207207
--max-num-batched-tokens 8192 \
208208
--max-model-len 131072 \
209209
--max-num-seqs 32 \
210-
--moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}'
210+
--plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
211211
```
212212

213-
**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `moba_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations.
213+
**Note**: If sparse attention is enabled, the system will automatically load the MLP weights from `plas_attention_mlp_weight.safetensors` in the weight directory. If the MLP weight file is not found, mean pooling will be applied to the key representations.
214214

215215
**Parameter Description:**
216216

217-
* Setting `FD_ATTENTION_BACKEND="MOBA_ATTN"` enables MOBA sparse attention.
218-
* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse.
219-
* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse.
217+
* Setting `FD_ATTENTION_BACKEND="PLAS_ATTN"` enables PLAS sparse attention.
218+
* `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` indicates that the range of top-k is between 50 and 60 when the encoder is sparse.
219+
* `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` indicates that the range of top-k is between 100 and 120 when the decoder is sparse.

docs/zh/features/plas_attention.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
<img src="images/plas_training_distill.png" alt="Attention Gate Module" width="60%">
1919
</div>
2020

21-
* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个 MLP 层压缩每个 K 个块,生成一个具有代表性的低维表示:$K_c^T=W_{kp}K^T$,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。
21+
* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个MLP层压缩每个K个块,生成一个具有代表性的低维表示: $K_c^T=W_{kp}K^T$ ,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。
2222

2323
* **Training Data**: 得益于模型架构和训练范式的高效性,我们的方法仅使用 10 亿个 token 进行训练,便实现了近乎无损的精度。训练数据源自内部构建的包含长文本和短文本的混合语料库,从而增强了模块对不同序列长度的适应性。
2424

@@ -36,7 +36,7 @@
3636

3737
* **Prefill Token Union**: 我们观察到相邻的查询标记倾向于选择相似的关键块。利用这种局部性,我们取连续 128 个查询标记选择的关键块的并集,并联合计算这些标记的稀疏注意力机制。
3838

39-
* **Decode Head Union**: 鉴于 GQA 在现代模型中的广泛应用,我们发现同一组内的不同查询头经常选择重叠的关键块。因此,我们将同一组内所有查询头选择的关键块合并为一个统一的集合,并联合计算稀疏注意力机制。这种方式也减少了内存访问开销,并进一步提高了解码效率。
39+
* **Decode Head Union**: 鉴于GQA在现代模型中的广泛应用,我们发现同一组内的不同查询头经常选择重叠的关键块。因此,我们将同一组内所有查询头选择的关键块合并为一个统一的集合,并联合计算稀疏注意力机制。这种方式也减少了内存访问开销,并进一步提高了解码效率。
4040

4141
* **Top-K Selection**: 传统的 Top-k 算法基于排序或直接调用 Cub 库,会带来显著的运行时开销。为了缓解这个问题,我们实现了一个基于二分查找的近似 Top-k 选择算法,该算法在保持准确率的同时显著降低了延迟,最终实现了性能的显著提升。
4242

@@ -200,7 +200,7 @@
200200
## 使用方式
201201

202202
```
203-
export FD_ATTENTION_BACKEND="MOBA_ATTN"
203+
export FD_ATTENTION_BACKEND="PLAS_ATTN"
204204
205205
python -m fastdeploy.entrypoints.openai.api_server
206206
--model baidu/ERNIE-4.5-300B-A47B-Paddle \
@@ -211,13 +211,13 @@ python -m fastdeploy.entrypoints.openai.api_server
211211
--max-num-batched-tokens 8192 \
212212
--max-model-len 131072 \
213213
--max-num-seqs 32 \
214-
--moba-attention-config '{"moba_encoder_top_k_left": 50, "moba_encoder_top_k_right": 60, "moba_decoder_top_k_left": 100, "moba_decoder_top_k_right": 120}'
214+
--plas-attention-config '{"plas_encoder_top_k_left": 50, "plas_encoder_top_k_right": 60, "plas_decoder_top_k_left": 100, "plas_decoder_top_k_right": 120}'
215215
```
216216

217-
**Note**: 如果启用了稀疏注意力机制,系统将自动从权重目录中的`moba_mlp_weight.safetensors`文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化
217+
**Note**: 如果启用了稀疏注意力机制,系统将自动从权重目录中的`plas_attention_mlp_weight.safetensors`文件加载 MLP 权重。如果未找到 MLP 权重文件,则将对关键表示应用均值池化
218218

219219
**Parameter Description:**
220220

221-
* `FD_ATTENTION_BACKEND="MOBA_ATTN"` 启用 MOBA sparse attention.
222-
* `moba_encoder_top_k_left=50, moba_encoder_top_k_right=60` 表示当encoder时,top-k的范围在50到60之间。
223-
* `moba_decoder_top_k_left=100, moba_decoder_top_k_right=120` 表示当decoder时,top-k的范围在100到120之间。
221+
* `FD_ATTENTION_BACKEND="PLAS_ATTN"` 启用 PLAS sparse attention.
222+
* `plas_encoder_top_k_left=50, plas_encoder_top_k_right=60` 表示当encoder时,top-k的范围在50到60之间。
223+
* `plas_decoder_top_k_left=100, plas_decoder_top_k_right=120` 表示当decoder时,top-k的范围在100到120之间。

fastdeploy/config.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -945,63 +945,63 @@ def update_use_cudagraph(self, argument: bool):
945945
argument = self.use_cudagraph
946946

947947

948-
class MobaAttentionConfig:
948+
class PlasAttentionConfig:
949949
def __init__(
950950
self,
951951
args,
952952
):
953-
self.moba_encoder_top_k_left: int = None
954-
self.moba_encoder_top_k_right: int = None
955-
"The sparse topk of encoder attention is located at [moba_encoder_top_k_left, moba_encoder top_k_right]"
956-
self.moba_decoder_top_k_left: int = None
957-
self.moba_decoder_top_k_right: int = None
958-
"The sparse topk of decoder attention is located at [moba_decoder_top_k_left, moba_decoder top_k_right]"
959-
self.moba_use_encoder_seq_limit: int = None
960-
"When the number of encdoer token is less than moba_use_encoder_seq_limit, it is not sparse"
961-
self.moba_use_decoder_seq_limit: int = None
962-
"When the number of decdoer token is less than moba_use_decoder_seq_limit, it is not sparse"
963-
self.moba_block_size: int = 128
964-
self.mlp_weight_name: str = "moba_mlp_weight.safetensors"
965-
self.moba_max_seq_length: int = 128 * 1024
953+
self.plas_encoder_top_k_left: int = None
954+
self.plas_encoder_top_k_right: int = None
955+
"The sparse topk of encoder attention is located at [plas_encoder_top_k_left, plas_encoder top_k_right]"
956+
self.plas_decoder_top_k_left: int = None
957+
self.plas_decoder_top_k_right: int = None
958+
"The sparse topk of decoder attention is located at [plas_decoder_top_k_left, plas_decoder top_k_right]"
959+
self.plas_use_encoder_seq_limit: int = None
960+
"When the number of encdoer token is less than plas_use_encoder_seq_limit, it is not sparse"
961+
self.plas_use_decoder_seq_limit: int = None
962+
"When the number of decdoer token is less than plas_use_decoder_seq_limit, it is not sparse"
963+
self.plas_block_size: int = 128
964+
self.mlp_weight_name: str = "plas_attention_mlp_weight.safetensors"
965+
self.plas_max_seq_length: int = 128 * 1024
966966
if args is not None:
967967
for key, value in args.items():
968968
if hasattr(self, key):
969969
setattr(self, key, value)
970-
if self.moba_use_encoder_seq_limit is None and self.moba_encoder_top_k_left is not None:
971-
self.moba_use_encoder_seq_limit = self.moba_encoder_top_k_left * self.moba_block_size
972-
if self.moba_use_decoder_seq_limit is None and self.moba_decoder_top_k_left is not None:
973-
self.moba_use_decoder_seq_limit = self.moba_decoder_top_k_left * self.moba_block_size
970+
if self.plas_use_encoder_seq_limit is None and self.plas_encoder_top_k_left is not None:
971+
self.plas_use_encoder_seq_limit = self.plas_encoder_top_k_left * self.plas_block_size
972+
if self.plas_use_decoder_seq_limit is None and self.plas_decoder_top_k_left is not None:
973+
self.plas_use_decoder_seq_limit = self.plas_decoder_top_k_left * self.plas_block_size
974974
self.check_legality_parameters()
975975

976976
def check_legality_parameters(
977977
self,
978978
) -> None:
979-
if self.moba_encoder_top_k_left is not None:
980-
assert self.moba_encoder_top_k_left > 0, "moba_encoder_top_k_left must large than 0"
979+
if self.plas_encoder_top_k_left is not None:
980+
assert self.plas_encoder_top_k_left > 0, "plas_encoder_top_k_left must large than 0"
981981

982-
if self.moba_encoder_top_k_right is not None:
983-
assert self.moba_encoder_top_k_right > 0, "moba_encoder_top_k_right must large than 0"
982+
if self.plas_encoder_top_k_right is not None:
983+
assert self.plas_encoder_top_k_right > 0, "plas_encoder_top_k_right must large than 0"
984984
assert (
985-
self.moba_encoder_top_k_right >= self.moba_encoder_top_k_left
986-
), "moba_encoder_top_k_right must large than moba_encoder_top_k_left"
985+
self.plas_encoder_top_k_right >= self.plas_encoder_top_k_left
986+
), "plas_encoder_top_k_right must large than plas_encoder_top_k_left"
987987

988-
if self.moba_decoder_top_k_left is not None:
989-
assert self.moba_decoder_top_k_left > 0, "moba_decoder_top_k_left must large than 0"
988+
if self.plas_decoder_top_k_left is not None:
989+
assert self.plas_decoder_top_k_left > 0, "plas_decoder_top_k_left must large than 0"
990990

991-
if self.moba_decoder_top_k_right is not None:
992-
assert self.moba_decoder_top_k_right > 0, "moba_decoder_top_k_right must large than 0"
991+
if self.plas_decoder_top_k_right is not None:
992+
assert self.plas_decoder_top_k_right > 0, "plas_decoder_top_k_right must large than 0"
993993
assert (
994-
self.moba_decoder_top_k_right >= self.moba_decoder_top_k_left
995-
), "moba_decoder_top_k_right must large than moba_decoder_top_k_left"
994+
self.plas_decoder_top_k_right >= self.plas_decoder_top_k_left
995+
), "plas_decoder_top_k_right must large than plas_decoder_top_k_left"
996996

997-
if self.moba_use_encoder_seq_limit is not None and self.moba_encoder_top_k_left is not None:
998-
assert self.moba_use_encoder_seq_limit >= self.moba_encoder_top_k_left * self.moba_block_size
999-
if self.moba_use_decoder_seq_limit is not None and self.moba_decoder_top_k_left is not None:
1000-
assert self.moba_use_decoder_seq_limit >= self.moba_decoder_top_k_left * self.moba_block_size
997+
if self.plas_use_encoder_seq_limit is not None and self.plas_encoder_top_k_left is not None:
998+
assert self.plas_use_encoder_seq_limit >= self.plas_encoder_top_k_left * self.plas_block_size
999+
if self.plas_use_decoder_seq_limit is not None and self.plas_decoder_top_k_left is not None:
1000+
assert self.plas_use_decoder_seq_limit >= self.plas_decoder_top_k_left * self.plas_block_size
10011001

10021002
def to_json_string(self):
10031003
"""
1004-
Convert moba_attention_config to json string.
1004+
Convert plas_attention_config to json string.
10051005
"""
10061006
return json.dumps({key: value for key, value in self.__dict__.items() if value is not None})
10071007

@@ -1396,7 +1396,7 @@ def __init__(
13961396
decoding_config: DecodingConfig = None,
13971397
quant_config: QuantConfigBase = None,
13981398
graph_opt_config: GraphOptimizationConfig = None,
1399-
moba_attention_config: MobaAttentionConfig = None,
1399+
plas_attention_config: PlasAttentionConfig = None,
14001400
speculative_config: SpeculativeConfig = None,
14011401
tokenizer: str = None,
14021402
max_model_len: int = 8192,
@@ -1427,7 +1427,7 @@ def __init__(
14271427
self.early_stop_config: Optional[EarlyStopConfig] = early_stop_config
14281428
self.decoding_config: DecodingConfig = decoding_config # type: ignore
14291429
self.cache_config: CacheConfig = cache_config # type: ignore
1430-
self.moba_attention_config: Optional[MobaAttentionConfig] = moba_attention_config
1430+
self.plas_attention_config: Optional[PlasAttentionConfig] = plas_attention_config
14311431
# Initialize cuda graph capture list
14321432
if self.graph_opt_config.cudagraph_capture_sizes is None:
14331433
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.scheduler_config.max_num_seqs)

fastdeploy/engine/args_utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
FDConfig,
3131
GraphOptimizationConfig,
3232
LoadConfig,
33-
MobaAttentionConfig,
3433
ModelConfig,
3534
ParallelConfig,
35+
PlasAttentionConfig,
3636
PoolerConfig,
3737
RunnerOption,
3838
SpeculativeConfig,
@@ -361,9 +361,9 @@ class EngineArgs:
361361
"""
362362
Configuration for graph optimization backend execution.
363363
"""
364-
moba_attention_config: Optional[Dict[str, Any]] = None
364+
plas_attention_config: Optional[Dict[str, Any]] = None
365365
"""
366-
Configuration for moba attention.
366+
Configuration for plas attention.
367367
"""
368368

369369
enable_logprob: bool = False
@@ -601,9 +601,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
601601
help="",
602602
)
603603
model_group.add_argument(
604-
"--moba-attention-config",
604+
"--plas-attention-config",
605605
type=json.loads,
606-
default=EngineArgs.moba_attention_config,
606+
default=EngineArgs.plas_attention_config,
607607
help="",
608608
)
609609
model_group.add_argument(
@@ -993,17 +993,17 @@ def create_graph_optimization_config(self) -> GraphOptimizationConfig:
993993
graph_optimization_args[k] = v
994994
return GraphOptimizationConfig(graph_optimization_args)
995995

996-
def create_moba_attention_config(self) -> MobaAttentionConfig:
996+
def create_plas_attention_config(self) -> PlasAttentionConfig:
997997
"""
998-
Create and retuan a MobaAttentionConfig object based on the current settings.
998+
Create and retuan a PlasAttentionConfig object based on the current settings.
999999
"""
10001000
attention_args = asdict(self)
1001-
if self.moba_attention_config is not None:
1002-
for k, v in self.moba_attention_config.items():
1001+
if self.plas_attention_config is not None:
1002+
for k, v in self.plas_attention_config.items():
10031003
attention_args[k] = v
1004-
return MobaAttentionConfig(attention_args)
1004+
return PlasAttentionConfig(attention_args)
10051005
else:
1006-
return MobaAttentionConfig(None)
1006+
return PlasAttentionConfig(None)
10071007

10081008
def create_early_stop_config(self) -> EarlyStopConfig:
10091009
"""
@@ -1064,7 +1064,7 @@ def create_engine_config(self) -> FDConfig:
10641064
scheduler_cfg = self.create_scheduler_config()
10651065
graph_opt_cfg = self.create_graph_optimization_config()
10661066
graph_opt_cfg.update_use_cudagraph(self.use_cudagraph)
1067-
moba_attention_config = self.create_moba_attention_config()
1067+
plas_attention_config = self.create_plas_attention_config()
10681068

10691069
early_stop_cfg = self.create_early_stop_config()
10701070
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
@@ -1093,7 +1093,7 @@ def create_engine_config(self) -> FDConfig:
10931093
max_long_partial_prefills=self.max_long_partial_prefills,
10941094
long_prefill_token_threshold=self.long_prefill_token_threshold,
10951095
graph_opt_config=graph_opt_cfg,
1096-
moba_attention_config=moba_attention_config,
1096+
plas_attention_config=plas_attention_config,
10971097
guided_decoding_backend=self.guided_decoding_backend,
10981098
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
10991099
early_stop_config=early_stop_cfg,

fastdeploy/engine/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def _start_worker_service(self):
501501
f" --early_stop_config '{self.cfg.early_stop_config.to_json_string()}'"
502502
f" --reasoning_parser {self.cfg.reasoning_parser}"
503503
f" --load_choices {self.cfg.load_config.load_choices}"
504-
f" --moba_attention_config '{self.cfg.moba_attention_config.to_json_string()}'"
504+
f" --plas_attention_config '{self.cfg.plas_attention_config.to_json_string()}'"
505505
f" --ips {ips}"
506506
f" --cache-transfer-protocol {self.cfg.cache_config.cache_transfer_protocol}"
507507
f" --runner {self.cfg.model_config.runner}"

fastdeploy/model_executor/layers/attention/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .flash_attn_backend import FlashAttentionBackend
2121
from .iluvatar_attn_backend import IluvatarAttnBackend
2222
from .mla_attention_backend import MLAAttentionBackend
23-
from .moba_attention_backend import MobaAttentionBackend
23+
from .moba_attention_backend import PlasAttentionBackend
2424
from .native_paddle_backend import PaddleNativeAttnBackend
2525
from .xpu_attn_backend import XPUAttentionBackend
2626

@@ -35,5 +35,5 @@
3535
"IluvatarAttnBackend",
3636
"BlockAttentionBackend",
3737
"Attention",
38-
"MobaAttentionBackend",
38+
"PlasAttentionBackend",
3939
]

0 commit comments

Comments
 (0)