diff --git a/custom_ops/gpu_ops/get_mm_split_fuse_v2.cc b/custom_ops/gpu_ops/get_mm_split_fuse_v2.cc new file mode 100644 index 0000000000..7114b38a0d --- /dev/null +++ b/custom_ops/gpu_ops/get_mm_split_fuse_v2.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" +#include + + +std::vector GetMmSplitFuseV2(const paddle::Tensor& task_input_ids, + const paddle::Tensor& task_image_type_ids, + const paddle::Tensor& task_input_ids_image_token_count, + const paddle::Tensor& grid_thw, + int64_t image_token_id, + int64_t img_total, + int seq_lens_origin, + int split_fuse_text_size) { + // All tensor in cpu + auto input_ids_cpu = task_input_ids.data(); + auto task_input_ids_image_token_count_cpu = task_input_ids_image_token_count.data(); + auto grid_thw_cpu = grid_thw.data(); + std::vector image_chunk_selections_vector; // 当前chunk 图片数目 + std::vector split_fuse_cur_seq_lens_vector; // 当前chunk 长度 + std::vector split_fuse_cur_mm_lens_vector; // 当前chunk mm_token数目 + // [预处理] 记录可划分chunk的位置 + std::map mp; + mp[0] = 1; // init + int st_idx = 0, last_st_ib = 0; + int idx = 0; + while (st_idx < seq_lens_origin) { + // 1. 当前st_idx为文本,找到文本末尾 + if (input_ids_cpu[st_idx] != image_token_id) { + do { + st_idx ++; + } while (st_idx < seq_lens_origin && input_ids_cpu[st_idx] != image_token_id); + mp[st_idx] = 1; // 记录划分chunk的末尾位置,此处为文本的末位+1 + } else { // 2. 当前 st_idx 为多模,根据多模token的长度找到末尾 + int ib = last_st_ib; + int cur_st_len = 0; + int token_times = 4; + cur_st_len = (grid_thw_cpu[ib * 3 + 1] * grid_thw_cpu[ib * 3 + 2]) / token_times; + mp[st_idx + cur_st_len] = 1; + last_st_ib = ++ib; + st_idx += cur_st_len; + } + } + int chunk_image_number = 0; + int last_id = 0; + for (idx = 0; idx < seq_lens_origin; idx++) { + if (mp[idx] == 1 && input_ids_cpu[idx] == image_token_id) { + chunk_image_number ++; + } + if (idx > 0 && (idx + 1) % split_fuse_text_size == 0 || idx == seq_lens_origin - 1) { + int chunk_start = last_id * split_fuse_text_size; + int chunk_end = idx; + int chunk_image_token_number = task_input_ids_image_token_count_cpu[chunk_end + 1] - task_input_ids_image_token_count_cpu[chunk_start]; + image_chunk_selections_vector.emplace_back(chunk_image_number); + split_fuse_cur_seq_lens_vector.emplace_back(chunk_end - chunk_start + 1); + split_fuse_cur_mm_lens_vector.emplace_back(chunk_image_token_number); + chunk_image_number = 0; + last_id = (idx + 1) / split_fuse_text_size; + } + } + // vector to cpu tensor + auto image_chunk_selections_out_cpu = paddle::from_blob(image_chunk_selections_vector.data(), {image_chunk_selections_vector.size()}, task_image_type_ids.dtype()); + auto split_fuse_cur_seq_lens_out_cpu = paddle::from_blob(split_fuse_cur_seq_lens_vector.data(), {split_fuse_cur_seq_lens_vector.size()}, task_image_type_ids.dtype()); + auto split_fuse_cur_mm_lens_out_cpu = paddle::from_blob(split_fuse_cur_mm_lens_vector.data(), {split_fuse_cur_mm_lens_vector.size()}, task_image_type_ids.dtype()); + // cpu tensor to gpu tensor + auto image_chunk_selections_out = paddle::experimental::copy_to(image_chunk_selections_out_cpu, task_image_type_ids.place(), false); + auto split_fuse_cur_seq_lens_out = paddle::experimental::copy_to(split_fuse_cur_seq_lens_out_cpu, task_image_type_ids.place(), false); + auto split_fuse_cur_mm_lens_out = paddle::experimental::copy_to(split_fuse_cur_mm_lens_out_cpu, task_image_type_ids.place(), false); + return {image_chunk_selections_out, split_fuse_cur_seq_lens_out, split_fuse_cur_mm_lens_out}; +} + +PD_BUILD_OP(get_mm_split_fuse_v2) + .Inputs({"task_input_ids", "task_image_type_ids", "task_input_ids_image_token_count", "grid_thw"}) + .Attrs({"image_token_id: int64_t", "img_total: int64_t", "seq_lens_origin: int", "split_fuse_text_size: int"}) + .Outputs({"image_chunk_selections", "split_fuse_cur_seq_lens", "split_fuse_cur_mm_lens_out"}) + .SetKernelFn(PD_KERNEL(GetMmSplitFuseV2)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 3ca8c3c3f3..128f41facf 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -261,6 +261,7 @@ def find_end_files(directory, end_str): "gpu_ops/gather_idx.cu", "gpu_ops/get_output_ep.cc", "gpu_ops/get_mm_split_fuse.cc", + "gpu_ops/get_mm_split_fuse_v2.cc", "gpu_ops/get_img_boundaries.cc", "gpu_ops/token_penalty_multi_scores.cu", "gpu_ops/token_penalty_only_once.cu", diff --git a/tests/operators/test_split_fuse_v2.py b/tests/operators/test_split_fuse_v2.py new file mode 100644 index 0000000000..d922f48bd2 --- /dev/null +++ b/tests/operators/test_split_fuse_v2.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""UT for get_mm_split_fuse_v2""" +import unittest + +import paddle + +from fastdeploy.model_executor.ops.gpu import get_mm_split_fuse_v2 + + +class TestSplitFuseV2(unittest.TestCase): + def setUp(self): + self.grid_thw = [[6, 20, 20], [6, 40, 20]] + self.split_fuse_img_size = 16 + self.split_fuse_text_size = 384 # 1024 + self.max_seq_len = 2048 + self.image_token_id = 100295 + + def split_grid(self, origin_grid_thw): + # 划分grid_thw,该函数用于视频场景 + # origin_grid_thw = [6, 10, 12] ---> [2, 10, 12, 2, 10, 12, 2, 10, 12] + grid_thw = [] + for t, h, w in origin_grid_thw: + if t > 2: + num_groups = t // 2 + remainder = t % 2 + for _ in range(num_groups): + grid_thw.extend([2, h, w]) + if remainder > 0: + grid_thw.extend([remainder, h, w]) + else: + grid_thw.extend([t, h, w]) + return grid_thw + + def test_get_mm_split_fuse_v2(self): + grid_thw = self.split_grid(self.grid_thw) + image_bs = len(grid_thw) // 3 + image_type_ids = [0] * image_bs + + # 随机拼接input_ids: [txt0+img1+tx1+img2] + input_ids = [2] * 19 + img1 = [self.image_token_id] * 100 * 3 + txt1 = [3] * 19 + img2 = [self.image_token_id] * 200 * 3 + input_ids.extend(img1) + input_ids.extend(txt1) + input_ids.extend(img2) + + seq_len = len(input_ids) + input_ids_tensor = paddle.to_tensor(input_ids, dtype="int64") + image_type_ids_tensor = paddle.to_tensor(image_type_ids, dtype="int32") + is_image_token = paddle.where(input_ids_tensor == self.image_token_id, 1, 0) + image_token_sum = paddle.cumsum(is_image_token) # 前缀和 + image_token_sum = paddle.concat([paddle.zeros([1], dtype="int64"), image_token_sum]) + + grid_thw_tensor = paddle.to_tensor(grid_thw, dtype="int64") + + image_chunk_selections, split_fuse_cur_seq_lens, split_fuse_cur_mm_lens = get_mm_split_fuse_v2( + input_ids_tensor.cpu(), + image_type_ids_tensor.cast("int32").cpu(), + image_token_sum.cast("int32").cpu(), + grid_thw_tensor.cpu(), + self.image_token_id, + image_bs, + seq_len, + self.split_fuse_text_size, + ) + + # Verify the outputs are not None + self.assertIsNotNone(image_chunk_selections) + self.assertIsNotNone(split_fuse_cur_seq_lens) + self.assertIsNotNone(split_fuse_cur_mm_lens) + + # Verify the shapes are as expected + self.assertEqual(len(image_chunk_selections.shape), 1) + self.assertEqual(len(split_fuse_cur_seq_lens.shape), 1) + self.assertEqual(len(split_fuse_cur_mm_lens.shape), 1) + + +if __name__ == "__main__": + unittest.main()