Skip to content

Commit 7eee023

Browse files
committed
wip
1 parent be94bdd commit 7eee023

36 files changed

+2364
-11
lines changed

build.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ function copy_ops(){
104104
is_npu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('npu'))"`
105105
if [ "$is_npu" = "True" ]; then
106106
DEVICE_TYPE="npu"
107-
cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/npu
108-
echo -e "npu ops have been copy to fastdeploy"
107+
echo -e "npu ops are already present in fastdeploy"
109108
return
110109
fi
111110

@@ -153,6 +152,7 @@ function build_and_install_ops() {
153152
echo -e "${BLUE}[build]${NONE} build and install fastdeploy_ops..."
154153
TMP_DIR_REAL_PATH=`readlink -f ${OPS_TMP_DIR}`
155154
is_xpu=`$python -c "import paddle; print(paddle.is_compiled_with_xpu())"`
155+
is_npu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('npu'))"`
156156
if [ "$is_xpu" = "True" ]; then
157157
cd xpu_ops/src
158158
bash build.sh ${TMP_DIR_REAL_PATH}
@@ -164,6 +164,8 @@ function build_and_install_ops() {
164164
FD_BUILDING_ARCS=${FD_BUILDING_ARCS} FD_CPU_USE_BF16=True ${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}
165165
fi
166166
find ${OPS_TMP_DIR} -type f -name "*.o" -exec rm -f {} \;
167+
elif [ "$is_npu" = "True" ]; then
168+
echo -e "${BLUE}[build]${NONE} skipping NPU ops build (already present)"
167169
elif [ "$FD_CPU_USE_BF16" == "false" ]; then
168170
if [ "$FD_BUILDING_ARCS" == "" ]; then
169171
${python} setup_ops.py install --install-lib ${OPS_TMP_DIR}

fastdeploy/model_executor/layers/activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
or current_platform.is_maca()
7272
):
7373
self.forward = self.forward_cuda
74-
elif current_platform.is_gcu():
74+
elif current_platform.is_gcu() or current_platform.is_npu():
7575
self.forward = self.forward_gcu
7676
else:
7777
raise NotImplementedError

fastdeploy/model_executor/layers/attention/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .mla_attention_backend import MLAAttentionBackend
2323
from .native_paddle_backend import PaddleNativeAttnBackend
2424
from .xpu_attn_backend import XPUAttentionBackend
25+
from .npu_fapa_attn_backend import NpuFaPaAttentionBackend
2526

2627
__all__ = [
2728
"AttentionBackend",
@@ -34,4 +35,5 @@
3435
"IluvatarAttnBackend",
3536
"BlockAttentionBackend",
3637
"Attention",
38+
"NpuFaPaAttentionBackend"
3739
]
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import os
20+
from dataclasses import dataclass, field
21+
from typing import TYPE_CHECKING, List, Optional
22+
from paddle import core
23+
from fastdeploy.config import FDConfig
24+
import paddle
25+
from fastdeploy.model_executor.layers.attention.ops import (
26+
get_block_shape_and_split_kv_block, init_signal_layerwise,
27+
open_shm_and_get_meta_signal)
28+
from fastdeploy.model_executor.ops.npu import fused_fapa_attention_npu
29+
30+
if TYPE_CHECKING:
31+
from paddle._typing.dtype_like import _DTypeLiteral
32+
33+
# from fastdeploy.config import LLMConfig
34+
from fastdeploy.model_executor.layers.attention import Attention
35+
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
36+
AttentionBackend, AttentionMetadata)
37+
38+
39+
@dataclass
40+
class NpuFaPaAttentionMetadata(AttentionMetadata):
41+
"""
42+
NpuFaPaAttentionMetadata
43+
"""
44+
45+
max_len_kv: paddle.Tensor = None
46+
set_max_lengths: int = -1
47+
encoder_batch_ids: paddle.Tensor = None
48+
encoder_tile_ids_per_batch: paddle.Tensor = None
49+
encoder_num_blocks: paddle.Tensor = None
50+
kv_batch_ids: paddle.Tensor = None
51+
kv_tile_ids_per_batch: paddle.Tensor = None
52+
kv_num_blocks: paddle.Tensor = None
53+
decoder_batch_ids: paddle.Tensor = None
54+
decoder_tile_ids_per_batch: paddle.Tensor = None
55+
decoder_num_blocks: paddle.Tensor = None
56+
57+
_dtype: _DTypeLiteral = paddle.bfloat16
58+
encoder_max_partition_size: int = 32768
59+
max_partition_size: int = 32768
60+
block_tables: Optional[paddle.Tensor] = None
61+
rotary_embs: Optional[paddle.Tensor] = None
62+
attn_mask: Optional[paddle.Tensor] = None
63+
encoder_block_shape_q: Optional[paddle.Tensor] = None
64+
decoder_block_shape_q: Optional[paddle.Tensor] = None
65+
_fuse_kernel_compute_dtype: str = "bf16"
66+
67+
# pd_disaggregation
68+
kv_signal_metadata: Optional[paddle.Tensor] = None
69+
kv_signal_data_list: List[paddle.Tensor] = field(default_factory=list)
70+
71+
72+
class NpuFaPaAttentionBackend(AttentionBackend):
73+
"""
74+
NpuFaPaAttentionBackend backend implementation.
75+
"""
76+
77+
def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int):
78+
"""
79+
NpuFaPaAttentionBackend __init__
80+
"""
81+
super().__init__()
82+
self.attention_metadata: NpuFaPaAttentionMetadata = None
83+
# TODO(gongshaotian): Use fd_config parameters in the correct location
84+
self.block_size = fd_config.parallel_config.block_size
85+
self.max_seq_len = fd_config.parallel_config.max_model_len
86+
self.rope_theta = (
87+
10000.0
88+
if fd_config.model_config.rope_theta is None
89+
else fd_config.model_config.rope_theta
90+
)
91+
self.rope_3d = getattr(fd_config.model_config, "rope_3d", False)
92+
self.causal = getattr(fd_config.model_config, "causal", True)
93+
self.speculative_method: str = fd_config.speculative_config.method
94+
self.use_speculate: bool = self.speculative_method is not None
95+
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
96+
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
97+
self.rank = fd_config.parallel_config.tensor_parallel_rank
98+
99+
self.kv_num_heads = kv_num_heads
100+
self.num_heads = num_heads
101+
self.head_dim = head_dim
102+
self.num_layers: int = fd_config.model_config.num_hidden_layers
103+
104+
# pd_disaggregation
105+
self.use_pd_disaggregation = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
106+
self.start_layer_index = fd_config.model_config.start_layer_index
107+
108+
def init_attention_metadata(self, forward_meta):
109+
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
110+
metadata = NpuFaPaAttentionMetadata()
111+
metadata.encoder_block_shape_q = 64
112+
metadata.decoder_block_shape_q = 16
113+
metadata.max_partition_size = 32768
114+
metadata.encoder_max_partition_size = 32768
115+
metadata._dtype = paddle.get_default_dtype()
116+
if metadata._dtype == "bfloat16":
117+
metadata._fuse_kernel_compute_dtype = "bf16"
118+
elif metadata._dtype == "float16":
119+
metadata._fuse_kernel_compute_dtype = "fp16"
120+
elif metadata._dtype == "float32":
121+
metadata._fuse_kernel_compute_dtype = "fp32"
122+
metadata.block_tables = forward_meta.block_tables
123+
metadata.rotary_embs = forward_meta.rotary_embs
124+
metadata.attn_mask = forward_meta.attn_mask
125+
metadata.pre_caches_length = forward_meta.pre_caches_length
126+
127+
# # FIXME:
128+
# (
129+
# metadata.encoder_batch_ids,
130+
# metadata.encoder_tile_ids_per_batch,
131+
# metadata.encoder_num_blocks,
132+
# metadata.kv_batch_ids,
133+
# metadata.kv_tile_ids_per_batch,
134+
# metadata.kv_num_blocks,
135+
# metadata.decoder_batch_ids,
136+
# metadata.decoder_tile_ids_per_batch,
137+
# metadata.decoder_num_blocks,
138+
# metadata.max_len_kv,
139+
# metadata.set_max_lengths,
140+
# ) = get_block_shape_and_split_kv_block(
141+
# forward_meta.seq_lens_encoder,
142+
# forward_meta.seq_lens_decoder,
143+
# forward_meta.seq_lens_this_time,
144+
# forward_meta.cum_offsets,
145+
# metadata.encoder_block_shape_q,
146+
# metadata.decoder_block_shape_q,
147+
# self.num_heads // self.kv_num_heads,
148+
# self.block_size,
149+
# self.speculate_max_draft_token_num + 1,
150+
# )
151+
152+
# pd_disaggregation
153+
metadata.kv_signal_data_list = [None] * self.num_layers
154+
if self.use_pd_disaggregation:
155+
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
156+
self.rank, self.keep_pd_step_flag
157+
)
158+
self.attention_metadata = metadata
159+
160+
def get_attntion_meta(self):
161+
"""get_attntion_meta"""
162+
return self.attention_metadata
163+
164+
def get_kv_cache_shape(
165+
self,
166+
max_num_blocks: int,
167+
kv_cache_quant_type: str = None,
168+
169+
):
170+
"""
171+
Caculate kv cache shape
172+
"""
173+
return (max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim)
174+
175+
def forward_mixed(
176+
self,
177+
q,
178+
k,
179+
v,
180+
qkv,
181+
compressed_kv,
182+
k_pe,
183+
layer: Attention,
184+
forward_meta,
185+
):
186+
"""
187+
forward_mixed
188+
"""
189+
metadata = self.attention_metadata
190+
191+
if self.use_pd_disaggregation:
192+
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
193+
metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index
194+
)
195+
# FIXME: guozr 这里改成bfloat16
196+
197+
198+
res = fused_fapa_attention_npu(
199+
qkv,
200+
metadata.rotary_embs,
201+
forward_meta.caches[2 * layer.layer_id],
202+
forward_meta.caches[2 * layer.layer_id + 1],
203+
forward_meta.seq_lens_encoder,
204+
forward_meta.seq_lens_decoder,
205+
metadata.block_tables,
206+
self.num_heads,
207+
self.kv_num_heads,
208+
self.head_dim,
209+
self.max_seq_len,
210+
self.block_size,
211+
)
212+
return res

fastdeploy/model_executor/layers/backends/npu/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,6 @@
1515
"""
1616
npu backend methods
1717
"""
18+
from .quantization.weight_only import NPUWeightOnlyLinearMethod
19+
20+
__all__ = ['NPUWeightOnlyLinearMethod']
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import paddle
18+
from fastdeploy.model_executor.layers.quantization.weight_only import (
19+
WeightOnlyConfig, WeightOnlyLinearMethod)
20+
from fastdeploy.model_executor.ops.npu import fused_linear_op as weight_only_linear
21+
from fastdeploy.model_executor.ops.npu import npu_quant_weight
22+
# import inspect
23+
24+
class NPUWeightOnlyLinearMethod(WeightOnlyLinearMethod):
25+
"""
26+
Weight only quantization method for linear layer on NPU
27+
"""
28+
29+
def __init__(
30+
self,
31+
quant_config: WeightOnlyConfig,
32+
) -> None:
33+
super().__init__(quant_config)
34+
35+
def create_weights(self, layer):
36+
"""
37+
Create weights for linear layer on NPU
38+
"""
39+
40+
linear_weight_scale_shape = [layer.embed_dim]
41+
if hasattr(layer, "linear_weight_shape"):
42+
if isinstance(layer.linear_weight_shape, list):
43+
layer_weight_shape = layer.linear_weight_shape
44+
linear_weight_scale_shape = layer_weight_shape[:1]
45+
46+
layer.linear_weight_scale = layer.create_parameter(
47+
shape=linear_weight_scale_shape,
48+
dtype="bfloat16",
49+
is_bias=False,
50+
)
51+
52+
def process_loaded_weights(self, layer, weight) -> None:
53+
"""
54+
loaded_weights using npu special quantization
55+
"""
56+
57+
quanted_weight_tensor, weight_scale_tensor = npu_quant_weight(weight)
58+
layer.linear_weight.set_value(quanted_weight_tensor.T)
59+
layer.linear_weight_scale.set_value(
60+
weight_scale_tensor.astype(paddle.get_default_dtype())
61+
)
62+
63+
def apply(self, layer, x):
64+
linear_out = weight_only_linear(
65+
x,
66+
weight=layer.linear_weight.T,
67+
weight_scale=layer.linear_weight_scale,
68+
)
69+
return linear_out

fastdeploy/model_executor/layers/linear.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
or current_platform.is_gcu()
109109
or current_platform.is_dcu()
110110
or current_platform.is_maca()
111+
or current_platform.is_npu()
111112
):
112113
self.forward = self.forward_cuda
113114
else:
@@ -555,6 +556,9 @@ def load_weight(self, state_dict: dict):
555556
if self.fd_config.quant_config:
556557
self.quant_method.process_loaded_weights(self, weight_tensor)
557558
else:
559+
# Handle dtype conversion for NPU compatibility
560+
if self.weight.dtype != weight_tensor.dtype: #FIXME: guozr 这里可能问题所在
561+
weight_tensor = weight_tensor.cast(self.weight.dtype)
558562
self.weight.set_value(weight_tensor)
559563

560564
def load_state_dict(self, state_dict: dict):

fastdeploy/model_executor/layers/normalization.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424

2525
if current_platform.is_gcu():
2626
from fastdeploy.model_executor.ops.gcu import fused_add_rms_norm, rms_norm
27+
elif current_platform.is_npu():
28+
from fastdeploy.model_executor.ops.npu import rms_norm_npu as fused_rms_norm
29+
from paddle.incubate.nn.functional import fused_layer_norm
2730
else:
2831
from paddle.incubate.nn.functional import fused_layer_norm, fused_rms_norm
2932

fastdeploy/model_executor/layers/quantization/weight_only.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def get_quant_method(self, layer) -> Optional[QuantMethodBase]:
104104
else:
105105

106106
return GPUWeightOnlyLinearMethod(self)
107+
elif current_platform.is_npu():
108+
from fastdeploy.model_executor.layers.backends import NPUWeightOnlyLinearMethod
109+
return NPUWeightOnlyLinearMethod(self)
107110
else:
108111
if isinstance(layer, FusedMoE):
109112
if layer.use_method == "cutlass":

0 commit comments

Comments
 (0)