|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/extension.h" |
| 16 | +#include <map> |
| 17 | + |
| 18 | + |
| 19 | +std::vector<paddle::Tensor> GetMmSplitFuseV2(const paddle::Tensor& task_input_ids, |
| 20 | + const paddle::Tensor& task_image_type_ids, |
| 21 | + const paddle::Tensor& task_input_ids_image_token_count, |
| 22 | + const paddle::Tensor& grid_thw, |
| 23 | + int64_t image_token_id, |
| 24 | + int64_t img_total, |
| 25 | + int seq_lens_origin, |
| 26 | + int split_fuse_text_size) { |
| 27 | + // All tensor in cpu |
| 28 | + auto input_ids_cpu = task_input_ids.data<int64_t>(); |
| 29 | + auto task_input_ids_image_token_count_cpu = task_input_ids_image_token_count.data<int>(); |
| 30 | + auto grid_thw_cpu = grid_thw.data<int64_t>(); |
| 31 | + std::vector<int> image_chunk_selections_vector; // 当前chunk 图片数目 |
| 32 | + std::vector<int> split_fuse_cur_seq_lens_vector; // 当前chunk 长度 |
| 33 | + std::vector<int> split_fuse_cur_mm_lens_vector; // 当前chunk mm_token数目 |
| 34 | + // [预处理] 记录可划分chunk的位置 |
| 35 | + std::map<int, int> mp; |
| 36 | + mp[0] = 1; // init |
| 37 | + int st_idx = 0, last_st_ib = 0; |
| 38 | + int idx = 0; |
| 39 | + while (st_idx < seq_lens_origin) { |
| 40 | + // 1. 当前st_idx为文本,找到文本末尾 |
| 41 | + if (input_ids_cpu[st_idx] != image_token_id) { |
| 42 | + do { |
| 43 | + st_idx ++; |
| 44 | + } while (st_idx < seq_lens_origin && input_ids_cpu[st_idx] != image_token_id); |
| 45 | + mp[st_idx] = 1; // 记录划分chunk的末尾位置,此处为文本的末位+1 |
| 46 | + } else { // 2. 当前 st_idx 为多模,根据多模token的长度找到末尾 |
| 47 | + int ib = last_st_ib; |
| 48 | + int cur_st_len = 0; |
| 49 | + int token_times = 4; |
| 50 | + cur_st_len = (grid_thw_cpu[ib * 3 + 1] * grid_thw_cpu[ib * 3 + 2]) / token_times; |
| 51 | + mp[st_idx + cur_st_len] = 1; |
| 52 | + last_st_ib = ++ib; |
| 53 | + st_idx += cur_st_len; |
| 54 | + } |
| 55 | + } |
| 56 | + int chunk_image_number = 0; |
| 57 | + int last_id = 0; |
| 58 | + for (idx = 0; idx < seq_lens_origin; idx++) { |
| 59 | + if (mp[idx] == 1 && input_ids_cpu[idx] == image_token_id) { |
| 60 | + chunk_image_number ++; |
| 61 | + } |
| 62 | + if (idx > 0 && (idx + 1) % split_fuse_text_size == 0 || idx == seq_lens_origin - 1) { |
| 63 | + int chunk_start = last_id * split_fuse_text_size; |
| 64 | + int chunk_end = idx; |
| 65 | + int chunk_image_token_number = task_input_ids_image_token_count_cpu[chunk_end + 1] - task_input_ids_image_token_count_cpu[chunk_start]; |
| 66 | + image_chunk_selections_vector.emplace_back(chunk_image_number); |
| 67 | + split_fuse_cur_seq_lens_vector.emplace_back(chunk_end - chunk_start + 1); |
| 68 | + split_fuse_cur_mm_lens_vector.emplace_back(chunk_image_token_number); |
| 69 | + chunk_image_number = 0; |
| 70 | + last_id = (idx + 1) / split_fuse_text_size; |
| 71 | + } |
| 72 | + } |
| 73 | + // vector to cpu tensor |
| 74 | + auto image_chunk_selections_out_cpu = paddle::from_blob(image_chunk_selections_vector.data(), {image_chunk_selections_vector.size()}, task_image_type_ids.dtype()); |
| 75 | + auto split_fuse_cur_seq_lens_out_cpu = paddle::from_blob(split_fuse_cur_seq_lens_vector.data(), {split_fuse_cur_seq_lens_vector.size()}, task_image_type_ids.dtype()); |
| 76 | + auto split_fuse_cur_mm_lens_out_cpu = paddle::from_blob(split_fuse_cur_mm_lens_vector.data(), {split_fuse_cur_mm_lens_vector.size()}, task_image_type_ids.dtype()); |
| 77 | + // cpu tensor to gpu tensor |
| 78 | + auto image_chunk_selections_out = paddle::experimental::copy_to(image_chunk_selections_out_cpu, task_image_type_ids.place(), false); |
| 79 | + auto split_fuse_cur_seq_lens_out = paddle::experimental::copy_to(split_fuse_cur_seq_lens_out_cpu, task_image_type_ids.place(), false); |
| 80 | + auto split_fuse_cur_mm_lens_out = paddle::experimental::copy_to(split_fuse_cur_mm_lens_out_cpu, task_image_type_ids.place(), false); |
| 81 | + return {image_chunk_selections_out, split_fuse_cur_seq_lens_out, split_fuse_cur_mm_lens_out}; |
| 82 | +} |
| 83 | + |
| 84 | +PD_BUILD_OP(get_mm_split_fuse_v2) |
| 85 | + .Inputs({"task_input_ids", "task_image_type_ids", "task_input_ids_image_token_count", "grid_thw"}) |
| 86 | + .Attrs({"image_token_id: int64_t", "img_total: int64_t", "seq_lens_origin: int", "split_fuse_text_size: int"}) |
| 87 | + .Outputs({"image_chunk_selections", "split_fuse_cur_seq_lens", "split_fuse_cur_mm_lens_out"}) |
| 88 | + .SetKernelFn(PD_KERNEL(GetMmSplitFuseV2)); |
0 commit comments