Skip to content

Commit 443731e

Browse files
committed
[Feature] update mm chunked_prefill op
1 parent 331c4d2 commit 443731e

File tree

3 files changed

+182
-0
lines changed

3 files changed

+182
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/extension.h"
16+
#include <map>
17+
18+
19+
std::vector<paddle::Tensor> GetMmSplitFuseV2(const paddle::Tensor& task_input_ids,
20+
const paddle::Tensor& task_image_type_ids,
21+
const paddle::Tensor& task_input_ids_image_token_count,
22+
const paddle::Tensor& grid_thw,
23+
int64_t image_token_id,
24+
int64_t img_total,
25+
int seq_lens_origin,
26+
int split_fuse_text_size) {
27+
// All tensor in cpu
28+
auto input_ids_cpu = task_input_ids.data<int64_t>();
29+
auto task_input_ids_image_token_count_cpu = task_input_ids_image_token_count.data<int>();
30+
auto grid_thw_cpu = grid_thw.data<int64_t>();
31+
std::vector<int> image_chunk_selections_vector; // 当前chunk 图片数目
32+
std::vector<int> split_fuse_cur_seq_lens_vector; // 当前chunk 长度
33+
std::vector<int> split_fuse_cur_mm_lens_vector; // 当前chunk mm_token数目
34+
// [预处理] 记录可划分chunk的位置
35+
std::map<int, int> mp;
36+
mp[0] = 1; // init
37+
int st_idx = 0, last_st_ib = 0;
38+
int idx = 0;
39+
while (st_idx < seq_lens_origin) {
40+
// 1. 当前st_idx为文本,找到文本末尾
41+
if (input_ids_cpu[st_idx] != image_token_id) {
42+
do {
43+
st_idx ++;
44+
} while (st_idx < seq_lens_origin && input_ids_cpu[st_idx] != image_token_id);
45+
mp[st_idx] = 1; // 记录划分chunk的末尾位置,此处为文本的末位+1
46+
} else { // 2. 当前 st_idx 为多模,根据多模token的长度找到末尾
47+
int ib = last_st_ib;
48+
int cur_st_len = 0;
49+
int token_times = 4;
50+
cur_st_len = (grid_thw_cpu[ib * 3 + 1] * grid_thw_cpu[ib * 3 + 2]) / token_times;
51+
mp[st_idx + cur_st_len] = 1;
52+
last_st_ib = ++ib;
53+
st_idx += cur_st_len;
54+
}
55+
}
56+
int chunk_image_number = 0;
57+
int last_id = 0;
58+
for (idx = 0; idx < seq_lens_origin; idx++) {
59+
if (mp[idx] == 1 && input_ids_cpu[idx] == image_token_id) {
60+
chunk_image_number ++;
61+
}
62+
if (idx > 0 && (idx + 1) % split_fuse_text_size == 0 || idx == seq_lens_origin - 1) {
63+
int chunk_start = last_id * split_fuse_text_size;
64+
int chunk_end = idx;
65+
int chunk_image_token_number = task_input_ids_image_token_count_cpu[chunk_end + 1] - task_input_ids_image_token_count_cpu[chunk_start];
66+
image_chunk_selections_vector.emplace_back(chunk_image_number);
67+
split_fuse_cur_seq_lens_vector.emplace_back(chunk_end - chunk_start + 1);
68+
split_fuse_cur_mm_lens_vector.emplace_back(chunk_image_token_number);
69+
chunk_image_number = 0;
70+
last_id = (idx + 1) / split_fuse_text_size;
71+
}
72+
}
73+
// vector to cpu tensor
74+
auto image_chunk_selections_out_cpu = paddle::from_blob(image_chunk_selections_vector.data(), {image_chunk_selections_vector.size()}, task_image_type_ids.dtype());
75+
auto split_fuse_cur_seq_lens_out_cpu = paddle::from_blob(split_fuse_cur_seq_lens_vector.data(), {split_fuse_cur_seq_lens_vector.size()}, task_image_type_ids.dtype());
76+
auto split_fuse_cur_mm_lens_out_cpu = paddle::from_blob(split_fuse_cur_mm_lens_vector.data(), {split_fuse_cur_mm_lens_vector.size()}, task_image_type_ids.dtype());
77+
// cpu tensor to gpu tensor
78+
auto image_chunk_selections_out = paddle::experimental::copy_to(image_chunk_selections_out_cpu, task_image_type_ids.place(), false);
79+
auto split_fuse_cur_seq_lens_out = paddle::experimental::copy_to(split_fuse_cur_seq_lens_out_cpu, task_image_type_ids.place(), false);
80+
auto split_fuse_cur_mm_lens_out = paddle::experimental::copy_to(split_fuse_cur_mm_lens_out_cpu, task_image_type_ids.place(), false);
81+
return {image_chunk_selections_out, split_fuse_cur_seq_lens_out, split_fuse_cur_mm_lens_out};
82+
}
83+
84+
PD_BUILD_OP(get_mm_split_fuse_v2)
85+
.Inputs({"task_input_ids", "task_image_type_ids", "task_input_ids_image_token_count", "grid_thw"})
86+
.Attrs({"image_token_id: int64_t", "img_total: int64_t", "seq_lens_origin: int", "split_fuse_text_size: int"})
87+
.Outputs({"image_chunk_selections", "split_fuse_cur_seq_lens", "split_fuse_cur_mm_lens_out"})
88+
.SetKernelFn(PD_KERNEL(GetMmSplitFuseV2));

custom_ops/setup_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def find_end_files(directory, end_str):
261261
"gpu_ops/gather_idx.cu",
262262
"gpu_ops/get_output_ep.cc",
263263
"gpu_ops/get_mm_split_fuse.cc",
264+
"gpu_ops/get_mm_split_fuse_v2.cc",
264265
"gpu_ops/get_img_boundaries.cc",
265266
"gpu_ops/token_penalty_multi_scores.cu",
266267
"gpu_ops/token_penalty_only_once.cu",
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""UT for set_stop_value"""
16+
import unittest
17+
18+
import paddle
19+
20+
from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse_v2
21+
22+
23+
class TestSplitFuse(unittest.TestCase):
24+
def setUp(self):
25+
self.grid_thw = [[6, 20, 20], [6, 40, 20]]
26+
self.split_fuse_img_size = 16
27+
self.split_fuse_text_size = 384 # 1024
28+
self.max_seq_len = 2048
29+
self.image_token_id = 100295
30+
31+
def split_grid(self, origin_grid_thw):
32+
# 划分grid_thw,该函数用于视频场景
33+
# origin_grid_thw = [6, 10, 12] ---> [2, 10, 12, 2, 10, 12, 2, 10, 12]
34+
grid_thw = []
35+
for t, h, w in origin_grid_thw:
36+
if t > 2:
37+
num_groups = t // 2
38+
remainder = t % 2
39+
for _ in range(num_groups):
40+
grid_thw.extend([2, h, w])
41+
if remainder > 0:
42+
grid_thw.extend([remainder, h, w])
43+
else:
44+
grid_thw.extend([t, h, w])
45+
return grid_thw
46+
47+
def test_get_mm_split_fuse(self):
48+
grid_thw = self.split_grid(self.grid_thw)
49+
image_bs = len(grid_thw) // 3
50+
image_type_ids = [0] * image_bs
51+
52+
# 随机拼接input_ids: [txt0+img1+tx1+img2]
53+
input_ids = [2] * 19
54+
img1 = [self.image_token_id] * 100 * 3
55+
txt1 = [3] * 19
56+
img2 = [self.image_token_id] * 200 * 3
57+
input_ids.extend(img1)
58+
input_ids.extend(txt1)
59+
input_ids.extend(img2)
60+
61+
seq_len = len(input_ids)
62+
input_ids_tensor = paddle.to_tensor(input_ids, dtype="int64")
63+
image_type_ids_tensor = paddle.to_tensor(image_type_ids, dtype="int32")
64+
is_image_token = paddle.where(input_ids_tensor == self.image_token_id, 1, 0)
65+
image_token_sum = paddle.cumsum(is_image_token) # 前缀和
66+
image_token_sum = paddle.concat([paddle.zeros([1], dtype="int64"), image_token_sum])
67+
68+
grid_thw_tensor = paddle.to_tensor(grid_thw, dtype="int64")
69+
70+
image_chunk_selections, split_fuse_cur_seq_lens, split_fuse_cur_mm_lens = get_mm_split_fuse_v2(
71+
input_ids_tensor.cpu(),
72+
image_type_ids_tensor.cast("int32").cpu(),
73+
image_token_sum.cast("int32").cpu(),
74+
grid_thw_tensor.cpu(),
75+
self.image_token_id,
76+
image_bs,
77+
seq_len,
78+
self.split_fuse_text_size,
79+
)
80+
81+
# Verify the outputs are not None
82+
self.assertIsNotNone(image_chunk_selections)
83+
self.assertIsNotNone(split_fuse_cur_seq_lens)
84+
self.assertIsNotNone(split_fuse_cur_mm_lens)
85+
86+
# Verify the shapes are as expected
87+
self.assertEqual(len(image_chunk_selections.shape), 1)
88+
self.assertEqual(len(split_fuse_cur_seq_lens.shape), 1)
89+
self.assertEqual(len(split_fuse_cur_mm_lens.shape), 1)
90+
91+
92+
if __name__ == "__main__":
93+
unittest.main()

0 commit comments

Comments
 (0)