Skip to content

Commit 87179cb

Browse files
authored
[XPU] support XPU VL model inference (#4030)
* [XPU] support XPU VL model inference * fix image op import and device check * rebase develop * fix perf
1 parent e36eccf commit 87179cb

File tree

18 files changed

+1297
-143
lines changed

18 files changed

+1297
-143
lines changed

custom_ops/xpu_ops/src/ops/block_attn.cc

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
4141
const paddle::Tensor &encoder_seq_lod_cpu,
4242
const paddle::Tensor &encoder_batch_map_cpu,
4343
const paddle::Tensor &decoder_context_len_cpu,
44-
const paddle::Tensor &decoder_batch_map_cpu) {
44+
const paddle::Tensor &decoder_batch_map_cpu,
45+
const std::string &pos_emb_type="NORMAL",
46+
bool rope_3d=false) {
4547
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
4648
auto dev_ctx =
4749
paddle::experimental::DeviceContextPool::Instance().Get(place);
@@ -72,6 +74,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
7274
int enc_batch = enc_batch_tensor.data<int32_t>()[0];
7375
int dec_batch = dec_batch_tensor.data<int32_t>()[0];
7476
int total_enc_len = total_enc_len_tensor.data<int32_t>()[0];
77+
int rope_max_seqlen = 0;
78+
int rope_3d_num_seqs = 1;
79+
if (rope_3d) {
80+
rope_max_seqlen = rotary_embs.dims()[3];
81+
rope_3d_num_seqs = rotary_embs.dims()[0];
82+
} else {
83+
rope_max_seqlen = rotary_embs.dims()[2];
84+
}
7585

7686
auto block_attn_out =
7787
paddle::full({token_num, hidden_dim}, -1, qkv.type(), qkv.place());
@@ -151,10 +161,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
151161
prefix_lens_vp, // start_tokens
152162
param.batch_size, // batch_size
153163
1, // emb_batch_size
154-
rotary_embs.dims()[2], // max_seqlen
164+
rope_max_seqlen, // max_seqlen
155165
param.head_num, param.kv_head_num, param.head_dim,
156166
param.max_batch_size, block_size, max_block_per_seq, "BLHD",
157-
"HLD", "NORMAL",
167+
"HLD", pos_emb_type,
158168
!p_kcache_perhead_scale.defined()
159169
? nullptr
160170
: p_kcache_perhead_scale.data<float>() +
@@ -246,10 +256,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
246256
vsl.slot_mapping_vp, // real_batch
247257
param.batch_size, // batch_size
248258
1, // emb_batch_size
249-
rotary_embs.dims()[2], // max_seqlen TODO!!double check
259+
rope_max_seqlen, // max_seqlen
250260
param.head_num, param.kv_head_num, param.head_dim,
251261
param.max_batch_size, block_size, max_block_per_seq, "BLHD", "HLD",
252-
"NORMAL",
262+
pos_emb_type,
253263
!p_kcache_perhead_scale.defined()
254264
? nullptr
255265
: p_kcache_perhead_scale.data<float>() +
@@ -260,7 +270,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
260270
param.kv_head_num, // v_cache_scale_inv
261271
nullptr, // k_cache_zp
262272
nullptr, // v_cache_zp
263-
false); // b_c8_pc
273+
false, // b_c8_pc
274+
rope_3d, // rope_3d
275+
rope_3d_num_seqs);
264276
XFTBLOCK_CHECK_EQ(ret, api::SUCCESS);
265277

266278
// attn decode
@@ -314,6 +326,7 @@ PD_BUILD_OP(block_attn)
314326
"decoder_context_len_cpu",
315327
"decoder_batch_map_cpu",
316328
})
329+
.Attrs({"pos_emb_type:std::string", "rope_3d:bool"})
317330
.Outputs({"block_attn_out"})
318331
.SetKernelFn(PD_KERNEL(BlockAttnKernel))
319332
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/extension.h"
16+
17+
std::vector<paddle::Tensor> GetImgBoundaries(const paddle::Tensor& task_input_ids,
18+
const paddle::Tensor& grid_thw,
19+
const int64_t image_patch_id) {
20+
// All tensor in cpu
21+
auto input_ids_ptr = task_input_ids.data<int64_t>();
22+
int64_t seq_lens_origin = task_input_ids.numel();
23+
auto grid_thw_ptr = grid_thw.data<int64_t>();
24+
25+
int token_times = 4;
26+
int token_idx = 0;
27+
int image_idx = 0;
28+
std::vector<int> img_boundaries, img_nums;
29+
img_boundaries.emplace_back(0);
30+
img_nums.emplace_back(0);
31+
while (token_idx < seq_lens_origin) {
32+
if (input_ids_ptr[token_idx] != image_patch_id) {
33+
do {
34+
token_idx++;
35+
} while (token_idx < seq_lens_origin && input_ids_ptr[token_idx] != image_patch_id);
36+
} else {
37+
int cur_image_token_len = (grid_thw_ptr[image_idx * 3 + 1] * grid_thw_ptr[image_idx * 3 + 2]) / token_times;
38+
image_idx++;
39+
token_idx += cur_image_token_len;
40+
}
41+
img_boundaries.emplace_back(token_idx);
42+
img_nums.emplace_back(image_idx);
43+
}
44+
45+
int64_t num_img_boundaries = static_cast<int64_t>(img_boundaries.size());
46+
auto out = paddle::full({2, num_img_boundaries}, 0, paddle::DataType::INT64, paddle::CPUPlace());
47+
48+
for (int i = 0; i < num_img_boundaries; i++) {
49+
out.data<int64_t>()[i] = img_boundaries[i];
50+
out.data<int64_t>()[num_img_boundaries + i] = img_nums[i];
51+
}
52+
53+
return {out};
54+
}
55+
56+
PD_BUILD_OP(get_img_boundaries)
57+
.Inputs({"task_input_ids", "grid_thw"})
58+
.Attrs({"image_patch_id: int64_t"})
59+
.Outputs({"img_boundaries"})
60+
.SetKernelFn(PD_KERNEL(GetImgBoundaries));

custom_ops/xpu_ops/src/ops/moe_layer.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,17 @@ std::vector<paddle::Tensor> MoeLayerKernel(
145145
? up_gate_proj_weight_scale.get_ptr()->data<float>()
146146
: nullptr),
147147
xftblock_tw,
148-
std::vector<int64_t>{expert_num, inter_dim, hidden_dim});
148+
std::vector<int64_t>{expert_num, inter_dim, hidden_dim}
149+
);
149150

150151
xdown_proj_w = std::make_shared<xftblock::Tensor>(
151152
const_cast<TW *>(down_proj_weight.data<TW>()), nullptr,
152153
const_cast<float *>(down_proj_weight_scale.get_ptr()
153154
? down_proj_weight_scale.get_ptr()->data<float>()
154155
: nullptr),
155156
xftblock_tw,
156-
std::vector<int64_t>{expert_num, hidden_dim, outer_dim});
157+
std::vector<int64_t>{expert_num, hidden_dim, outer_dim}
158+
);
157159
}
158160
std::shared_ptr<xftblock::Tensor> xup_gate_proj_bias;
159161
std::shared_ptr<xftblock::Tensor> xdown_proj_bias;
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <paddle/phi/backends/xpu/xpu_context.h>
16+
#include <xft/xdnn_plugin.h>
17+
#include "paddle/extension.h"
18+
#include "xpu/plugin.h"
19+
20+
void TextImageGatherScatter(
21+
paddle::Tensor& input,
22+
paddle::Tensor& text_input,
23+
paddle::Tensor& image_input,
24+
paddle::Tensor& token_type_ids,
25+
paddle::Tensor& text_index,
26+
paddle::Tensor& image_index,
27+
const bool is_scatter) {
28+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
29+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
30+
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
31+
32+
const int64_t token_num = input.dims()[0];
33+
const int64_t hidden_size = input.dims()[1];
34+
const int64_t text_token_num = text_input.dims()[0];
35+
const int64_t image_token_num = image_input.dims()[0];
36+
37+
switch (input.type()) {
38+
case paddle::DataType::BFLOAT16: {
39+
using XPUType = typename XPUTypeTrait<bfloat16>::Type;
40+
typedef paddle::bfloat16 data_t;
41+
int r = baidu::xpu::api::plugin::text_image_gather_scatter<XPUType>(
42+
xpu_ctx->x_context(),
43+
reinterpret_cast<XPUType*>(input.data<data_t>()),
44+
reinterpret_cast<XPUType*>(text_input.data<data_t>()),
45+
reinterpret_cast<XPUType*>(image_input.data<data_t>()),
46+
reinterpret_cast<int*>(token_type_ids.data<int>()),
47+
reinterpret_cast<int*>(text_index.data<int>()),
48+
reinterpret_cast<int*>(image_index.data<int>()),
49+
token_num,
50+
text_token_num,
51+
image_token_num,
52+
hidden_size,
53+
is_scatter
54+
);
55+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_gather_scatter");
56+
break;
57+
}
58+
default: {
59+
PD_THROW(
60+
"NOT supported data type. Only support BFLOAT16. ");
61+
break;
62+
}
63+
}
64+
}
65+
66+
67+
PD_BUILD_OP(text_image_gather_scatter)
68+
.Inputs({"input",
69+
"text_input",
70+
"image_input",
71+
"token_type_ids",
72+
"text_index",
73+
"image_index"})
74+
.Outputs({"text_input_out",
75+
"image_input_out",
76+
"text_index_out",
77+
"image_index_out"})
78+
.Attrs({"is_scatter:bool"})
79+
.SetInplaceMap({{"text_input", "text_input_out"},
80+
{"image_input", "image_input_out"},
81+
{"text_index", "text_index_out"},
82+
{"image_index", "image_index_out"}})
83+
.SetKernelFn(PD_KERNEL(TextImageGatherScatter));
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <paddle/phi/backends/xpu/xpu_context.h>
16+
#include "paddle/extension.h"
17+
#include "xpu/plugin.h"
18+
19+
void TextImageIndexOut(
20+
const paddle::Tensor& token_type_ids,
21+
const paddle::Tensor& text_index,
22+
const paddle::Tensor& image_index) {
23+
if (token_type_ids.type() != paddle::DataType::INT32 || text_index.type()
24+
!= paddle::DataType::INT32 || image_index.type() != paddle::DataType::INT32) {
25+
PD_THROW("NOT supported data type. Only support BFLOAT16. ");
26+
}
27+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
28+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
29+
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
30+
const int64_t token_num = token_type_ids.shape()[0];
31+
int r = baidu::xpu::api::plugin::text_image_index_out(xpu_ctx->x_context(),
32+
token_type_ids.data<int32_t>(),
33+
const_cast<int32_t*>(text_index.data<int32_t>()),
34+
const_cast<int32_t*>(image_index.data<int32_t>()),
35+
token_num);
36+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "text_image_index_out");
37+
}
38+
39+
40+
PD_BUILD_OP(text_image_index_out)
41+
.Inputs({"token_type_ids",
42+
"text_index",
43+
"image_index"})
44+
.Outputs({"text_index_out",
45+
"image_index_out"})
46+
.SetInplaceMap({{"text_index", "text_index_out"},
47+
{"image_index", "image_index_out"}})
48+
.SetKernelFn(PD_KERNEL(TextImageIndexOut));

custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,25 @@ DLL_EXPORT int quant2d_per_channel(api::Context *ctx, const TX *x,
140140
const TSCALE *scale_in, TY *y,
141141
TSCALE *scale_out, int64_t m, int64_t n);
142142

143+
DLL_EXPORT int text_image_index_out(Context* ctx,
144+
const int* token_type_ids, // x
145+
int* text_index, // y1
146+
int* image_index, // y2
147+
const int64_t token_num);
148+
149+
template <typename T>
150+
DLL_EXPORT int text_image_gather_scatter(api::Context* ctx,
151+
T* input,
152+
T* text_input,
153+
T* image_input,
154+
int* token_type_ids,
155+
int* text_index,
156+
int* image_index,
157+
int64_t token_num,
158+
int64_t text_token_num,
159+
int64_t image_token_num,
160+
int64_t hidden_size,
161+
bool is_scatter);
143162

144163
/*--------------------------------------- MTP being --------------------------------------------*/
145164

0 commit comments

Comments
 (0)