From 166eb1c4526c1a8f94db7c958067f49c39f73755 Mon Sep 17 00:00:00 2001 From: mihao <892949708@qq.com> Date: Mon, 4 Aug 2025 09:55:15 +0800 Subject: [PATCH] [Ascend_npu]add /banckend/npu/opp --- backends/npu/README.md | 11 + backends/npu/README_cn.md | 11 + backends/npu/opp/ascendc_custom_ops/LICENSE | 201 ++++++++ .../build/build_ascendc_ops.sh | 128 +++++ .../opp/ascendc_custom_ops/build/build_ops.sh | 74 +++ .../opp/ascendc_custom_ops/build/ir_demo.json | 40 ++ .../change_rank_id_to_device_id.py | 107 ++++ .../npu/opp/ascendc_custom_ops/changelog.txt | 23 + backends/npu/opp/ascendc_custom_ops/main.py | 238 +++++++++ .../onnx_plugin/set_mask_value_plugin.cpp | 45 ++ .../set_stop_value_multi_ends_plugin.cpp | 45 ++ .../set_value_by_flags_and_idx_plugin.cpp | 45 ++ .../token_penalty_multi_scores_plugin.cpp | 45 ++ .../src/ops/ascendc/op_host/get_max_len.cpp | 137 +++++ .../ops/ascendc/op_host/get_max_len_tiling.h | 25 + .../ascendc/op_host/get_padding_offset.cpp | 173 +++++++ .../op_host/get_padding_offset_tiling.h | 25 + .../ops/ascendc/op_host/rebuild_padding.cpp | 151 ++++++ .../ascendc/op_host/rebuild_padding_tiling.h | 27 + .../ops/ascendc/op_host/set_mask_value.cpp | 107 ++++ .../ascendc/op_host/set_mask_value_tiling.h | 10 + .../op_host/set_stop_value_multi_ends.cpp | 93 ++++ .../set_stop_value_multi_ends_tiling.h | 10 + .../op_host/set_stop_value_multi_ends_v2.cpp | 167 +++++++ .../set_stop_value_multi_ends_v2_tiling.h | 25 + .../op_host/set_stop_value_multi_seqs.cpp | 182 +++++++ .../set_stop_value_multi_seqs_tiling.h | 28 ++ .../op_host/set_value_by_flags_and_idx.cpp | 122 +++++ .../set_value_by_flags_and_idx_tiling.h | 10 + .../op_host/set_value_by_flags_and_idx_v2.cpp | 157 ++++++ .../set_value_by_flags_and_idx_v2_tiling.h | 26 + .../src/ops/ascendc/op_host/step_paddle.cpp | 469 ++++++++++++++++++ .../ops/ascendc/op_host/step_paddle_tiling.h | 16 + .../op_host/token_penalty_multi_scores.cpp | 199 ++++++++ .../token_penalty_multi_scores_tiling.h | 29 ++ .../op_host/token_penalty_multi_scores_v2.cpp | 214 ++++++++ .../token_penalty_multi_scores_v2_tiling.h | 31 ++ ...en_penalty_multi_scores_with_stop_seqs.cpp | 217 ++++++++ ...nalty_multi_scores_with_stop_seqs_tiling.h | 31 ++ .../src/ops/ascendc/op_host/update_inputs.cpp | 239 +++++++++ .../ascendc/op_host/update_inputs_tiling.h | 26 + .../src/ops/ascendc/op_kernel/get_max_len.cpp | 52 ++ .../ascendc/op_kernel/get_padding_offset.cpp | 180 +++++++ .../ops/ascendc/op_kernel/rebuild_padding.cpp | 125 +++++ .../ops/ascendc/op_kernel/set_mask_value.cpp | 52 ++ .../op_kernel/set_stop_value_multi_ends.cpp | 72 +++ .../set_stop_value_multi_ends_v2.cpp | 101 ++++ .../op_kernel/set_stop_value_multi_seqs.cpp | 140 ++++++ .../op_kernel/set_value_by_flags_and_idx.cpp | 55 ++ .../set_value_by_flags_and_idx_v2.cpp | 86 ++++ .../src/ops/ascendc/op_kernel/step_paddle.cpp | 248 +++++++++ .../op_kernel/token_penalty_multi_scores.cpp | 219 ++++++++ .../token_penalty_multi_scores_v2.cpp | 262 ++++++++++ ...en_penalty_multi_scores_with_stop_seqs.cpp | 239 +++++++++ .../ops/ascendc/op_kernel/update_inputs.cpp | 86 ++++ .../tests/ascendc/utils/common.h | 106 ++++ backends/npu/tools/set_env.sh | 1 + 57 files changed, 5983 insertions(+) create mode 100644 backends/npu/opp/ascendc_custom_ops/LICENSE create mode 100644 backends/npu/opp/ascendc_custom_ops/build/build_ascendc_ops.sh create mode 100644 backends/npu/opp/ascendc_custom_ops/build/build_ops.sh create mode 100644 backends/npu/opp/ascendc_custom_ops/build/ir_demo.json create mode 100644 backends/npu/opp/ascendc_custom_ops/change_rank_id_to_device_id.py create mode 100644 backends/npu/opp/ascendc_custom_ops/changelog.txt create mode 100644 backends/npu/opp/ascendc_custom_ops/main.py create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/set_mask_value_plugin.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/set_stop_value_multi_ends_plugin.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/set_value_by_flags_and_idx_plugin.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/token_penalty_multi_scores_plugin.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_max_len.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_max_len_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_padding_offset.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_padding_offset_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/rebuild_padding.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/rebuild_padding_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_mask_value.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_mask_value_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends_v2.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends_v2_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_seqs.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_seqs_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx_v2.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx_v2_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/step_paddle.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/step_paddle_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_v2.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_v2_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_with_stop_seqs.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_with_stop_seqs_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/update_inputs.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/update_inputs_tiling.h create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/get_max_len.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/get_padding_offset.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/rebuild_padding.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_mask_value.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_stop_value_multi_ends.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_stop_value_multi_ends_v2.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_stop_value_multi_seqs.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_value_by_flags_and_idx.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_value_by_flags_and_idx_v2.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/step_paddle.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/token_penalty_multi_scores.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/token_penalty_multi_scores_v2.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/token_penalty_multi_scores_with_stop_seqs.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/update_inputs.cpp create mode 100644 backends/npu/opp/ascendc_custom_ops/tests/ascendc/utils/common.h create mode 100644 backends/npu/tools/set_env.sh diff --git a/backends/npu/README.md b/backends/npu/README.md index 96962f25710..777daa3822a 100644 --- a/backends/npu/README.md +++ b/backends/npu/README.md @@ -59,6 +59,17 @@ bash tools/compile.sh # 5. install the generated whl package, which is under build/dist directory pip install build/dist/paddle_custom_npu*.whl + +# 6) Set the execution environment variables +source tools/set_env.sh + +# 7) Install the ops library +cd opp/ascend_custom_ops/build +bash build_ops.sh +cd custom_project/build_out/ +./custom_opp*.run + +Follow the prompts to input export ``` ## Verification diff --git a/backends/npu/README_cn.md b/backends/npu/README_cn.md index 89e8931199b..cca7a3da42d 100644 --- a/backends/npu/README_cn.md +++ b/backends/npu/README_cn.md @@ -57,6 +57,17 @@ bash tools/compile.sh # 5) 编译产出在 build/dist 路径下,使用 pip 安装 pip install build/dist/paddle_custom_npu*.whl + +# 6) 执行环境变量设置 +source tools/set_env.sh + +# 7) 安装ops库 +cd opp/ascend_custom_ops/build +bash build_ops.sh +cd custom_project/build_out/ +./custom_opp*.run + +依照提示输入 export ``` ### 基础功能检查 diff --git a/backends/npu/opp/ascendc_custom_ops/LICENSE b/backends/npu/opp/ascendc_custom_ops/LICENSE new file mode 100644 index 00000000000..261eeb9e9f8 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/backends/npu/opp/ascendc_custom_ops/build/build_ascendc_ops.sh b/backends/npu/opp/ascendc_custom_ops/build/build_ascendc_ops.sh new file mode 100644 index 00000000000..44d6b33dd93 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/build/build_ascendc_ops.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# -*- coding: utf-8 -*- + +src=${current_script_dir}/../src/ops/ascendc +dst=${current_script_dir}/custom_project + +function create_empty_custom_project(){ + cd ${current_script_dir} + rm -rf ${dst} + ${msopgen} gen -i ir_demo.json -f onnx \ + -c ai_core-ascend310p,ai_core-ascend910,ai_core-ascend910b -lan cpp -out ${dst} + rm ${dst}/framework/onnx_plugin/*.cc + rm ${dst}/op_host/*.h + rm ${dst}/op_host/*.cpp + rm ${dst}/op_kernel/*.cpp +} + +function release_framework_onnx(){ + cd ${src}/framework/onnx_plugin + # 如需控制哪些文件发布,可以按照字母序列举具体文件 + local files=( + set_mask_value_plugin.cpp + set_stop_value_multi_ends_plugin.cpp + set_value_by_flags_and_idx_plugin.cpp + token_penalty_multi_scores_plugin.cpp + ) + cp ${files[@]} ${dst}/framework/onnx_plugin +} + +function release_op_host(){ + cd ${src}/op_host + local files=( + set_value_by_flags_and_idx_tiling.h + set_value_by_flags_and_idx.cpp + + set_value_by_flags_and_idx_v2_tiling.h + set_value_by_flags_and_idx_v2.cpp + + set_stop_value_multi_ends_tiling.h + set_stop_value_multi_ends.cpp + + set_stop_value_multi_ends_v2_tiling.h + set_stop_value_multi_ends_v2.cpp + + set_stop_value_multi_seqs_tiling.h + set_stop_value_multi_seqs.cpp + + set_mask_value_tiling.h + set_mask_value.cpp + + token_penalty_multi_scores_tiling.h + token_penalty_multi_scores.cpp + + token_penalty_multi_scores_v2_tiling.h + token_penalty_multi_scores_v2.cpp + + token_penalty_multi_scores_with_stop_seqs_tiling.h + token_penalty_multi_scores_with_stop_seqs.cpp + + update_inputs_tiling.h + update_inputs.cpp + + get_max_len_tiling.h + get_max_len.cpp + + rebuild_padding_tiling.h + rebuild_padding.cpp + + get_padding_offset_tiling.h + get_padding_offset.cpp + + step_paddle_tiling.h + step_paddle.cpp + ) + cp ${files[@]} ${dst}/op_host +} + +function release_op_kernel(){ + cd ${src}/op_kernel + local files=( + set_value_by_flags_and_idx.cpp + set_value_by_flags_and_idx_v2.cpp + set_stop_value_multi_ends.cpp + set_stop_value_multi_ends_v2.cpp + set_stop_value_multi_seqs.cpp + set_mask_value.cpp + token_penalty_multi_scores.cpp + token_penalty_multi_scores_v2.cpp + token_penalty_multi_scores_with_stop_seqs.cpp + update_inputs.cpp + get_max_len.cpp + rebuild_padding.cpp + get_padding_offset.cpp + step_paddle.cpp + ) + cp ${files[@]} ${dst}/op_kernel +} + +function revise_settings(){ + cd ${dst} + sed -i "s#/usr/local/Ascend/latest#${local_toolkit}#g" CMakePresets.json + sed -i "s#\"value\": \"customize\"#\"value\": \"aie_ascendc\"#g" CMakePresets.json + sed -i "s#\"value\": \"True\"#\"value\": \"False\"#g" CMakePresets.json + + local line_num=$(grep -Fn "ENABLE_SOURCE_PACKAGE" CMakePresets.json | cut -d : -f 1) + local offset_line_num=$((line_num+2)) + sed -i "${offset_line_num}s#\"value\": \"False\"#\"value\": \"True\"#g" CMakePresets.json +} + +function build_and_install(){ + cd ${dst} + bash build.sh + bash ${dst}/build_out/*.run --install-path=${current_script_dir} +} + +function build_ascendc_ops(){ + ori_path=${PWD} + + create_empty_custom_project + release_framework_onnx + release_op_host + release_op_kernel + revise_settings + build_and_install + cd ${ori_path} +} + +build_ascendc_ops diff --git a/backends/npu/opp/ascendc_custom_ops/build/build_ops.sh b/backends/npu/opp/ascendc_custom_ops/build/build_ops.sh new file mode 100644 index 00000000000..c86ec779e0b --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/build/build_ops.sh @@ -0,0 +1,74 @@ +#!/bin/bash +# -*- coding: utf-8 -*- + +# 构建环境使用CANN主线包,容易引入兼容性问题。同时为了更好地控制对外发布内容,我们 +# 在构建环境用msopgen工具生成工程,然后将要发布的算子交付件拷贝到新生成的工程构建 +set -e + +is_ci_build="n" +current_script_dir=$(dirname $(readlink -f $0)) +# 构建过程source该脚本需要传递实际路径,通过参数数量判断是否为构建流程 +if [ $# -ne 0 ]; then + is_ci_build="y" + current_script_dir=$(realpath $1) + if [ ! -f ${current_script_dir}/build_ops.sh ]; then + echo "${current_script_dir}/build_ops.sh not exists" + exit 1 + fi + if [ "x${RELEASE_TMP_DIR}" == "x" ]; then + echo "Did not define correct RELEASE_TMP_DIR" + exit 1 + fi + release_path=$(realpath ${RELEASE_TMP_DIR}) + if [ ! -d ${release_path} ]; then + echo "Invalid RELEASE_TMP_DIR" + exit 1 + fi + # 构建环境的toolkit默认安装路径 + local_toolkit=/home/slave1/Ascend/ascend-toolkit/latest +else + # 对于非构建环境,推荐整包安装,通过source set_env.sh脚本会定义环境变量 + if [ "x${ASCEND_TOOLKIT_HOME}" != "x" ]; then + local_toolkit=${ASCEND_TOOLKIT_HOME} + else + echo "Can not find toolkit path, please set ASCEND_TOOLKIT_HOME" + echo "eg: export ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest" + exit 1 + fi +fi + +msopgen=${local_toolkit}/python/site-packages/bin/msopgen +if [ ! -f ${msopgen} ]; then + echo "${msopgen} not exists" + exit 1 +fi + +function make_package(){ + cd ${current_script_dir} + rm -rf pkg + mkdir pkg + chmod +w vendors + mv vendors pkg + chmod -w pkg/vendors + chmod -w pkg + ./custom_project/cmake/util/makeself/makeself.sh \ + --header ./custom_project/cmake/util/makeself/makeself-header.sh \ + --gzip --notemp --complevel 4 --nomd5 --sha256 --chown \ + ./pkg aie_ops.run 'aie ops' +} + +function build_ops(){ + ori_path=${PWD} + cd ${current_script_dir} + rm -rf vendors + source ${current_script_dir}/build_ascendc_ops.sh + rm -rf ${current_script_dir}/vendors/aie_ascendc/bin + rm -rf ${current_script_dir}/vendors/customize/bin + make_package + if [ "x${is_ci_build}" == "xy" ]; then + cp aie_ops.run ${release_path}/ + fi + cd ${ori_path} +} + +build_ops diff --git a/backends/npu/opp/ascendc_custom_ops/build/ir_demo.json b/backends/npu/opp/ascendc_custom_ops/build/ir_demo.json new file mode 100644 index 00000000000..200db39025b --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/build/ir_demo.json @@ -0,0 +1,40 @@ +[ + { + "op": "GenSeqLen", + "language": "cpp", + "input_desc": [ + { + "name": "attnMask", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "int32" + ] + } + ], + "output_desc": [ + { + "name": "seqLenAlign", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "int32" + ] + }, + { + "name": "seqLenOri", + "param_type": "required", + "format": [ + "ND" + ], + "type": [ + "int32" + ] + } + ] + } +] \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/change_rank_id_to_device_id.py b/backends/npu/opp/ascendc_custom_ops/change_rank_id_to_device_id.py new file mode 100644 index 00000000000..0fc57524a91 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/change_rank_id_to_device_id.py @@ -0,0 +1,107 @@ +import json +import os +from functools import partial +from argparse import ArgumentParser + + +FILTER_DIRS = [".profiler", "HCCL_PROF", "timeline", "query", 'sqlite', 'log'] + + +def get_path_dir(path: str) -> list: + """ + check result path exist JOB dir + path : result path + """ + path_dir_filter = filter(partial(_path_dir_filter_func, root_dir=path), os.listdir(path)) + sub_dirs = list(path_dir_filter) + if not sub_dirs: + message = f"The path \"{path}\" does not have PROF dir. Please check the path." + print(message) + return sub_dirs + + +def _path_dir_filter_func(sub_path, root_dir): + return sub_path not in FILTER_DIRS and os.path.isdir(os.path.realpath(os.path.join(root_dir, sub_path))) + + +def change_rank_id_to_device_id(pro_dir): + for root, dirs, files in os.walk(pro_dir): + for dir_ in dirs: + if 'device_' in dir_: + device_id = dir_.split("_")[-1] + + info_path = os.path.join(root, dir_, f'info.json.{device_id}') + + with open(info_path, 'r+') as f: + info_data = json.load(f) + print("ori: ", info_data.get('rank_id')) + info_data['rank_id'] = int(device_id) + print("modify: ", info_data['rank_id']) + with open(info_path, "w+") as f: + json.dump(info_data, f) + + +def set_rank_id(pro_dir, rank_id_set): + for root, dirs, files in os.walk(pro_dir): + for dir_ in dirs: + if 'device_' in dir_: + device_id = dir_.split("_")[-1] + + info_path = os.path.join(root, dir_, f'info.json.{device_id}') + + with open(info_path, 'r+') as f: + info_data = json.load(f) + print("ori: ", info_data.get('rank_id')) + idx = 0 + while int(device_id) + idx * 8 in rank_id_set: + idx += 1 + rank_id = int(device_id) + idx * 8 + rank_id_set.add(rank_id) + info_data['rank_id'] = rank_id + print("modify: ", info_data['rank_id']) + with open(info_path, "w+") as f: + json.dump(info_data, f) + + +def get_node_id_set_rank_id(dir_path): + for dir_name in os.listdir(dir_path): + node_id = dir_name.split("_")[0] + for root, dirs, files in os.walk(os.path.join(dir_path, dir_name)): + for dir_ in dirs: + if 'device_' in dir_: + device_id = dir_.split("_")[-1] + + info_path = os.path.join(root, dir_, f'info.json.{device_id}') + + with open(info_path, 'r+') as f: + info_data = json.load(f) + print("ori: rank_id:", info_data.get('rank_id'), "device_id: ", device_id, "node_id: ", node_id) + rank_id = int(device_id) + int(node_id) * 8 + info_data['rank_id'] = rank_id + print("modify: ", info_data['rank_id']) + with open(info_path, "w+") as f: + json.dump(info_data, f) + + +def set_soft_link(args): + + cmd = "find %s -name PROF* | xargs -I {} mv {} %s" % (args.data, args.output) + print("rm_data cmd:{} begin".format(cmd)) + os.system(cmd) + print("rm_data cmd:{} end".format(cmd)) + + +def parse_args(): + parser = ArgumentParser(description="Merge timeline for multi card") + parser.add_argument("--data", "-d", default=None, help="root dir of PROF_* data") + parser.add_argument("--output", default=None, help="soft link dir") + arg = parser.parse_args() + return arg + + +if __name__ == "__main__": + args = parse_args() + print(" ======================== set rank id and soft link ========================") + change_rank_id_to_device_id(args.data) + # get_node_id_set_rank_id(args.data) + # set_soft_link(args) diff --git a/backends/npu/opp/ascendc_custom_ops/changelog.txt b/backends/npu/opp/ascendc_custom_ops/changelog.txt new file mode 100644 index 00000000000..cd53de92357 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/changelog.txt @@ -0,0 +1,23 @@ +【2023.10.25】 + 1. 优化 TokenPenaltyMultiScores算子的 Tiling策略 & 计算逻辑 +【2023.12.19】 + 1. 新增7个算子 + GetMaxLen + GetPaddingOffset + RebuildPadding + UpdateInputs + SetValueByFlagsAndIdxV2 + SetStopValueMultiEndsV2 + TokenPenaltyMultiScoresV2 + 2. 算子TokenPenaltyMultiScores 适配新版本cann, +【2024.01.17】 + 1. 新增1个算子 + StepPaddle +【2024.01.20】 + 1. TokenPenaltyMultiScores & TokenPenaltyMultiScoresV2 + break分支异常处理 + 算子执行完不修改repeat_times入参 + 2. SetValueByFlagsAndIdxV2 + 性能优化 + 3. GetPaddingOffset & RebuildPadding + 调整 MAX_BATCH_NUM 为256 diff --git a/backends/npu/opp/ascendc_custom_ops/main.py b/backends/npu/opp/ascendc_custom_ops/main.py new file mode 100644 index 00000000000..61ecdc923cd --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/main.py @@ -0,0 +1,238 @@ +#! /usr/bin/python3 +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import re + +from functools import partial +from argparse import ArgumentParser + + +FILTER_DIRS = [".profiler", "HCCL_PROF", "timeline", "query", 'sqlite', 'log'] +MAX_INDEX_COUNT = 1000 + + +# 获取时间差异文件中的node和时间差的对应关系,保存到字典中 +def get_node_time_diff(time_diff_file_path): + node_diff = {} + if not time_diff_file_path: + return None + with open(time_diff_file_path, 'r+', encoding='utf-8') as f: + all_time_diff = json.load(f) + node_idx = 0 + for ip, timediff in all_time_diff.items(): + node_diff[node_idx] = timediff + node_idx += 1 + return node_diff + + +def get_path_dir(path: str) -> list: + """ + check result path exist JOB dir + path : result path + """ + path_dir_filter = filter(partial(_path_dir_filter_func, root_dir=path), os.listdir(path)) + sub_dirs = list(path_dir_filter) + if not sub_dirs: + message = f"The path \"{path}\" does not have PROF dir. Please check the path." + print(message) + return sub_dirs + + +def _path_dir_filter_func(sub_path, root_dir): + return sub_path not in FILTER_DIRS and os.path.isdir(os.path.realpath(os.path.join(root_dir, sub_path))) + + +def natural_sort(files): + convert = lambda text: int(text) if text.isdigit() else text.lower() + alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] + return sorted(files, key=alphanum_key) + + +def get_timeline_info(args, prof_dirs): + timeline_info = {} + + for prof in prof_dirs: + pro_path = os.path.join(args.data, prof) + + # 从info.json读取rank_id + rank_id = get_rank_id_from_info_json(pro_path) + if rank_id is None: + print(f"WARN, There is not rank id info in {pro_path}") + continue + + timeline_path = get_timeline_path(pro_path, args.type) + + if os.path.exists(timeline_path): + timeline_info[rank_id] = timeline_path + else: + print(f"WARN, The file \"{timeline_path}\" does not exist.") + return timeline_info + + +def get_timeline_path(pro_path, type): + for root, dirs, files in os.walk(pro_path): + for dir_ in dirs: + if 'ASCEND_PROFILER_OUTPUT' == dir_ and type == 'pytorch': + timeline_path = os.path.realpath(os.path.join(root, dir_, 'trace_view.json')) + return timeline_path + + for file_ in sorted(files, reverse=True): + if 'msprof' in file_: + timeline_path = os.path.join(root, file_) + return timeline_path + return + + +def get_rank_id_from_info_json(pro_path): + info_json = "" + rank_id = None + for root, dirs, files in os.walk(pro_path): + for file in files: + if "info.json." in file and ".done" not in file: + info_json = os.path.join(root, file) + break + + if info_json: + with open(info_json, "r+") as f: + info = json.load(f) + rank_id = info.get("rank_id") + return rank_id + + +def merge_timeline_general(args): + """合并e2e profiling生成的msprof*.json""" + prof_dir = get_path_dir(args.data) + timeline_info = get_timeline_info(args, prof_dir) + timeline_files_dict = {} + + node_time_diff = get_node_time_diff(args.timediff) if args.timediff else None + + # 合并部分profiling items + process_list = args.items.split(",") if args.items else None + + # 合并部分rank + if args.rank: + rank_ids = [int(rank_id) for rank_id in args.rank.split(",")] + else: + rank_ids = list(timeline_info.keys()) + + for rank_id in rank_ids: + timeline_files_dict[rank_id] = timeline_info.get(rank_id) + merge_timeline_events(timeline_files_dict, process_list, node_time_diff) + + +def merge_timeline_custom(args): + """合并指定目录里所有timeline文件""" + timeline_files = natural_sort(os.listdir(args.data)) + timeline_files_dict = {} + for idx, timeline_file in enumerate(timeline_files): + timeline_files_dict[idx] = os.path.join(args.data, timeline_file) + node_time_diff = get_node_time_diff(args.timediff) if args.timediff else None + # 合并部分profiling items + process_list = args.items.split(",") if args.items else None + merge_timeline_events(timeline_files_dict, process_list, node_time_diff) + + +def merge_timeline_events(timeline_file_dict, process_list, node_time_diff=None): + """ + 输入需要合并的timeline文件路径及对应的rank_id/id、需要合并的process_list、校准时间差node_time_diff + 输出合并timeline + """ + new_events = [] + for rank_id, timeline_file_path in timeline_file_dict.items(): + node = rank_id // 8 + print("rank id: ", rank_id, "timeline file: ", timeline_file_path) + + # 获取相应的时间差异 + node_time = node_time_diff[node] if node_time_diff else None + try: + with open(timeline_file_path, 'r+') as f: + cur_events = jso except Exception as err: + print("[ERROR] %s" % err) + return + + proc_pid_dict = {} + for event in cur_events: + if event.get("name") == "process_name" and event.get("ph") == "M": + if event.get("args"): + proc_pid_dict[event["args"].get("name")] = event.get("pid") + process_list = process_list if process_list else list(proc_pid_dict.keys()) + # 提取待合并的items的pid + merged_pids = set() + for pro in process_list: + pro = " ".join(pro.split("_")) if "_" in pro else pro + + if pro not in proc_pid_dict.keys(): + print(f"{pro} is invalid item, valid items: {list(proc_pid_dict.keys())}") + continue + merged_pids.add(proc_pid_dict.get(pro)) + + for event in cur_events: + + # 只合并特定数据项 + if merged_pids and event.get('pid') not in merged_pids: + continue + + # 当前节点间时间误差可用时,进行时间校准 + if event.get("ts") and node_time: + event["ts"] = event["ts"] - node_time * 1000000 + + # 区分不同rank的同一进程的pid + if isinstance(event.get("pid"), (str, int)): + # 合并GPU profiling/ after timeline pid modify + event["pid"] = int(''.join(x for x in str(event.get("pid")) if x.isdigit()) + + str(rank_id)) + + # convert tid to int + if isinstance(event.get("tid"), str): + event["tid"] = int(''.join(x for x in event["tid"] if x.isdigit())) + + # 进程名加上rank_id区分不同rank + if event.get("name") == "process_name" and event.get("ph") == "M": + if event.get("args") is not None and event["args"].get("name") is not None: + event["args"]["name"] = event["args"]["name"] + f"_{rank_id}" + + new_events.append(event) + + output_path = os.path.join(args.output, f"msprof_merged_{len(timeline_file_dict)}p.json") + with open(output_path, 'w') as f: + json.dump(new_events, f) + print(f"timeline merged output path: {output_path}") + + +def parse_args(): + parser = ArgumentParser(description="Merge timeline for multi card") + parser.add_argument("--data", "-d", default=None, help="root dir of PROF_* data") + parser.add_argument("--timediff", "-t", default=None, help="JSON file for saving startup time differences") + parser.add_argument("--output", "-o", default=None, help="save path of msprof_merged.json ") + parser.add_argument("--rank", default=None, help="List of ranks to be merged. By default, all ranks are merged") + parser.add_argument("--items", default=None, help="Specify the data items to be merged. in the timeline.") + parser.add_argument("--type", choices=('pytorch', 'e2e', 'custom'), help="Customize the timeline file to be merged.") + arg = parser.parse_args() + return arg + + +if __name__ == "__main__": + args = parse_args() + + if not args.output: + args.output = args.data + print("========================== start merge timeline ====================") + if args.type == "custom": + merge_timeline_custom(args) + else: + merge_timeline_general(args) \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/set_mask_value_plugin.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/set_mask_value_plugin.cpp new file mode 100644 index 00000000000..87941696e93 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/set_mask_value_plugin.cpp @@ -0,0 +1,45 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/register.h" + +namespace domi { + // Onnx ParseParams + Status ParseParamSetMaskValue(const ge::Operator &opSrc, ge::Operator &opDest) + { + return SUCCESS; + } + + static std::vector g_supportedOnnxVersion ({ + "ai.onnx::8::SetMaskValue", + "ai.onnx::9::SetMaskValue", + "ai.onnx::10::SetMaskValue", + "ai.onnx::11::SetMaskValue", + "ai.onnx::12::SetMaskValue", + "ai.onnx::13::SetMaskValue", + "ai.onnx::14::SetMaskValue", + "ai.onnx::15::SetMaskValue", + "ai.onnx::16::SetMaskValue", + "ai.onnx::17::SetMaskValue", + "ai.onnx::18::SetMaskValue", + }); + + // register SetMaskValue op info to GE + REGISTER_CUSTOM_OP("SetMaskValue") // Set the registration name of operator + .FrameworkType(ONNX) // Operator name with the original framework + .OriginOpType(g_supportedOnnxVersion) // Set the original frame type of the operator + .ParseParamsByOperatorFn(ParseParamSetMaskValue); +} // namespace domi diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/set_stop_value_multi_ends_plugin.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/set_stop_value_multi_ends_plugin.cpp new file mode 100644 index 00000000000..55726ad504d --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/set_stop_value_multi_ends_plugin.cpp @@ -0,0 +1,45 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/register.h" + +namespace domi { + // Onnx ParseParams + Status ParseParamSetStopValueMultiEnds(const ge::Operator &opSrc, ge::Operator &opDest) + { + return SUCCESS; + } + + static std::vector g_supportedOnnxVersion ({ + "ai.onnx::8::SetStopValueMultiEnds", + "ai.onnx::9::SetStopValueMultiEnds", + "ai.onnx::10::SetStopValueMultiEnds", + "ai.onnx::11::SetStopValueMultiEnds", + "ai.onnx::12::SetStopValueMultiEnds", + "ai.onnx::13::SetStopValueMultiEnds", + "ai.onnx::14::SetStopValueMultiEnds", + "ai.onnx::15::SetStopValueMultiEnds", + "ai.onnx::16::SetStopValueMultiEnds", + "ai.onnx::17::SetStopValueMultiEnds", + "ai.onnx::18::SetStopValueMultiEnds", + }); + + // register SetStopValueMultiEnds op info to GE + REGISTER_CUSTOM_OP("SetStopValueMultiEnds") // Set the registration name of operator + .FrameworkType(ONNX) // Operator name with the original framework + .OriginOpType(g_supportedOnnxVersion) // Set the original frame type of the operator + .ParseParamsByOperatorFn(ParseParamSetStopValueMultiEnds); +} // namespace domi diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/set_value_by_flags_and_idx_plugin.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/set_value_by_flags_and_idx_plugin.cpp new file mode 100644 index 00000000000..29595a27359 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/set_value_by_flags_and_idx_plugin.cpp @@ -0,0 +1,45 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/register.h" + +namespace domi { + // Onnx ParseParams + Status ParseParamSetValueByFlagsAndIdx(const ge::Operator &opSrc, ge::Operator &opDest) + { + return SUCCESS; + } + + static std::vector g_supportedOnnxVersion ({ + "ai.onnx::8::SetValueByFlagsAndIdx", + "ai.onnx::9::SetValueByFlagsAndIdx", + "ai.onnx::10::SetValueByFlagsAndIdx", + "ai.onnx::11::SetValueByFlagsAndIdx", + "ai.onnx::12::SetValueByFlagsAndIdx", + "ai.onnx::13::SetValueByFlagsAndIdx", + "ai.onnx::14::SetValueByFlagsAndIdx", + "ai.onnx::15::SetValueByFlagsAndIdx", + "ai.onnx::16::SetValueByFlagsAndIdx", + "ai.onnx::17::SetValueByFlagsAndIdx", + "ai.onnx::18::SetValueByFlagsAndIdx", + }); + + // register SetValueByFlagsAndIdx op info to GE + REGISTER_CUSTOM_OP("SetValueByFlagsAndIdx") // Set the registration name of operator + .FrameworkType(ONNX) // Operator name with the original framework + .OriginOpType(g_supportedOnnxVersion) // Set the original frame type of the operator + .ParseParamsByOperatorFn(ParseParamSetValueByFlagsAndIdx); +} // namespace domi diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/token_penalty_multi_scores_plugin.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/token_penalty_multi_scores_plugin.cpp new file mode 100644 index 00000000000..5b1ec74f1c0 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/framework/onnx_plugin/token_penalty_multi_scores_plugin.cpp @@ -0,0 +1,45 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "register/register.h" + +namespace domi { + // Onnx ParseParams + Status ParseParamTokenPenaltyMultiScores(const ge::Operator &opSrc, ge::Operator &opDest) + { + return SUCCESS; + } + + static std::vector g_supportedOnnxVersion ({ + "ai.onnx::8::TokenPenaltyMultiScores", + "ai.onnx::9::TokenPenaltyMultiScores", + "ai.onnx::10::TokenPenaltyMultiScores", + "ai.onnx::11::TokenPenaltyMultiScores", + "ai.onnx::12::TokenPenaltyMultiScores", + "ai.onnx::13::TokenPenaltyMultiScores", + "ai.onnx::14::TokenPenaltyMultiScores", + "ai.onnx::15::TokenPenaltyMultiScores", + "ai.onnx::16::TokenPenaltyMultiScores", + "ai.onnx::17::TokenPenaltyMultiScores", + "ai.onnx::18::TokenPenaltyMultiScores", + }); + + // register TokenPenaltyMultiScores op info to GE + REGISTER_CUSTOM_OP("TokenPenaltyMultiScores") // Set the registration name of operator + .FrameworkType(ONNX) // Operator name with the original framework + .OriginOpType(g_supportedOnnxVersion) // Set the original frame type of the operator + .ParseParamsByOperatorFn(ParseParamTokenPenaltyMultiScores); +} // namespace domi diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_max_len.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_max_len.cpp new file mode 100644 index 00000000000..f4c7a82ab73 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_max_len.cpp @@ -0,0 +1,137 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "get_max_len_tiling.h" +#include "register/op_def_registry.h" + +using namespace std; + +namespace { +constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { + static ge::graphStatus GetMaxLenTilingFunc(gert::TilingContext *context) + { + GetMaxLenTilingData tiling; + const gert::StorageShape* seqLensEncoder = context->GetInputShape(0); + + int bs = seqLensEncoder->GetStorageShape().GetDim(0); + tiling.set_bs(bs); + + int32_t blockDims = 1; + context->SetBlockDim(blockDims); + + std::cout << "GetMaxLenTilingFunc bs " << bs << std::endl; + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + std::cout << "GetWorkspaceNum Failed" << std::endl; + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + + return ge::GRAPH_SUCCESS; + } +} + +namespace ge { + static ge::graphStatus GetMaxLenInferShape(gert::InferShapeContext *context) + { + const gert::Shape* x0Shape = context->GetInputShape(0); + const gert::Shape* x1Shape = context->GetInputShape(1); + + gert::Shape* y0Shape = context->GetOutputShape(0); + gert::Shape* y1Shape = context->GetOutputShape(1); + + y0Shape->SetDimNum(1); + y0Shape->SetDim(0, 8); + y1Shape->SetDimNum(1); + y1Shape->SetDim(0, 8); + + return GRAPH_SUCCESS; + } + + static ge::graphStatus GetMaxLenInferShapeRange(gert::InferShapeRangeContext *context) + { + const gert::Range *inputX0ShapeRange = context->GetInputShapeRange(0); + const gert::Range *inputX1ShapeRange = context->GetInputShapeRange(1); + gert::Range *y0ShapeRange = context->GetOutputShapeRange(0); + gert::Range *y1ShapeRange = context->GetOutputShapeRange(1); + + *y0ShapeRange = *inputX0ShapeRange; + *y1ShapeRange = *inputX1ShapeRange; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus GetMaxLenInferDataType(gert::InferDataTypeContext *context) + { + const ge::DataType x0DataType = context->GetInputDataType(0); + + context->SetOutputDataType(0, x0DataType); + context->SetOutputDataType(1, x0DataType); + + return GRAPH_SUCCESS; + } +} + +namespace ops { + class GetMaxLen : public OpDef { + public: + explicit GetMaxLen(const char *name) : OpDef(name) + { + this->Input("seq_lens_encoder") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seq_lens_decoder") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("seq_lens_encoder_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("seq_lens_decoder_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::GetMaxLenInferShape) + .SetInferShapeRange(ge::GetMaxLenInferShapeRange) + .SetInferDataType(ge::GetMaxLenInferDataType); + + this->AICore() + .SetTiling(optiling::GetMaxLenTilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } + }; + + OP_ADD(GetMaxLen); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_max_len_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_max_len_tiling.h new file mode 100644 index 00000000000..e62ea2a7f95 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_max_len_tiling.h @@ -0,0 +1,25 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(GetMaxLenTilingData) +TILING_DATA_FIELD_DEF(int32_t, bs); +TILING_DATA_FIELD_DEF(int32_t, bs1); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(GetMaxLen, GetMaxLenTilingData) +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_padding_offset.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_padding_offset.cpp new file mode 100644 index 00000000000..25fddb5e121 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_padding_offset.cpp @@ -0,0 +1,173 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "get_padding_offset_tiling.h" +#include "register/op_def_registry.h" +namespace { +constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { +static ge::graphStatus GetPaddingOffsetTilingFunc(gert::TilingContext *context) +{ + GetPaddingOffsetTilingData tiling; + + const gert::StorageShape *input_data_shape = context->GetInputShape(0); + + int32_t batch = input_data_shape->GetStorageShape().GetDim(0); + int32_t padLength = input_data_shape->GetStorageShape().GetDim(1); + + tiling.set_batch(batch); + tiling.set_padLength(padLength); + + int32_t blockSize = 1; + context->SetBlockDim(blockSize); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + std::cout << "GetWorkspaceNum Failed" << std::endl; + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; +} +} + +namespace ge { +static ge::graphStatus GetPaddingOffsetInferShape(gert::InferShapeContext *context) +{ + const gert::Shape *x0_shape = context->GetInputShape(0); + const gert::Shape *x1_shape = context->GetInputShape(1); + gert::Shape *y0_shape = context->GetOutputShape(0); + gert::Shape *y1_shape = context->GetOutputShape(1); + gert::Shape *y2_shape = context->GetOutputShape(2); + gert::Shape *y3_shape = context->GetOutputShape(3); + gert::Shape *y4_shape = context->GetOutputShape(4); + + y0_shape -> SetDimNum(1); + y0_shape -> SetDim(0, x0_shape -> GetDim(0) * x0_shape -> GetDim(1)); + + y1_shape -> SetDimNum(1); + y1_shape -> SetDim(0, x1_shape -> GetDim(0)); + + y2_shape -> SetDimNum(1); + y2_shape -> SetDim(0, x0_shape -> GetDim(0) * x0_shape -> GetDim(1)); + + y3_shape -> SetDimNum(1); + y3_shape -> SetDim(0, x1_shape -> GetDim(0) + 1); + + y4_shape -> SetDimNum(1); + y4_shape -> SetDim(0, x1_shape -> GetDim(0) + 1); + return GRAPH_SUCCESS; +} + +ge::graphStatus GetPaddingOffsetInferShapeRange(gert::InferShapeRangeContext *context) +{ + const gert::Range *inputXShapeRange = context->GetInputShapeRange(2); + gert::Range *outputShapeRange0 = context->GetOutputShapeRange(0); + gert::Range *outputShapeRange1 = context->GetOutputShapeRange(1); + gert::Range *outputShapeRange2 = context->GetOutputShapeRange(2); + gert::Range *outputShapeRange3 = context->GetOutputShapeRange(3); + gert::Range *outputShapeRange4 = context->GetOutputShapeRange(4); + *outputShapeRange0 = *inputXShapeRange; + *outputShapeRange1 = *inputXShapeRange; + *outputShapeRange2 = *inputXShapeRange; + *outputShapeRange3 = *inputXShapeRange; + *outputShapeRange4 = *inputXShapeRange; + + return GRAPH_SUCCESS; +} + +ge::graphStatus GetPaddingOffsetInferDataType(gert::InferDataTypeContext *context) +{ + const ge::DataType x0DataType = context->GetInputDataType(0); + const ge::DataType x1DataType = context->GetInputDataType(1); + + context->SetOutputDataType(0, x0DataType); + context->SetOutputDataType(1, x1DataType); + context->SetOutputDataType(2, x1DataType); + context->SetOutputDataType(3, x1DataType); + context->SetOutputDataType(4, x1DataType); + + return GRAPH_SUCCESS; +} +} + +namespace ops { +class GetPaddingOffset : public OpDef { +public: + GetPaddingOffset(const char *name) : OpDef(name) + { + this->Input("input_ids") + .ParamType(REQUIRED) + .DataType({ ge::DT_INT64 }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + this->Input("cum_offsets") + .ParamType(REQUIRED) + .DataType({ ge::DT_INT32 }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + this->Input("token_num") + .ParamType(REQUIRED) + .DataType({ ge::DT_INT64 }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + this->Input("seq_len") + .ParamType(REQUIRED) + .DataType({ ge::DT_INT32 }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + this->Output("x_remove_padding") + .ParamType(REQUIRED) + .DataType({ ge::DT_INT64 }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + this->Output("cum_offsets_out") + .ParamType(REQUIRED) + .DataType({ ge::DT_INT32 }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + this->Output("padding_offset") + .ParamType(REQUIRED) + .DataType({ ge::DT_INT32 }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + this->Output("cu_seqlens_q") + .ParamType(REQUIRED) + .DataType({ ge::DT_INT32 }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + this->Output("cu_seqlens_k") + .ParamType(REQUIRED) + .DataType({ ge::DT_INT32 }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + + this->SetInferShape(ge::GetPaddingOffsetInferShape); + + this->AICore().SetTiling(optiling::GetPaddingOffsetTilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } +}; + +OP_ADD(GetPaddingOffset); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_padding_offset_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_padding_offset_tiling.h new file mode 100644 index 00000000000..44a2e6d7ba5 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/get_padding_offset_tiling.h @@ -0,0 +1,25 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(GetPaddingOffsetTilingData) +TILING_DATA_FIELD_DEF(uint32_t, padLength); +TILING_DATA_FIELD_DEF(int32_t, batch); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(GetPaddingOffset, GetPaddingOffsetTilingData) +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/rebuild_padding.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/rebuild_padding.cpp new file mode 100644 index 00000000000..dff72cb8ced --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/rebuild_padding.cpp @@ -0,0 +1,151 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "rebuild_padding_tiling.h" +#include "register/op_def_registry.h" + +using namespace std; + +namespace { +constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { + static ge::graphStatus RebuildPaddingTilingFunc(gert::TilingContext *context) + { + RebuildPaddingTilingData tiling; + const gert::StorageShape* tmpOut = context->GetInputShape(0); + const gert::StorageShape* cumOffsetsShape = context->GetInputShape(1); + + int bs = cumOffsetsShape->GetStorageShape().GetDim(0); + int token_num = tmpOut->GetStorageShape().GetDim(0); + int dim_embed = tmpOut->GetStorageShape().GetDim(1); + + auto attrs = context->GetAttrs(); + auto max_input_length_ptr = attrs->GetAttrPointer(0); + auto max_input_length = *max_input_length_ptr; + + tiling.set_bs(bs); + tiling.set_token_num(token_num); + tiling.set_dim_embed(dim_embed); + tiling.set_max_input_length(max_input_length); + + int32_t blockDims = 1; + context->SetBlockDim(blockDims); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + std::cout << "GetWorkspaceNum Failed" << std::endl; + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; + } +} + +namespace ge { + static ge::graphStatus RebuildPaddingInferShape(gert::InferShapeContext *context) + { + const gert::Shape* x0Shape = context->GetInputShape(0); + // const gert::Shape* x1Shape = context->GetInputShape(1); + const gert::Shape* x3Shape = context->GetInputShape(3); + int x0num = x0Shape->GetDimNum(); + + gert::Shape* y0Shape = context->GetOutputShape(0); + + y0Shape->SetDimNum(2); + // y0Shape->SetDim(0,x1Shape->GetDim(0)); + y0Shape->SetDim(0,x3Shape->GetDim(0)); + y0Shape->SetDim(1,x0Shape->GetDim(x0num-1)); + + return GRAPH_SUCCESS; + } + + static ge::graphStatus RebuildPaddingInferShapeRange(gert::InferShapeRangeContext *context) + { + gert::Range *y0ShapeRange = context->GetOutputShapeRange(0); + gert::Shape min = gert::Shape({1}); + gert::Shape max = gert::Shape({-1}); + y0ShapeRange->SetMin(&min); + y0ShapeRange->SetMax(&max); + + return GRAPH_SUCCESS; + } + + static ge::graphStatus RebuildPaddingInferDataType(gert::InferDataTypeContext *context) + { + const ge::DataType x0DataType = context->GetInputDataType(0); + + context->SetOutputDataType(0, x0DataType); + + return GRAPH_SUCCESS; + } +} + +namespace ops { + class RebuildPadding : public OpDef { + public: + explicit RebuildPadding(const char *name) : OpDef(name) + { + this->Input("tmpOut") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("cumOffsets") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seqLensDecoder") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seqLensEncoder") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("out") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::RebuildPaddingInferShape) + .SetInferShapeRange(ge::RebuildPaddingInferShapeRange) + .SetInferDataType(ge::RebuildPaddingInferDataType); + + this->Attr("max_input_length").AttrType(REQUIRED).Int(); + + this->AICore() + .SetTiling(optiling::RebuildPaddingTilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } + }; + + OP_ADD(RebuildPadding); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/rebuild_padding_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/rebuild_padding_tiling.h new file mode 100644 index 00000000000..8076f795feb --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/rebuild_padding_tiling.h @@ -0,0 +1,27 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(RebuildPaddingTilingData) +TILING_DATA_FIELD_DEF(int32_t, bs); +TILING_DATA_FIELD_DEF(int32_t, dim_embed); +TILING_DATA_FIELD_DEF(int32_t, token_num); +TILING_DATA_FIELD_DEF(int32_t, max_input_length); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(RebuildPadding, RebuildPaddingTilingData) +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_mask_value.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_mask_value.cpp new file mode 100644 index 00000000000..57b2433aac0 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_mask_value.cpp @@ -0,0 +1,107 @@ +#include "set_mask_value_tiling.h" +#include "register/op_def_registry.h" +namespace { +constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { +static ge::graphStatus SetMaskValueTilingFunc(gert::TilingContext *context) +{ + SetMaskValueTilingData tiling; + + const gert::StorageShape *input_data_shape = context->GetInputShape(0); + const gert::StorageShape *seq_lens_shape = context->GetInputShape(2); + + + int32_t inputBs = input_data_shape->GetStorageShape().GetDim(0); + int32_t length = input_data_shape->GetStorageShape().GetDim(3); + int32_t seqBs = seq_lens_shape->GetStorageShape().GetDim(0); + + int32_t blockSize = 1; + + tiling.set_seqBs(seqBs); + tiling.set_length(length); + + context->SetBlockDim(blockSize); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + std::cout << "GetWorkspaceNum Failed" << std::endl; + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; +} +} + +namespace ge { +static ge::graphStatus SetMaskValueInferShape(gert::InferShapeContext *context) +{ + const gert::Shape *x3_shape = context->GetInputShape(2); + gert::Shape *y0_shape = context->GetOutputShape(0); + + *y0_shape = *x3_shape; + return GRAPH_SUCCESS; +} + +ge::graphStatus SetMaskValueInferShapeRange(gert::InferShapeRangeContext *context) +{ + const gert::Range *inputXShapeRange = context->GetInputShapeRange(2); + gert::Range *outputShapeRange = context->GetOutputShapeRange(0); + *outputShapeRange = *inputXShapeRange; + + return GRAPH_SUCCESS; +} + +ge::graphStatus SetMaskValueInferDataType(gert::InferDataTypeContext *context) +{ + const ge::DataType Q = context->GetInputDataType(2); + + context->SetOutputDataType(0, Q); + return GRAPH_SUCCESS; +} +} + +namespace ops { +class SetMaskValue : public OpDef { +public: + SetMaskValue(const char *name) : OpDef(name) + { + this->Input("input_data") + .ParamType(REQUIRED) + .DataType({ ge::DT_FLOAT16 }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + this->Input("stop_flags") + .ParamType(REQUIRED) + .DataType({ ge::DT_BOOL }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + this->Input("seq_lens") + .ParamType(REQUIRED) + .DataType({ ge::DT_INT32 }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + this->Output("sequence_lengths") + .ParamType(REQUIRED) + .DataType({ ge::DT_INT32 }) + .Format({ ge::FORMAT_ND }) + .UnknownShapeFormat({ ge::FORMAT_ND }); + + this->SetInferShape(ge::SetMaskValueInferShape) + .SetInferShapeRange(ge::SetMaskValueInferShapeRange) + .SetInferDataType(ge::SetMaskValueInferDataType); + + this->AICore().SetTiling(optiling::SetMaskValueTilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } +}; + +OP_ADD(SetMaskValue); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_mask_value_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_mask_value_tiling.h new file mode 100644 index 00000000000..afdaf33ec22 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_mask_value_tiling.h @@ -0,0 +1,10 @@ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(SetMaskValueTilingData) +TILING_DATA_FIELD_DEF(int32_t, seqBs); +TILING_DATA_FIELD_DEF(int32_t, length); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(SetMaskValue, SetMaskValueTilingData) +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends.cpp new file mode 100644 index 00000000000..b355fc6047c --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends.cpp @@ -0,0 +1,93 @@ + +#include "set_stop_value_multi_ends_tiling.h" +#include "register/op_def_registry.h" + +namespace { + constexpr int32_t MODE = 2; + constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} +namespace optiling { +static ge::graphStatus TilingFunc(gert::TilingContext* context) +{ + SetStopValueMultiEndsTilingData tiling; + + const gert::StorageShape *input_data_shape = context->GetInputShape(0); + const gert::StorageShape *seq_lens_shape = context->GetInputShape(2); + + int32_t inputBs = input_data_shape->GetStorageShape().GetDim(0); + int32_t length = seq_lens_shape->GetStorageShape().GetDim(0); + int32_t seqBs = input_data_shape->GetStorageShape().GetDim(0); + + int32_t blockSize = 1; + tiling.set_seqBs(seqBs); + tiling.set_length(length); + + context->SetBlockDim(blockSize); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + std::cout << "GetWorkspaceNum Failed" << std::endl; + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; +} +} + +namespace ge { +static ge::graphStatus InferShape(gert::InferShapeContext* context) +{ + const gert::Shape* x1_shape = context->GetInputShape(0); + gert::Shape* y_shape = context->GetOutputShape(0); + *y_shape = *x1_shape; + return GRAPH_SUCCESS; +} +} + + +namespace ops { +class SetStopValueMultiEnds : public OpDef { +public: + explicit SetStopValueMultiEnds(const char* name) : OpDef(name) + { + this->Input("topk_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("stop_flags") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("end_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("topk_ids_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("stop_flags_out") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::InferShape); + + this->AICore() + .SetTiling(optiling::TilingFunc); + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } +}; + +OP_ADD(SetStopValueMultiEnds); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends_tiling.h new file mode 100644 index 00000000000..e0d0d2b95a9 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends_tiling.h @@ -0,0 +1,10 @@ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(SetStopValueMultiEndsTilingData) +TILING_DATA_FIELD_DEF(int32_t, seqBs); +TILING_DATA_FIELD_DEF(int32_t, length); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(SetStopValueMultiEnds, SetStopValueMultiEndsTilingData) +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends_v2.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends_v2.cpp new file mode 100644 index 00000000000..c4a8c68d53c --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends_v2.cpp @@ -0,0 +1,167 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "set_stop_value_multi_ends_v2_tiling.h" +#include "register/op_def_registry.h" + +namespace { + constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} +namespace optiling { + static ge::graphStatus SetStopValueMultiEndsV2TilingFunc(gert::TilingContext* context) + { + SetStopValueMultiEndsV2TilingData tiling; + const gert::StorageShape* topkIdsShape = context->GetInputShape(0); + const gert::StorageShape* endIdsShape = context->GetInputShape(3); + + int32_t bs = topkIdsShape->GetStorageShape().GetDim(0); + int32_t length = endIdsShape->GetStorageShape().GetDim(0); + + tiling.set_bs(bs); + tiling.set_length(length); + + int32_t blockSize = 1; + context->SetBlockDim(blockSize); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + std::cout << "GetWorkspaceNum Failed" << std::endl; + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; + } +} + +namespace ge { + static ge::graphStatus SetStopValueMultiEndsV2InferShape(gert::InferShapeContext* context) + { + const gert::Shape* x0_shape = context->GetInputShape(0); + const gert::Shape* x1_shape = context->GetInputShape(1); + const gert::Shape* x4_shape = context->GetInputShape(4); + + gert::Shape* y0_shape = context->GetOutputShape(0); + gert::Shape* y1_shape = context->GetOutputShape(1); + gert::Shape* y2_shape = context->GetOutputShape(2); + + *y0_shape = *x0_shape; + *y1_shape = *x1_shape; + *y2_shape = *x4_shape; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus SetStopValueMultiEndsV2InferShapeRange(gert::InferShapeRangeContext *context) + { + const gert::Range *inputX0ShapeRange = context->GetInputShapeRange(0); + const gert::Range *inputX1ShapeRange = context->GetInputShapeRange(1); + const gert::Range *inputX4ShapeRange = context->GetInputShapeRange(4); + + gert::Range *y0ShapeRange = context->GetOutputShapeRange(0); + gert::Range *y1ShapeRange = context->GetOutputShapeRange(1); + gert::Range *y2ShapeRange = context->GetOutputShapeRange(2); + + *y0ShapeRange = *inputX0ShapeRange; + *y1ShapeRange = *inputX1ShapeRange; + *y2ShapeRange = *inputX4ShapeRange; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus SetStopValueMultiEndsV2InferDataType(gert::InferDataTypeContext *context) + { + const ge::DataType x0DataType = context->GetInputDataType(0); + const ge::DataType x1DataType = context->GetInputDataType(1); + const ge::DataType x4DataType = context->GetInputDataType(4); + + context->SetOutputDataType(0, x0DataType); + context->SetOutputDataType(1, x1DataType); + context->SetOutputDataType(2, x4DataType); + + return GRAPH_SUCCESS; + } +} + +namespace ops { +class SetStopValueMultiEndsV2 : public OpDef { +public: + explicit SetStopValueMultiEndsV2(const char* name) : OpDef(name) + { + this->Input("topkIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("stopFlags") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seqLens") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("endIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("nextTokens") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("topkIdsOut") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("stopFlagsOut") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("nextTokensOut") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::SetStopValueMultiEndsV2InferShape) + .SetInferShapeRange(ge::SetStopValueMultiEndsV2InferShapeRange) + .SetInferDataType(ge::SetStopValueMultiEndsV2InferDataType);; + + this->AICore() + .SetTiling(optiling::SetStopValueMultiEndsV2TilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } +}; + +OP_ADD(SetStopValueMultiEndsV2); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends_v2_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends_v2_tiling.h new file mode 100644 index 00000000000..dcede468aad --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_ends_v2_tiling.h @@ -0,0 +1,25 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(SetStopValueMultiEndsV2TilingData) +TILING_DATA_FIELD_DEF(int32_t, bs); +TILING_DATA_FIELD_DEF(int32_t, length); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(SetStopValueMultiEndsV2, SetStopValueMultiEndsV2TilingData) +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_seqs.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_seqs.cpp new file mode 100644 index 00000000000..9b3e7b680b1 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_seqs.cpp @@ -0,0 +1,182 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "set_stop_value_multi_seqs_tiling.h" +#include "register/op_def_registry.h" + +namespace { + constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} +namespace optiling { + static ge::graphStatus SetStopValueMultiSeqsTilingFunc(gert::TilingContext* context) + { + SetStopValueMultiSeqsTilingData tiling; + const gert::StorageShape* preIdsShape = context->GetInputShape(1); + const gert::StorageShape* stopSeqsShape = context->GetInputShape(5); + const gert::StorageShape* eosShape = context->GetInputShape(7); + + int32_t bs = preIdsShape->GetStorageShape().GetDim(0); + int32_t length = preIdsShape->GetStorageShape().GetDim(1); + int32_t stop_seqs_num = stopSeqsShape->GetStorageShape().GetDim(0); + int32_t stop_seqs_max_len = stopSeqsShape->GetStorageShape().GetDim(1); + int32_t eos_len = eosShape->GetStorageShape().GetDim(0); + + + std::cout << "[INFO] SetStopValueMultiSeqs tiling result, bs " << bs << ", length " << length + << ", stop_seqs_num " << stop_seqs_num << ", stop_seqs_max_len " << stop_seqs_max_len << ", eos_len " << eos_len << std::endl; + + tiling.set_bs(bs); + tiling.set_length(length); + tiling.set_stop_seqs_num(stop_seqs_num); + tiling.set_stop_seqs_max_len(stop_seqs_max_len); + tiling.set_eos_len(eos_len); + + int32_t blockSize = 1; + context->SetBlockDim(blockSize); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + std::cout << "GetWorkspaceNum Failed" << std::endl; + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; + } +} + +namespace ge { + static ge::graphStatus SetStopValueMultiSeqsInferShape(gert::InferShapeContext* context) + { + const gert::Shape* x0_shape = context->GetInputShape(0); + const gert::Shape* x1_shape = context->GetInputShape(3); + + gert::Shape* y0_shape = context->GetOutputShape(0); + gert::Shape* y1_shape = context->GetOutputShape(1); + + *y0_shape = *x0_shape; + *y1_shape = *x1_shape; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus SetStopValueMultiSeqsInferShapeRange(gert::InferShapeRangeContext *context) + { + const gert::Range *inputX0ShapeRange = context->GetInputShapeRange(0); + const gert::Range *inputX1ShapeRange = context->GetInputShapeRange(3); + + gert::Range *y0ShapeRange = context->GetOutputShapeRange(0); + gert::Range *y1ShapeRange = context->GetOutputShapeRange(1); + + *y0ShapeRange = *inputX0ShapeRange; + *y1ShapeRange = *inputX1ShapeRange; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus SetStopValueMultiSeqsInferDataType(gert::InferDataTypeContext *context) + { + const ge::DataType x0DataType = context->GetInputDataType(0); + const ge::DataType x1DataType = context->GetInputDataType(3); + + context->SetOutputDataType(0, x0DataType); + context->SetOutputDataType(1, x1DataType); + + return GRAPH_SUCCESS; + } +} + +namespace ops { +class SetStopValueMultiSeqs : public OpDef { +public: + explicit SetStopValueMultiSeqs(const char* name) : OpDef(name) + { + this->Input("topkIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("preIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("stepIdx") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("stopFlags") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seqLens") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("stopSeqs") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("stopSeqsLen") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("endIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("topkIdsOut") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("stopFlagsOut") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::SetStopValueMultiSeqsInferShape) + .SetInferShapeRange(ge::SetStopValueMultiSeqsInferShapeRange) + .SetInferDataType(ge::SetStopValueMultiSeqsInferDataType);; + + this->AICore() + .SetTiling(optiling::SetStopValueMultiSeqsTilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } +}; + +OP_ADD(SetStopValueMultiSeqs); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_seqs_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_seqs_tiling.h new file mode 100644 index 00000000000..4a952946358 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_stop_value_multi_seqs_tiling.h @@ -0,0 +1,28 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(SetStopValueMultiSeqsTilingData) +TILING_DATA_FIELD_DEF(int32_t, bs); +TILING_DATA_FIELD_DEF(int32_t, length); +TILING_DATA_FIELD_DEF(int32_t, stop_seqs_num); +TILING_DATA_FIELD_DEF(int32_t, stop_seqs_max_len); +TILING_DATA_FIELD_DEF(int32_t, eos_len); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(SetStopValueMultiSeqs, SetStopValueMultiSeqsTilingData) +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx.cpp new file mode 100644 index 00000000000..a2238c25233 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx.cpp @@ -0,0 +1,122 @@ +#include "set_value_by_flags_and_idx_tiling.h" +#include "register/op_def_registry.h" + +using namespace std; + +namespace { +constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { + static ge::graphStatus SetValueByFlagsAndIdxTilingFunc(gert::TilingContext *context) + { + SetValueByFlagsAndIdxTilingData tiling; + const gert::StorageShape* preIdsAllShape = context->GetInputShape(0); + const gert::StorageShape* stopFlagsShape = context->GetInputShape(3); + + int bs = stopFlagsShape->GetStorageShape().GetDim(0); + int length = preIdsAllShape->GetStorageShape().GetDim(1); + + tiling.set_bs(bs); + tiling.set_length(length); + + int32_t blockDims = 1; + context->SetBlockDim(blockDims); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + std::cout << "GetWorkspaceNum Failed" << std::endl; + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; + } +} + +namespace ge { + static ge::graphStatus SetValueByFlagsAndIdxInferShape(gert::InferShapeContext *context) + { + const gert::Shape* x3Shape = context->GetInputShape(3); + + gert::Shape* y0Shape = context->GetOutputShape(0); + + *y0Shape = *x3Shape; + + int num = y0Shape->GetDimNum(); + + return GRAPH_SUCCESS; + } + + static ge::graphStatus SetValueByFlagsAndIdxInferShapeRange(gert::InferShapeRangeContext *context) + { + const gert::Range *inputXShapeRange = context->GetInputShapeRange(3); + gert::Range *y0ShapeRange = context->GetOutputShapeRange(0); + + *y0ShapeRange = *inputXShapeRange; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus SetValueByFlagsAndIdxInferDataType(gert::InferDataTypeContext *context) + { + const ge::DataType x3DataType = context->GetInputDataType(3); + + context->SetOutputDataType(0, x3DataType); + + return GRAPH_SUCCESS; + } +} + +namespace ops { + class SetValueByFlagsAndIdx : public OpDef { + public: + explicit SetValueByFlagsAndIdx(const char *name) : OpDef(name) + { + this->Input("preIdsAll") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("preIdsNow") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("stepIdx") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("stopFlags") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("stopFlagsOut") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::SetValueByFlagsAndIdxInferShape) + .SetInferShapeRange(ge::SetValueByFlagsAndIdxInferShapeRange) + .SetInferDataType(ge::SetValueByFlagsAndIdxInferDataType); + + this->AICore() + .SetTiling(optiling::SetValueByFlagsAndIdxTilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } + }; + + OP_ADD(SetValueByFlagsAndIdx); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx_tiling.h new file mode 100644 index 00000000000..622204b3189 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx_tiling.h @@ -0,0 +1,10 @@ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(SetValueByFlagsAndIdxTilingData) + TILING_DATA_FIELD_DEF(int32_t, bs); + TILING_DATA_FIELD_DEF(int32_t, length); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(SetValueByFlagsAndIdx, SetValueByFlagsAndIdxTilingData) +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx_v2.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx_v2.cpp new file mode 100644 index 00000000000..cffb41f4acc --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx_v2.cpp @@ -0,0 +1,157 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "set_value_by_flags_and_idx_v2_tiling.h" +#include "register/op_def_registry.h" + +using namespace std; + +namespace { +constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { + static ge::graphStatus SetValueByFlagsAndIdxV2TilingFunc(gert::TilingContext *context) + { + SetValueByFlagsAndIdxV2TilingData tiling; + const gert::StorageShape* seqLensThisTimeShape = context->GetInputShape(2); + const gert::StorageShape* preIdsAllShape = context->GetInputShape(0); + const gert::StorageShape* inputIdsShape = context->GetInputShape(1); + + int bs = seqLensThisTimeShape->GetStorageShape().GetDim(0); + int length = preIdsAllShape->GetStorageShape().GetDim(1); + int lengthInput = inputIdsShape->GetStorageShape().GetDim(1); + + tiling.set_bs(bs); + tiling.set_length(length); + tiling.set_lengthInput(lengthInput); + + int32_t blockDims = 1; + context->SetBlockDim(blockDims); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + std::cout << "GetWorkspaceNum Failed" << std::endl; + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; + } +} + +namespace ge { + static ge::graphStatus SetValueByFlagsAndIdxV2InferShape(gert::InferShapeContext *context) + { + const gert::Shape* x0Shape = context->GetInputShape(0); + + gert::Shape* y0Shape = context->GetOutputShape(0); + + *y0Shape = *x0Shape; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus SetValueByFlagsAndIdxV2InferShapeRange(gert::InferShapeRangeContext *context) + { + const gert::Range *inputXShapeRange = context->GetInputShapeRange(0); + + gert::Range *y0ShapeRange = context->GetOutputShapeRange(0); + + *y0ShapeRange = *inputXShapeRange; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus SetValueByFlagsAndIdxV2InferDataType(gert::InferDataTypeContext *context) + { + const ge::DataType x0DataType = context->GetInputDataType(0); + + context->SetOutputDataType(0, x0DataType); + + return GRAPH_SUCCESS; + } +} + +namespace ops { + class SetValueByFlagsAndIdxV2 : public OpDef { + public: + explicit SetValueByFlagsAndIdxV2(const char *name) : OpDef(name) + { + this->Input("preIdsAll") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("inputIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seqLensThisTime") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seqLensEncoder") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seqLensDecoder") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("stepIdx") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("stopFlags") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("preIdsAllOut") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::SetValueByFlagsAndIdxV2InferShape) + .SetInferShapeRange(ge::SetValueByFlagsAndIdxV2InferShapeRange) + .SetInferDataType(ge::SetValueByFlagsAndIdxV2InferDataType); + + this->AICore() + .SetTiling(optiling::SetValueByFlagsAndIdxV2TilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } + }; + + OP_ADD(SetValueByFlagsAndIdxV2); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx_v2_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx_v2_tiling.h new file mode 100644 index 00000000000..53e10559013 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/set_value_by_flags_and_idx_v2_tiling.h @@ -0,0 +1,26 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(SetValueByFlagsAndIdxV2TilingData) + TILING_DATA_FIELD_DEF(int32_t, bs); + TILING_DATA_FIELD_DEF(int32_t, length); + TILING_DATA_FIELD_DEF(int32_t, lengthInput); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(SetValueByFlagsAndIdxV2, SetValueByFlagsAndIdxV2TilingData) +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/step_paddle.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/step_paddle.cpp new file mode 100644 index 00000000000..587465a9961 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/step_paddle.cpp @@ -0,0 +1,469 @@ +#include "step_paddle_tiling.h" +#include "register/op_def_registry.h" + +using namespace std; + +namespace { +constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { + static ge::graphStatus StepPaddleTilingFunc(gert::TilingContext *context) + { + GetStepPaddleTilingData tiling; + const gert::StorageShape* seqLensThisTimeShape = context->GetInputShape(1); + const gert::StorageShape* blockTablesShape = context->GetInputShape(5); + const gert::StorageShape* inputIdsShape = context->GetInputShape(17); + const gert::StorageShape* preIdsShape = context->GetInputShape(18); + + int bsz = seqLensThisTimeShape->GetStorageShape().GetDim(0); + int block_num_per_seq = blockTablesShape->GetStorageShape().GetDim(1); + int length = inputIdsShape->GetStorageShape().GetDim(1); + int pre_id_length = preIdsShape->GetStorageShape().GetDim(1); + + auto attrs = context->GetAttrs(); + auto block_size_ptr = attrs->GetAttrPointer(0); + int block_size = *block_size_ptr; + auto encoder_decoder_block_num_ptr = attrs->GetAttrPointer(1); + int encoder_decoder_block_num = *encoder_decoder_block_num_ptr; + int max_decoder_block_num = pre_id_length / block_size - encoder_decoder_block_num; + auto first_token_id_ptr = attrs->GetAttrPointer(2); + auto first_token_id = *first_token_id_ptr; + + tiling.set_bsz(bsz); + tiling.set_block_size(block_size); + tiling.set_block_num_per_seq(block_num_per_seq); + tiling.set_max_decoder_block_num(max_decoder_block_num); + tiling.set_length(length); + tiling.set_pre_id_length(pre_id_length); + tiling.set_first_token_id(first_token_id); + + int32_t blockDims = 1; + context->SetBlockDim(blockDims); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + std::cout << "GetWorkspaceNum Failed" << std::endl; + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; + } +} + +namespace ge { + static ge::graphStatus StepPaddleInferShape(gert::InferShapeContext *context) + { + const gert::Shape* x0Shape = context->GetInputShape(0); + const gert::Shape* x1Shape = context->GetInputShape(1); + const gert::Shape* x3Shape = context->GetInputShape(3); + const gert::Shape* x4Shape = context->GetInputShape(4); + const gert::Shape* x5Shape = context->GetInputShape(5); + const gert::Shape* x6Shape = context->GetInputShape(6); + const gert::Shape* x7Shape = context->GetInputShape(7); + const gert::Shape* x8Shape = context->GetInputShape(8); + const gert::Shape* x9Shape = context->GetInputShape(9); + const gert::Shape* x10Shape = context->GetInputShape(10); + const gert::Shape* x11Shape = context->GetInputShape(11); + const gert::Shape* x12Shape = context->GetInputShape(12); + const gert::Shape* x13Shape = context->GetInputShape(13); + const gert::Shape* x14Shape = context->GetInputShape(14); + const gert::Shape* x15Shape = context->GetInputShape(15); + const gert::Shape* x16Shape = context->GetInputShape(16); + const gert::Shape* x17Shape = context->GetInputShape(17); + + gert::Shape* y0Shape = context->GetOutputShape(0); + gert::Shape* y1Shape = context->GetOutputShape(1); + gert::Shape* y2Shape = context->GetOutputShape(2); + gert::Shape* y3Shape = context->GetOutputShape(3); + gert::Shape* y4Shape = context->GetOutputShape(4); + gert::Shape* y5Shape = context->GetOutputShape(5); + gert::Shape* y6Shape = context->GetOutputShape(6); + gert::Shape* y7Shape = context->GetOutputShape(7); + gert::Shape* y8Shape = context->GetOutputShape(8); + gert::Shape* y9Shape = context->GetOutputShape(9); + gert::Shape* y10Shape = context->GetOutputShape(10); + gert::Shape* y11Shape = context->GetOutputShape(11); + gert::Shape* y12Shape = context->GetOutputShape(12); + gert::Shape* y13Shape = context->GetOutputShape(13); + gert::Shape* y14Shape = context->GetOutputShape(14); + gert::Shape* y15Shape = context->GetOutputShape(15); + gert::Shape* y16Shape = context->GetOutputShape(16); + + *y0Shape = *x0Shape; + *y1Shape = *x1Shape; + *y2Shape = *x3Shape; + *y3Shape = *x4Shape; + *y4Shape = *x5Shape; + *y5Shape = *x6Shape; + *y6Shape = *x7Shape; + *y7Shape = *x8Shape; + *y8Shape = *x9Shape; + *y9Shape = *x10Shape; + *y10Shape = *x11Shape; + *y11Shape = *x12Shape; + *y12Shape = *x13Shape; + *y13Shape = *x14Shape; + *y14Shape = *x15Shape; + *y15Shape = *x16Shape; + *y16Shape = *x17Shape; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus StepPaddleInferShapeRange(gert::InferShapeRangeContext *context) + { + const gert::Range *inputX0ShapeRange = context->GetInputShapeRange(0); + const gert::Range *inputX1ShapeRange = context->GetInputShapeRange(1); + const gert::Range *inputX3ShapeRange = context->GetInputShapeRange(3); + const gert::Range *inputX4ShapeRange = context->GetInputShapeRange(4); + const gert::Range *inputX5ShapeRange = context->GetInputShapeRange(5); + const gert::Range *inputX6ShapeRange = context->GetInputShapeRange(6); + const gert::Range *inputX7ShapeRange = context->GetInputShapeRange(7); + const gert::Range *inputX8ShapeRange = context->GetInputShapeRange(8); + const gert::Range *inputX9ShapeRange = context->GetInputShapeRange(9); + const gert::Range *inputX10ShapeRange = context->GetInputShapeRange(10); + const gert::Range *inputX11ShapeRange = context->GetInputShapeRange(11); + const gert::Range *inputX12ShapeRange = context->GetInputShapeRange(12); + const gert::Range *inputX13ShapeRange = context->GetInputShapeRange(13); + const gert::Range *inputX14ShapeRange = context->GetInputShapeRange(14); + const gert::Range *inputX15ShapeRange = context->GetInputShapeRange(15); + const gert::Range *inputX16ShapeRange = context->GetInputShapeRange(16); + const gert::Range *inputX17ShapeRange = context->GetInputShapeRange(17); + + gert::Range *y0ShapeRange = context->GetOutputShapeRange(0); + gert::Range *y1ShapeRange = context->GetOutputShapeRange(1); + gert::Range *y2ShapeRange = context->GetOutputShapeRange(2); + gert::Range *y3ShapeRange = context->GetOutputShapeRange(3); + gert::Range *y4ShapeRange = context->GetOutputShapeRange(4); + gert::Range *y5ShapeRange = context->GetOutputShapeRange(5); + gert::Range *y6ShapeRange = context->GetOutputShapeRange(6); + gert::Range *y7ShapeRange = context->GetOutputShapeRange(7); + gert::Range *y8ShapeRange = context->GetOutputShapeRange(8); + gert::Range *y9ShapeRange = context->GetOutputShapeRange(9); + gert::Range *y10ShapeRange = context->GetOutputShapeRange(10); + gert::Range *y11ShapeRange = context->GetOutputShapeRange(11); + gert::Range *y12ShapeRange = context->GetOutputShapeRange(12); + gert::Range *y13ShapeRange = context->GetOutputShapeRange(13); + gert::Range *y14ShapeRange = context->GetOutputShapeRange(14); + gert::Range *y15ShapeRange = context->GetOutputShapeRange(15); + gert::Range *y16ShapeRange = context->GetOutputShapeRange(16); + + *y0ShapeRange = *inputX0ShapeRange; + *y1ShapeRange = *inputX1ShapeRange; + *y2ShapeRange = *inputX3ShapeRange; + *y3ShapeRange = *inputX4ShapeRange; + *y4ShapeRange = *inputX5ShapeRange; + *y5ShapeRange = *inputX6ShapeRange; + *y6ShapeRange = *inputX7ShapeRange; + *y7ShapeRange = *inputX8ShapeRange; + *y8ShapeRange = *inputX9ShapeRange; + *y9ShapeRange = *inputX10ShapeRange; + *y10ShapeRange = *inputX11ShapeRange; + *y11ShapeRange = *inputX12ShapeRange; + *y12ShapeRange = *inputX13ShapeRange; + *y13ShapeRange = *inputX14ShapeRange; + *y14ShapeRange = *inputX15ShapeRange; + *y15ShapeRange = *inputX16ShapeRange; + *y16ShapeRange = *inputX17ShapeRange; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus StepPaddleInferDataType(gert::InferDataTypeContext *context) + { + const ge::DataType x0DataType = context->GetInputDataType(0); + const ge::DataType x1DataType = context->GetInputDataType(1); + const ge::DataType x3DataType = context->GetInputDataType(3); + const ge::DataType x4DataType = context->GetInputDataType(4); + const ge::DataType x5DataType = context->GetInputDataType(5); + const ge::DataType x6DataType = context->GetInputDataType(6); + const ge::DataType x7DataType = context->GetInputDataType(7); + const ge::DataType x8DataType = context->GetInputDataType(8); + const ge::DataType x9DataType = context->GetInputDataType(9); + const ge::DataType x10DataType = context->GetInputDataType(10); + const ge::DataType x11DataType = context->GetInputDataType(11); + const ge::DataType x12DataType = context->GetInputDataType(12); + const ge::DataType x13DataType = context->GetInputDataType(13); + const ge::DataType x14DataType = context->GetInputDataType(14); + const ge::DataType x15DataType = context->GetInputDataType(15); + const ge::DataType x16DataType = context->GetInputDataType(16); + const ge::DataType x17DataType = context->GetInputDataType(17); + + context->SetOutputDataType(0, x0DataType); + context->SetOutputDataType(1, x1DataType); + context->SetOutputDataType(2, x3DataType); + context->SetOutputDataType(3, x4DataType); + context->SetOutputDataType(4, x5DataType); + context->SetOutputDataType(5, x6DataType); + context->SetOutputDataType(6, x7DataType); + context->SetOutputDataType(7, x8DataType); + context->SetOutputDataType(8, x9DataType); + context->SetOutputDataType(9, x10DataType); + context->SetOutputDataType(10, x11DataType); + context->SetOutputDataType(11, x12DataType); + context->SetOutputDataType(12, x13DataType); + context->SetOutputDataType(13, x14DataType); + context->SetOutputDataType(14, x15DataType); + context->SetOutputDataType(15, x16DataType); + context->SetOutputDataType(16, x17DataType); + + return GRAPH_SUCCESS; + } +} + +namespace ops { + class StepPaddle : public OpDef { + public: + explicit StepPaddle(const char *name) : OpDef(name) + { + this->Input("stop_flags") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seq_lens_this_time") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("ori_seq_lens_encoder") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seq_lens_encoder") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seq_lens_decoder") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("block_tables") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("encoder_block_lens") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("is_block_step") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("step_block_list") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("step_lens") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("recover_block_list") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("recover_lens") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("need_block_list") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("need_block_len") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("used_list_len") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("free_list") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("free_list_len") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("input_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("pre_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("step_idx") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("next_tokens") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("stop_flags_out") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("seq_lens_this_time_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("seq_lens_encoder_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("seq_lens_decoder_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("block_tables_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("encoder_block_lens_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("is_block_step_out") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("step_block_list_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("step_lens_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("recover_block_list_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("recover_lens_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("need_block_list_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("need_block_len_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("used_list_len_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("free_list_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("free_list_len_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("input_ids_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::StepPaddleInferShape) + .SetInferShapeRange(ge::StepPaddleInferShapeRange) + .SetInferDataType(ge::StepPaddleInferDataType); + + this->Attr("block_size").AttrType(REQUIRED).Int(); + this->Attr("encoder_decoder_block_num").AttrType(REQUIRED).Int(); + this->Attr("first_token_id").AttrType(REQUIRED).Int(0); + + this->AICore() + .SetTiling(optiling::StepPaddleTilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } + }; + + OP_ADD(StepPaddle); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/step_paddle_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/step_paddle_tiling.h new file mode 100644 index 00000000000..8167b298470 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/step_paddle_tiling.h @@ -0,0 +1,16 @@ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(GetStepPaddleTilingData) + TILING_DATA_FIELD_DEF(int32_t, bsz); + TILING_DATA_FIELD_DEF(int32_t, block_size); + TILING_DATA_FIELD_DEF(int32_t, block_num_per_seq); + TILING_DATA_FIELD_DEF(int32_t, max_decoder_block_num); + TILING_DATA_FIELD_DEF(int32_t, length); + TILING_DATA_FIELD_DEF(int32_t, pre_id_length); + TILING_DATA_FIELD_DEF(int64_t, first_token_id); + +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(StepPaddle, GetStepPaddleTilingData) +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores.cpp new file mode 100644 index 00000000000..7c50c400a0a --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores.cpp @@ -0,0 +1,199 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "token_penalty_multi_scores_tiling.h" + +#include "register/op_def_registry.h" +#include "tiling/platform/platform_ascendc.h" + +using namespace std; + +namespace { +constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { + static ge::graphStatus TokenPenaltyMultiScoresTilingFunc(gert::TilingContext *context) + { + TokenPenaltyMultiScoresTilingData tiling; + const gert::StorageShape* preIdsShape = context->GetInputShape(0); + const gert::StorageShape* logitsShape = context->GetInputShape(1); + const gert::StorageShape* endTokenIdShape = context->GetInputShape(8); + + int bs = logitsShape->GetStorageShape().GetDim(0); + int vs = logitsShape->GetStorageShape().GetDim(1); + int seqLen = preIdsShape->GetStorageShape().GetDim(1); + int etiLength = endTokenIdShape->GetStorageShape().GetDim(0); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + int coreNum = ascendcPlatform.GetCoreNumAiv(); + int coreLoop = (bs + coreNum - 1) / coreNum; + int blockDim = (bs + coreLoop - 1) / coreLoop; + + int blockElements = 256 / sizeof(float); // 256 bytes for select + if (vs % blockElements != 0) { + std::cout << "[ERROR] TokenPenaltyMultiScores voc_size " << vs << " is invalid" << std::endl; + return ge::GRAPH_FAILED; + } + int vsBlockNum = vs / blockElements; + int vsBlockBase = blockElements; + int vsBlock = vsBlockBase; + for (int i = 2; i < vsBlockNum; i++) { + if (vsBlockNum % i == 0) { + int vsBlockTmp = vsBlockBase * i; + if (vsBlockTmp > 2048) { // 2048 is max vsBlock + break; + } + vsBlock = vsBlockTmp; + } + } + + std::cout << "[INFO] TokenPenaltyMultiScores tiling result, vs " << vs + << ", vsBlock " << vsBlock << ", seqLen " << seqLen << ", etil " << etiLength + << ", bs " << bs << ", bsBlock " << coreLoop << std::endl; + + tiling.set_vs(vs); + tiling.set_vsBlock(vsBlock); + tiling.set_seqLen(seqLen); + tiling.set_etil(etiLength); + tiling.set_bs(bs); + tiling.set_bsBlock(coreLoop); + + context->SetBlockDim(blockDim); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; + } +} + +namespace ge { + static ge::graphStatus TokenPenaltyMultiScoresInferShape(gert::InferShapeContext *context) + { + const gert::Shape* x3Shape = context->GetInputShape(1); + + gert::Shape* y0Shape = context->GetOutputShape(0); + + *y0Shape = *x3Shape; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus TokenPenaltyMultiScoresInferShapeRange(gert::InferShapeRangeContext *context) + { + const gert::Range *inputXShapeRange = context->GetInputShapeRange(1); + gert::Range *y0ShapeRange = context->GetOutputShapeRange(0); + + *y0ShapeRange = *inputXShapeRange; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus TokenPenaltyMultiScoresInferDataType(gert::InferDataTypeContext *context) + { + const ge::DataType x2DataType = context->GetInputDataType(1); + + context->SetOutputDataType(0, x2DataType); + + return GRAPH_SUCCESS; + } +} + +namespace ops { + class TokenPenaltyMultiScores : public OpDef { + public: + explicit TokenPenaltyMultiScores(const char *name) : OpDef(name) + { + this->Input("preIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("logits") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("repeatTimes") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("penaltyScores") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("frequencyScores") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("presenceScores") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("curLen") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("minLen") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("eosTokenId") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("logitsOut") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::TokenPenaltyMultiScoresInferShape) + .SetInferShapeRange(ge::TokenPenaltyMultiScoresInferShapeRange) + .SetInferDataType(ge::TokenPenaltyMultiScoresInferDataType); + + this->AICore() + .SetTiling(optiling::TokenPenaltyMultiScoresTilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } + }; + + OP_ADD(TokenPenaltyMultiScores); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_tiling.h new file mode 100644 index 00000000000..b8f417490ed --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_tiling.h @@ -0,0 +1,29 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(TokenPenaltyMultiScoresTilingData) +TILING_DATA_FIELD_DEF(int32_t, vs); // vocab size +TILING_DATA_FIELD_DEF(int32_t, vsBlock); // vocab size per core +TILING_DATA_FIELD_DEF(int32_t, seqLen); // seq length +TILING_DATA_FIELD_DEF(int32_t, etil); // eos_token_id length +TILING_DATA_FIELD_DEF(int32_t, bs); // batch size +TILING_DATA_FIELD_DEF(int32_t, bsBlock) // batch size per core +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(TokenPenaltyMultiScores, TokenPenaltyMultiScoresTilingData) +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_v2.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_v2.cpp new file mode 100644 index 00000000000..9766e80434e --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_v2.cpp @@ -0,0 +1,214 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "token_penalty_multi_scores_v2_tiling.h" + +#include "register/op_def_registry.h" +#include "tiling/platform/platform_ascendc.h" + +using namespace std; + +namespace { +constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { + static ge::graphStatus TokenPenaltyMultiScoresV2TilingFunc(gert::TilingContext *context) + { + TokenPenaltyMultiScoresV2TilingData tiling; + const gert::StorageShape* preIdsShape = context->GetInputShape(0); + const gert::StorageShape* logitsShape = context->GetInputShape(1); + const gert::StorageShape* badWordsShape = context->GetInputShape(7); + const gert::StorageShape* endTokenIdShape = context->GetInputShape(10); + + int bs = logitsShape->GetStorageShape().GetDim(0); + int vs = logitsShape->GetStorageShape().GetDim(1); + int seqLen = preIdsShape->GetStorageShape().GetDim(1); + int etiLength = endTokenIdShape->GetStorageShape().GetDim(0); + int badWLen = badWordsShape->GetStorageShape().GetDim(0); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + int coreNum = ascendcPlatform.GetCoreNumAiv(); + int coreLoop = (bs + coreNum - 1) / coreNum; + int blockDim = (bs + coreLoop - 1) / coreLoop; + + int blockElements = 256 / sizeof(float); // 256 bytes for select + if (vs % blockElements != 0) { + std::cout << "[ERROR] TokenPenaltyMultiScoresV2 voc_size " << vs << " is invalid" << std::endl; + return ge::GRAPH_FAILED; + } + int vsBlockNum = vs / blockElements; + int vsBlockBase = blockElements; + int vsBlock = vsBlockBase; + for (int i = 2; i < vsBlockNum; i++) { + if (vsBlockNum % i == 0) { + int vsBlockTmp = vsBlockBase * i; + if (vsBlockTmp > 2048) { // 2048 is max vsBlock + break; + } + vsBlock = vsBlockTmp; + } + } + + std::cout << "[INFO] TokenPenaltyMultiScoresV2 tiling result, vs " << vs + << ", vsBlock " << vsBlock << ", seqLen " << seqLen << ", etil " << etiLength + << ", bs " << bs << ", bsBlock " << coreLoop << ", badTokenLen " << badWLen <SetBlockDim(blockDim); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; + } +} + +namespace ge { + static ge::graphStatus TokenPenaltyMultiScoresV2InferShape(gert::InferShapeContext *context) + { + const gert::Shape* x3Shape = context->GetInputShape(1); + + gert::Shape* y0Shape = context->GetOutputShape(0); + + *y0Shape = *x3Shape; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus TokenPenaltyMultiScoresV2InferShapeRange(gert::InferShapeRangeContext *context) + { + const gert::Range *inputXShapeRange = context->GetInputShapeRange(1); + gert::Range *y0ShapeRange = context->GetOutputShapeRange(0); + + *y0ShapeRange = *inputXShapeRange; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus TokenPenaltyMultiScoresV2InferDataType(gert::InferDataTypeContext *context) + { + const ge::DataType x2DataType = context->GetInputDataType(1); + + context->SetOutputDataType(0, x2DataType); + + return GRAPH_SUCCESS; + } +} + +namespace ops { + class TokenPenaltyMultiScoresV2 : public OpDef { + public: + explicit TokenPenaltyMultiScoresV2(const char *name) : OpDef(name) + { + this->Input("preIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("logits") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("repeatTimes") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("penaltyScores") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("frequencyScores") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("presenceScores") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("temperatures") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("badTokens") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("curLen") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("minLen") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("eosTokenId") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("logitsOut") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::TokenPenaltyMultiScoresV2InferShape) + .SetInferShapeRange(ge::TokenPenaltyMultiScoresV2InferShapeRange) + .SetInferDataType(ge::TokenPenaltyMultiScoresV2InferDataType); + + this->AICore() + .SetTiling(optiling::TokenPenaltyMultiScoresV2TilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } + }; + + OP_ADD(TokenPenaltyMultiScoresV2); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_v2_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_v2_tiling.h new file mode 100644 index 00000000000..aee1bdcc5f9 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_v2_tiling.h @@ -0,0 +1,31 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(TokenPenaltyMultiScoresV2TilingData) +TILING_DATA_FIELD_DEF(int32_t, vs); // vocab size +TILING_DATA_FIELD_DEF(int32_t, vsBlock); // vocab size per core +TILING_DATA_FIELD_DEF(int32_t, seqLen); // seq length +TILING_DATA_FIELD_DEF(int32_t, etil); // eos_token_id length +TILING_DATA_FIELD_DEF(int32_t, bs); // batch size +TILING_DATA_FIELD_DEF(int32_t, bsBlock); // batch size per core +TILING_DATA_FIELD_DEF(int32_t, badWLen); // bad token length +TILING_DATA_FIELD_DEF(int32_t, reserve); // +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(TokenPenaltyMultiScoresV2, TokenPenaltyMultiScoresV2TilingData) +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_with_stop_seqs.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_with_stop_seqs.cpp new file mode 100644 index 00000000000..c08a77b60b1 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_with_stop_seqs.cpp @@ -0,0 +1,217 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "token_penalty_multi_scores_with_stop_seqs_tiling.h" + +#include "register/op_def_registry.h" +#include "tiling/platform/platform_ascendc.h" + +using namespace std; + +namespace { +constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { + static ge::graphStatus TokenPenaltyMultiScoresWithStopSeqsTilingFunc(gert::TilingContext *context) + { + TokenPenaltyMultiScoresWithStopSeqsTilingData tiling; + const gert::StorageShape* preIdsShape = context->GetInputShape(0); + const gert::StorageShape* logitsShape = context->GetInputShape(1); + const gert::StorageShape* stopSeqsShape = context->GetInputShape(8); + const gert::StorageShape* eosTokenIdsShape = context->GetInputShape(10); + + int bs = logitsShape->GetStorageShape().GetDim(0); + int vs = logitsShape->GetStorageShape().GetDim(1); + int seqLen = preIdsShape->GetStorageShape().GetDim(1); + int32_t stop_seqs_num = stopSeqsShape->GetStorageShape().GetDim(0); + int32_t stop_seqs_max_len = stopSeqsShape->GetStorageShape().GetDim(1); + int32_t eos_len = eosTokenIdsShape->GetStorageShape().GetDim(0); + + auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo()); + int coreNum = ascendcPlatform.GetCoreNumAiv(); + int coreLoop = (bs + coreNum - 1) / coreNum; + int blockDim = (bs + coreLoop - 1) / coreLoop; + + int blockElements = 256 / sizeof(float); // 256 bytes for select + if (vs % blockElements != 0) { + std::cout << "[ERROR] TokenPenaltyMultiScoresWithStopSeqs voc_size " << vs << " is invalid" << std::endl; + return ge::GRAPH_FAILED; + } + int vsBlockNum = vs / blockElements; + int vsBlockBase = blockElements; + int vsBlock = vsBlockBase; + for (int i = 2; i < vsBlockNum; i++) { + if (vsBlockNum % i == 0) { + int vsBlockTmp = vsBlockBase * i; + if (vsBlockTmp > 2048) { // 2048 is max vsBlock + break; + } + vsBlock = vsBlockTmp; + } + } + + std::cout << "[INFO] TokenPenaltyMultiScoresWithStopSeqs tiling result, vs " << vs + << ", vsBlock " << vsBlock << ", seqLen " << seqLen << ", stop_seqs_num " << stop_seqs_num + << ", stop_seqs_max_len " << stop_seqs_max_len << ", eos_len " << eos_len + << ", bs " << bs << ", bsBlock " << coreLoop << std::endl; + + tiling.set_vs(vs); + tiling.set_vsBlock(vsBlock); + tiling.set_seqLen(seqLen); + tiling.set_bs(bs); + tiling.set_bsBlock(coreLoop); + tiling.set_stop_seqs_num(stop_seqs_num); + tiling.set_stop_seqs_max_len(stop_seqs_max_len); + tiling.set_eos_len(eos_len); + + context->SetBlockDim(blockDim); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; + } +} + +namespace ge { + static ge::graphStatus TokenPenaltyMultiScoresWithStopSeqsInferShape(gert::InferShapeContext *context) + { + const gert::Shape* x3Shape = context->GetInputShape(1); + + gert::Shape* y0Shape = context->GetOutputShape(0); + + *y0Shape = *x3Shape; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus TokenPenaltyMultiScoresWithStopSeqsInferShapeRange(gert::InferShapeRangeContext *context) + { + const gert::Range *inputXShapeRange = context->GetInputShapeRange(1); + gert::Range *y0ShapeRange = context->GetOutputShapeRange(0); + + *y0ShapeRange = *inputXShapeRange; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus TokenPenaltyMultiScoresWithStopSeqsInferDataType(gert::InferDataTypeContext *context) + { + const ge::DataType x2DataType = context->GetInputDataType(1); + + context->SetOutputDataType(0, x2DataType); + + return GRAPH_SUCCESS; + } +} + +namespace ops { + class TokenPenaltyMultiScoresWithStopSeqs : public OpDef { + public: + explicit TokenPenaltyMultiScoresWithStopSeqs(const char *name) : OpDef(name) + { + this->Input("preIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("logits") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("repeatTimes") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("penaltyScores") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("frequencyScores") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("presenceScores") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("curLen") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("minLen") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("stopSeqs") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("stopSeqsLen") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("eosTokenIds") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("logitsOut") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::TokenPenaltyMultiScoresWithStopSeqsInferShape) + .SetInferShapeRange(ge::TokenPenaltyMultiScoresWithStopSeqsInferShapeRange) + .SetInferDataType(ge::TokenPenaltyMultiScoresWithStopSeqsInferDataType); + + this->AICore() + .SetTiling(optiling::TokenPenaltyMultiScoresWithStopSeqsTilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } + }; + + OP_ADD(TokenPenaltyMultiScoresWithStopSeqs); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_with_stop_seqs_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_with_stop_seqs_tiling.h new file mode 100644 index 00000000000..a26d22b641c --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/token_penalty_multi_scores_with_stop_seqs_tiling.h @@ -0,0 +1,31 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(TokenPenaltyMultiScoresWithStopSeqsTilingData) +TILING_DATA_FIELD_DEF(int32_t, vs); // vocab size +TILING_DATA_FIELD_DEF(int32_t, vsBlock); // vocab size per core +TILING_DATA_FIELD_DEF(int32_t, seqLen); // seq length +TILING_DATA_FIELD_DEF(int32_t, stop_seqs_num); // eos_token_id length +TILING_DATA_FIELD_DEF(int32_t, stop_seqs_max_len); // eos_token_id length +TILING_DATA_FIELD_DEF(int32_t, eos_len); // eos_token_id length +TILING_DATA_FIELD_DEF(int32_t, bs); // batch size +TILING_DATA_FIELD_DEF(int32_t, bsBlock) // batch size per core +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(TokenPenaltyMultiScoresWithStopSeqs, TokenPenaltyMultiScoresWithStopSeqsTilingData) +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/update_inputs.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/update_inputs.cpp new file mode 100644 index 00000000000..678a2383f92 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/update_inputs.cpp @@ -0,0 +1,239 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "update_inputs_tiling.h" +#include "register/op_def_registry.h" +#include + +using namespace std; + +namespace { +constexpr uint32_t MINIMAL_WORKSPACE = 16 * 1024 * 1024; +} + +namespace optiling { + static ge::graphStatus UpdateInputsTilingFunc(gert::TilingContext *context) + { + UpdateInputsTilingData tiling; + const gert::StorageShape* stopFlagsShape = context->GetInputShape(0); + const gert::StorageShape* seqLensThisTimeShape = context->GetInputShape(2); + const gert::StorageShape* inputIdsShape = context->GetInputShape(5); + int max_bs = stopFlagsShape->GetStorageShape().GetDim(0); + int bs = seqLensThisTimeShape->GetStorageShape().GetDim(0); + int length = inputIdsShape->GetStorageShape().GetDim(1); + + tiling.set_bs(bs); + tiling.set_max_bs(max_bs); + tiling.set_length(length); + + int32_t blockDims = 1; + context->SetBlockDim(blockDims); + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + if (context->GetWorkspaceNum() <= 0) { + std::cout << "GetWorkspaceNum Failed" << std::endl; + return ge::GRAPH_FAILED; + } + currentWorkspace[0] = MINIMAL_WORKSPACE; + return ge::GRAPH_SUCCESS; + } +} + +namespace ge { + static ge::graphStatus UpdateInputsInferShape(gert::InferShapeContext *context) + { + + const gert::Shape* x0Shape = context->GetInputShape(1); + const gert::Shape* x1Shape = context->GetInputShape(2); + const gert::Shape* x2Shape = context->GetInputShape(3); + const gert::Shape* x3Shape = context->GetInputShape(4); + const gert::Shape* x4Shape = context->GetInputShape(5); + + gert::Shape* y0Shape = context->GetOutputShape(0); + + gert::Shape* y1Shape = context->GetOutputShape(1); + + gert::Shape* y2Shape = context->GetOutputShape(2); + + gert::Shape* y3Shape = context->GetOutputShape(3); + + gert::Shape* y4Shape = context->GetOutputShape(4); + + *y0Shape = *x0Shape; + + *y1Shape = *x1Shape; + + *y2Shape = *x2Shape; + + *y3Shape = *x3Shape; + + *y4Shape = *x4Shape; + + + return GRAPH_SUCCESS; + } + + static ge::graphStatus UpdateInputsInferShapeRange(gert::InferShapeRangeContext *context) + { + const gert::Range *inputX0ShapeRange = context->GetInputShapeRange(1); + const gert::Range *inputX1ShapeRange = context->GetInputShapeRange(2); + const gert::Range *inputX2ShapeRange = context->GetInputShapeRange(3); + const gert::Range *inputX3ShapeRange = context->GetInputShapeRange(4); + const gert::Range *inputX4ShapeRange = context->GetInputShapeRange(5); + + gert::Range *y0ShapeRange = context->GetOutputShapeRange(0); + gert::Range *y1ShapeRange = context->GetOutputShapeRange(1); + gert::Range *y2ShapeRange = context->GetOutputShapeRange(2); + gert::Range *y3ShapeRange = context->GetOutputShapeRange(3); + gert::Range *y4ShapeRange = context->GetOutputShapeRange(4); + + + + *y0ShapeRange = *inputX0ShapeRange; + *y1ShapeRange = *inputX1ShapeRange; + *y2ShapeRange = *inputX2ShapeRange; + *y3ShapeRange = *inputX3ShapeRange; + *y4ShapeRange = *inputX4ShapeRange; + + return GRAPH_SUCCESS; + } + + static ge::graphStatus UpdateInputsInferDataType(gert::InferDataTypeContext *context) + { + + const ge::DataType x0DataType = context->GetInputDataType(1); + const ge::DataType x1DataType = context->GetInputDataType(2); + const ge::DataType x2DataType = context->GetInputDataType(3); + const ge::DataType x3DataType = context->GetInputDataType(4); + const ge::DataType x4DataType = context->GetInputDataType(5); + + + context->SetOutputDataType(0, x0DataType); + context->SetOutputDataType(1, x1DataType); + context->SetOutputDataType(2, x2DataType); + context->SetOutputDataType(3, x3DataType); + context->SetOutputDataType(4, x4DataType); + + return GRAPH_SUCCESS; + } +} + +namespace ops { + class UpdateInputs : public OpDef { + public: + explicit UpdateInputs(const char *name) : OpDef(name) + { + this->Input("stop_flags") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("not_need_stop") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seq_lens_this_time") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seq_lens_encoder") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("seq_lens_decoder") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("input_ids") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("stop_nums") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("next_tokens") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Input("is_block_step") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("not_need_stop_out") + .ParamType(REQUIRED) + .DataType({ge::DT_BOOL}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("seq_lens_this_time_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("seq_lens_encoder_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("seq_lens_decoder_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->Output("input_ids_out") + .ParamType(REQUIRED) + .DataType({ge::DT_INT64}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::UpdateInputsInferShape) + .SetInferShapeRange(ge::UpdateInputsInferShapeRange) + .SetInferDataType(ge::UpdateInputsInferDataType); + + this->AICore() + .SetTiling(optiling::UpdateInputsTilingFunc); + + this->AICore().AddConfig("ascend310p"); + this->AICore().AddConfig("ascend910"); + this->AICore().AddConfig("ascend910b"); + } + }; + + OP_ADD(UpdateInputs); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/update_inputs_tiling.h b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/update_inputs_tiling.h new file mode 100644 index 00000000000..7d5d16bf8fc --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_host/update_inputs_tiling.h @@ -0,0 +1,26 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(UpdateInputsTilingData) +TILING_DATA_FIELD_DEF(int32_t, bs); +TILING_DATA_FIELD_DEF(int32_t, max_bs); +TILING_DATA_FIELD_DEF(int32_t, length); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(UpdateInputs, UpdateInputsTilingData) +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/get_max_len.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/get_max_len.cpp new file mode 100644 index 00000000000..691aadacbe9 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/get_max_len.cpp @@ -0,0 +1,52 @@ +#include "kernel_operator.h" +using namespace AscendC; + +class GetMaxLen { +public: + __aicore__ inline GetMaxLen(int32_t bs) + { + this->batchNum = bs; + } + + __aicore__ inline void Init(__gm__ uint8_t *seqLensEncoder, __gm__ uint8_t *seqLensDecoder, __gm__ uint8_t *seqLensEncoderOut, __gm__ uint8_t *seqLensDecoderOut) + { + seqLensEncoderGm = (__gm__ int32_t *)seqLensEncoder; + seqLensDecoderGm = (__gm__ int32_t *)seqLensDecoder; + seqLensEncoderOutGm = (__gm__ int32_t *)seqLensEncoderOut; + seqLensDecoderOutGm = (__gm__ int32_t *)seqLensDecoderOut; + } + + __aicore__ inline void Process() + { + *(seqLensEncoderOutGm) = 0; + *(seqLensDecoderOutGm) = 0; + + for (int32_t i = 0; i < batchNum; i++) { + if (*(seqLensEncoderGm + i) > *seqLensEncoderOutGm) { + *seqLensEncoderOutGm = *(seqLensEncoderGm + i); + } + + if (*(seqLensDecoderGm + i) > *seqLensDecoderOutGm) { + *seqLensDecoderOutGm = *(seqLensDecoderGm + i); + } + } + pipe_barrier(PIPE_ALL); + } + +private: + int32_t batchNum = 0; + + __gm__ int32_t *seqLensEncoderGm; + __gm__ int32_t *seqLensDecoderGm; + __gm__ int32_t *seqLensEncoderOutGm; + __gm__ int32_t *seqLensDecoderOutGm; +}; + +extern "C" __global__ __aicore__ void get_max_len(GM_ADDR seqLensEncoder, GM_ADDR seqLensDecoder, GM_ADDR seqLensEncoderOut, GM_ADDR seqLensDecoderOut, + GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + GetMaxLen op(tilingData.bs); + op.Init(seqLensEncoder, seqLensDecoder, seqLensEncoderOut, seqLensDecoderOut); + op.Process(); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/get_padding_offset.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/get_padding_offset.cpp new file mode 100644 index 00000000000..82756f1e9a7 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/get_padding_offset.cpp @@ -0,0 +1,180 @@ +#include "kernel_operator.h" +using namespace AscendC; + +constexpr int32_t BUFFER_NUM = 1; +constexpr int32_t ELE_PER_BLK = 8; +constexpr int32_t MAX_BATCH_NUM = 256; + +namespace { + +class GetPaddingOffset { +public: + __aicore__ inline GetPaddingOffset() {} + __aicore__ inline void Init(GM_ADDR input_ids, + GM_ADDR cum_offsets_now, GM_ADDR token_num, + GM_ADDR seq_len, GM_ADDR x_remove_padding, + GM_ADDR cum_offsets_out, GM_ADDR padding_offset, + GM_ADDR cu_seqlens_q, GM_ADDR cu_seqlens_k, + uint32_t padLength_, uint32_t batch) + { + this->padLength_ = padLength_; + this->batch_ = batch; + inputIdsGm.SetGlobalBuffer((__gm__ int64_t *)input_ids, padLength_ * batch_); + cumOffsetsNowGm.SetGlobalBuffer((__gm__ int32_t *)cum_offsets_now, batch_); + tokenNumGm.SetGlobalBuffer((__gm__ int64_t *)token_num, 1); + seqLenGm.SetGlobalBuffer((__gm__ int32_t *)seq_len, batch_); + xRemovePaddingGm.SetGlobalBuffer((__gm__ int64_t *)x_remove_padding, padLength_ * batch_); + cumOffsetOutGm.SetGlobalBuffer((__gm__ int32_t *)cum_offsets_out, batch_); + paddingOffsetGm.SetGlobalBuffer((__gm__ int32_t *)padding_offset, padLength_ * batch_); + cuSeqlensQGm = (__gm__ int32_t *)cu_seqlens_q; + cuSeqlensKGm = (__gm__ int32_t *)cu_seqlens_k; + padLengthAlign_ = ((padLength_ + ELE_PER_BLK - 1) / ELE_PER_BLK) * ELE_PER_BLK; + batchAlign_ = ((batch_ + ELE_PER_BLK - 1) / ELE_PER_BLK) * ELE_PER_BLK; + pipe_.InitBuffer(inputIdsQueue_, BUFFER_NUM, padLengthAlign_ * sizeof(int64_t)); + pipe_.InitBuffer(cumOffsetsQueue_, BUFFER_NUM, MAX_BATCH_NUM * sizeof(int32_t)); + pipe_.InitBuffer(seqLenQueue_, BUFFER_NUM, MAX_BATCH_NUM * sizeof(int32_t)); + pipe_.InitBuffer(xRemovePaddingQueue_, BUFFER_NUM, padLengthAlign_ * sizeof(int64_t)); + pipe_.InitBuffer(cumOffsetOutQueue_, BUFFER_NUM, MAX_BATCH_NUM * sizeof(int32_t)); + pipe_.InitBuffer(broadCastBuf_, padLengthAlign_ * sizeof(int32_t)); + pipe_.InitBuffer(cumOffsetsBuf_, MAX_BATCH_NUM * sizeof(int32_t)); + pipe_.InitBuffer(seqLenBuf_, MAX_BATCH_NUM * sizeof(int32_t)); + + } + + __aicore__ inline void Process() + { + pipe_barrier(PIPE_MTE2); + pipe_barrier(PIPE_ALL); + for (int32_t i = 0; i < batch_; i++) { + CopyIn(i); + CopyOnce(); + pipe_barrier(PIPE_MTE2); + ComputeOnce(i); + ComputeRemovePadding(); + pipe_barrier(PIPE_V); + CopyOut(i); + CopyOutOnce(i); + } + + for (int32_t i = 0; i < batch_; i++) { + int32_t cum_seq_len = (i+1) * padLength_ - *(cumOffsetOutGm.GetPhyAddr()+ i); + *(cuSeqlensQGm + i + 1) = cum_seq_len; + *(cuSeqlensKGm + i + 1) = cum_seq_len; + } + + } + +private: + __aicore__ inline void CopyOnce() + { + LocalTensor seqLenLocal = seqLenQueue_.AllocTensor(); + AscendC::LocalTensor cumOffsetsBuffer = cumOffsetsBuf_.Get(); + AscendC::LocalTensor seqLenBufBuffer = seqLenBuf_.Get(); + DataCopy(seqLenLocal, seqLenGm, MAX_BATCH_NUM); + DataCopy(cumOffsetsBuffer, cumOffsetsNowGm, MAX_BATCH_NUM); + DataCopy(seqLenBufBuffer, seqLenGm, MAX_BATCH_NUM); + seqLenQueue_.EnQue(seqLenLocal); + } + + __aicore__ inline void ComputeOnce(uint32_t progress) + { + AscendC::LocalTensor cumOffsetsBuffer = cumOffsetsBuf_.Get(); + LocalTensor seqLenLocal = seqLenQueue_.DeQue(); + AscendC::LocalTensor broadCast = broadCastBuf_.Get(); + AscendC::LocalTensor cumOffsetOutLocal = cumOffsetOutQueue_.AllocTensor(); + seqLenZero_ = seqLenLocal.GetValue(0); + if (progress == 0) { + Duplicate(broadCast, (int32_t)0, padLengthAlign_); + pipe_barrier(PIPE_ALL); + } else { + Duplicate(broadCast, (int32_t)cumOffsetsBuffer.GetValue(progress - 1), padLengthAlign_); + pipe_barrier(PIPE_V); + } + pipe_barrier(PIPE_ALL); + cumOffsetOutLocal.SetValue(0, (int32_t)0); + for (uint32_t x = 1; x < batch_; x++) { + cumOffsetOutLocal.SetValue(x, cumOffsetsBuffer.GetValue(x - 1)); + } + cumOffsetOutQueue_.EnQue(cumOffsetOutLocal); + seqLenQueue_.FreeTensor(seqLenLocal); + } + + __aicore__ inline void CopyOutOnce(uint32_t progress) + { + AscendC::LocalTensor cumOffsetsBuffer = cumOffsetsBuf_.Get(); + LocalTensor cumOffsetOutLocal = cumOffsetOutQueue_.DeQue(); + AscendC::LocalTensor broadCast = broadCastBuf_.Get(); + if (progress == 0) { + DataCopy(paddingOffsetGm, broadCast, padLengthAlign_); + } else { + DataCopy(paddingOffsetGm[seqLenZero_ + (progress - 1) * padLength_ - + cumOffsetsBuffer.GetValue(progress - 1) + + cumOffsetsBuffer.GetValue(0)], broadCast, padLengthAlign_); + } + DataCopy(cumOffsetOutGm, cumOffsetOutLocal, batchAlign_); + cumOffsetOutQueue_.FreeTensor(cumOffsetOutLocal); + } + + __aicore__ inline void CopyIn(uint32_t progress) + { + LocalTensor inputIdsLocal = inputIdsQueue_.AllocTensor(); + DataCopy(inputIdsLocal, inputIdsGm[progress * padLength_], padLengthAlign_); + inputIdsQueue_.EnQue(inputIdsLocal); + } + + __aicore__ inline void ComputeRemovePadding() + { + LocalTensor inputIdsLocal = inputIdsQueue_.DeQue(); + AscendC::LocalTensor broadCast = broadCastBuf_.Get(); + AscendC::LocalTensor xRemovePaddingLocal = xRemovePaddingQueue_.AllocTensor(); + DataCopy(xRemovePaddingLocal, inputIdsLocal, padLengthAlign_); + pipe_barrier(PIPE_ALL); + inputIdsQueue_.FreeTensor(inputIdsLocal); + xRemovePaddingQueue_.EnQue(xRemovePaddingLocal); + } + + __aicore__ inline void CopyOut(uint32_t progress) + { + AscendC::LocalTensor cumOffsetsBuffer = cumOffsetsBuf_.Get(); + AscendC::LocalTensor seqLenBufBuffer = seqLenBuf_.Get(); + LocalTensor xRemovePaddingLocal = xRemovePaddingQueue_.DeQue(); + if (progress == 0) { + DataCopy(xRemovePaddingGm, xRemovePaddingLocal, padLengthAlign_); + } else { + DataCopy(xRemovePaddingGm[progress * padLength_ - cumOffsetsBuffer.GetValue(progress - 1)], + xRemovePaddingLocal, padLengthAlign_); + } + xRemovePaddingQueue_.FreeTensor(xRemovePaddingLocal); + } + +private: + TPipe pipe_; + TQue inputIdsQueue_, cumOffsetsQueue_, tokenNumQueue_, seqLenQueue_; + TQue xRemovePaddingQueue_, cumOffsetOutQueue_, paddingOffsetQueue_; + AscendC::TBuf broadCastBuf_; + AscendC::TBuf cumOffsetsBuf_; + AscendC::TBuf seqLenBuf_; + GlobalTensor cumOffsetsNowGm, seqLenGm, cumOffsetOutGm, paddingOffsetGm; + GlobalTensor inputIdsGm, tokenNumGm, xRemovePaddingGm; + uint32_t padLength_{1}; + uint32_t batch_{1}; + uint32_t padLengthAlign_{16}; + uint32_t batchAlign_{8}; + int32_t seqLenZero_{0}; + __gm__ int32_t *cuSeqlensQGm; + __gm__ int32_t *cuSeqlensKGm; +}; +} + +extern "C" __global__ __aicore__ void get_padding_offset(GM_ADDR input_ids, + GM_ADDR cum_offsets_now, GM_ADDR token_num, GM_ADDR seq_len, GM_ADDR x_remove_padding, + GM_ADDR cum_offsets_out, GM_ADDR padding_offset, GM_ADDR cu_seqlens_q, GM_ADDR cu_seqlens_k, + GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + GetPaddingOffset op; + op.Init(input_ids, cum_offsets_now, token_num, seq_len, + x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k, + tilingData.padLength, tilingData.batch); + op.Process(); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/rebuild_padding.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/rebuild_padding.cpp new file mode 100644 index 00000000000..13ea521c1e4 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/rebuild_padding.cpp @@ -0,0 +1,125 @@ +#include "kernel_operator.h" +using namespace AscendC; + +namespace { + +constexpr int32_t BUFFER_NUM = 1; +constexpr int32_t ELE_PER_BLK = 16; +constexpr int32_t MAX_BATCH_NUM = 256; + +class RebuildPadding { +public: + __aicore__ inline RebuildPadding(int32_t bs, int32_t dim_embed, int32_t token_num,int32_t max_input_length) + { + this->bs_ = bs; + this->dimEmbed_ = dim_embed; + this->tokenNum_ = token_num; + this->maxInputLength_ = max_input_length; + } + + __aicore__ inline void Init(GM_ADDR tmpOut, + GM_ADDR cum_offsets, GM_ADDR seq_lens_decoder, + GM_ADDR seq_lens_encoder, GM_ADDR out) + { + dimEmbedAlign_ = (dimEmbed_ + ELE_PER_BLK - 1) / ELE_PER_BLK * ELE_PER_BLK; + tmpOutGm.SetGlobalBuffer((__gm__ half *)tmpOut, tokenNum_ * dimEmbed_); + cumOffsetsGm.SetGlobalBuffer((__gm__ int32_t *)cum_offsets, bs_); + seqLensDecoderGm.SetGlobalBuffer((__gm__ int32_t *)seq_lens_decoder, bs_); + seqLensEncoderGm.SetGlobalBuffer((__gm__ int32_t *)seq_lens_encoder, bs_); + outGm.SetGlobalBuffer((__gm__ half *)out, bs_ * dimEmbed_); + + pipe_.InitBuffer(tmpOutQueue_, BUFFER_NUM, dimEmbedAlign_ * sizeof(half)); + pipe_.InitBuffer(cumOffsetsQueue_, BUFFER_NUM, MAX_BATCH_NUM * sizeof(int32_t)); + pipe_.InitBuffer(seqLensDecoderQueue_, BUFFER_NUM, MAX_BATCH_NUM * sizeof(int32_t)); + pipe_.InitBuffer(seqLensEncoderQueue_, BUFFER_NUM, MAX_BATCH_NUM * sizeof(int32_t)); + pipe_.InitBuffer(outQueue_, BUFFER_NUM, dimEmbedAlign_ * sizeof(half)); + } + + __aicore__ inline void Process() + { + for (int32_t i = 0; i < bs_; i++) { + CopyOnce(); + pipe_barrier(PIPE_ALL); + CopyIn(i); + pipe_barrier(PIPE_ALL); + CopyOut(i); + } + } + +private: + __aicore__ inline void CopyOnce() + { + LocalTensor cumOffsetsLocal = cumOffsetsQueue_.AllocTensor(); + pipe_barrier(PIPE_ALL); + DataCopy(cumOffsetsLocal, cumOffsetsGm, MAX_BATCH_NUM); + cumOffsetsQueue_.EnQue(cumOffsetsLocal); + + LocalTensor seqLensDecoderLocal = seqLensDecoderQueue_.AllocTensor(); + pipe_barrier(PIPE_ALL); + DataCopy(seqLensDecoderLocal, seqLensDecoderGm, MAX_BATCH_NUM); + seqLensDecoderQueue_.EnQue(seqLensDecoderLocal); + + LocalTensor seqLensEncoderLocal = seqLensEncoderQueue_.AllocTensor(); + pipe_barrier(PIPE_ALL); + DataCopy(seqLensEncoderLocal, seqLensEncoderGm, MAX_BATCH_NUM); + seqLensEncoderQueue_.EnQue(seqLensEncoderLocal); + } + __aicore__ inline void CopyIn(uint32_t progress) + { + LocalTensor tmpOutLocal = tmpOutQueue_.AllocTensor(); + LocalTensor cumOffsetsLocal = cumOffsetsQueue_.DeQue(); + LocalTensor seqLensDecoderLocal = seqLensDecoderQueue_.DeQue(); + LocalTensor seqLensEncoderLocal = seqLensEncoderQueue_.DeQue(); + int32_t decoderVal = seqLensDecoderLocal.GetValue(progress); + int32_t encoderVal = seqLensEncoderLocal.GetValue(progress); + int32_t cumOffset = cumOffsetsLocal.GetValue(progress); + pipe_barrier(PIPE_ALL); + + if (decoderVal == 0) { + if (encoderVal != 0) { + tempVal_ = progress * maxInputLength_ - cumOffset + (encoderVal - 1); + pipe_barrier(PIPE_ALL); + DataCopy(tmpOutLocal, tmpOutGm[tempVal_ * dimEmbed_], dimEmbedAlign_); + } + } else { + tempVal_ = progress * maxInputLength_ - cumOffset; + pipe_barrier(PIPE_ALL); + DataCopy(tmpOutLocal, tmpOutGm[tempVal_ * dimEmbed_], dimEmbedAlign_); + } + tmpOutQueue_.EnQue(tmpOutLocal); + cumOffsetsQueue_.FreeTensor(cumOffsetsLocal); + seqLensDecoderQueue_.FreeTensor(seqLensDecoderLocal); + seqLensEncoderQueue_.FreeTensor(seqLensEncoderLocal); + } + + __aicore__ inline void CopyOut(uint32_t progress) + { + LocalTensor tmpOutLocal = tmpOutQueue_.DeQue(); + pipe_barrier(PIPE_ALL); + DataCopy(outGm[progress * dimEmbed_], tmpOutLocal, dimEmbedAlign_); + tmpOutQueue_.FreeTensor(tmpOutLocal); + } + +private: + TPipe pipe_; + TQue tmpOutQueue_, cumOffsetsQueue_, seqLensDecoderQueue_, seqLensEncoderQueue_; + TQue outQueue_; + GlobalTensor tmpOutGm, outGm; + GlobalTensor cumOffsetsGm, seqLensDecoderGm, seqLensEncoderGm; + int32_t bs_{1}; + int32_t dimEmbed_{16}; + int32_t dimEmbedAlign_{16}; + int32_t maxInputLength_{64}; + int32_t tempVal_ = 0; + int32_t tokenNum_ = 0; +}; +} + +extern "C" __global__ __aicore__ void rebuild_padding(GM_ADDR tmp_out, GM_ADDR cum_offsets, GM_ADDR seq_lens_decoder, + GM_ADDR seq_lens_encoder, GM_ADDR out, GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + RebuildPadding op(tilingData.bs, tilingData.dim_embed, tilingData.token_num,tilingData.max_input_length); + op.Init(tmp_out, cum_offsets, seq_lens_decoder, seq_lens_encoder, out); + op.Process(); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_mask_value.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_mask_value.cpp new file mode 100644 index 00000000000..d50be06a5d1 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_mask_value.cpp @@ -0,0 +1,52 @@ +#include "kernel_operator.h" +using namespace AscendC; + +namespace { + +class SetMaskValue { +public: + __aicore__ inline SetMaskValue(int32_t seqBs, int32_t length) + { + this->seqBs = seqBs; + this->length = length; + } + + __aicore__ inline void Process(__gm__ uint8_t *inputData, __gm__ uint8_t *stopFlags, __gm__ uint8_t *seqLens, + __gm__ uint8_t *sequenceLengths) + { + inputDataGm = (__gm__ half *)inputData; + stopFlagsGm = (__gm__ bool *)stopFlags; + seqLensGm = (__gm__ int32_t *)seqLens; + sequenceLengthsGm = (__gm__ int32_t *)sequenceLengths; + + for (int32_t i = 0; i < seqBs; i++) { + if (*(stopFlagsGm + i)) { + *(sequenceLengthsGm + i) = 0; + pipe_barrier(PIPE_ALL); + } else { + *(sequenceLengthsGm + i) = *(seqLensGm + i); + pipe_barrier(PIPE_ALL); + } + *((inputDataGm + i * length + *(seqLensGm + i))) = (half)1.0; + pipe_barrier(PIPE_ALL); + } + + } +private: + int32_t seqBs; + int32_t length; + __gm__ half *inputDataGm; + __gm__ bool *stopFlagsGm; + __gm__ int32_t *seqLensGm; + __gm__ int32_t *sequenceLengthsGm; + +}; +} + +extern "C" __global__ __aicore__ void set_mask_value(GM_ADDR input_data, GM_ADDR stop_flags, GM_ADDR seq_lens, + GM_ADDR sequence_lengths, GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + SetMaskValue op(tilingData.seqBs, tilingData.length); + op.Process(input_data, stop_flags, seq_lens, sequence_lengths); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_stop_value_multi_ends.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_stop_value_multi_ends.cpp new file mode 100644 index 00000000000..d33a5559310 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_stop_value_multi_ends.cpp @@ -0,0 +1,72 @@ +#include "kernel_operator.h" +using namespace AscendC; + +namespace { +// In vector core, repeat size equal to 256B = 128xhalf +const uint32_t REPEATSIZE = 128; +// Default repeat stride is 8 +const uint8_t DEFAULREPEATSTRIDE = 8; +const uint32_t BLOCK_SIZE = 16; +// 12288 * 8 means we can process 8 tokens in one loop +const uint64_t MAX_PROCESS_NUM = 12288 * 8; + +class SetStopValueMultiEnds { +public: + __aicore__ inline SetStopValueMultiEnds(int32_t seqBs, int32_t length) + { + // notkens headDim headNum + this->seqBs = seqBs; + this->length = length; + } + + __aicore__ inline void Process(__gm__ uint8_t *topk_ids, __gm__ uint8_t *stopFlags, __gm__ uint8_t *end_ids, + __gm__ uint8_t *topk_ids_out, __gm__ uint8_t *stop_flags_out) + { + topk_idsGm = (__gm__ int64_t *)topk_ids; + stopFlagsGm = (__gm__ bool *)stopFlags; + end_idsGm = (__gm__ int64_t *)end_ids; + topk_ids_outGm = (__gm__ int64_t *)topk_ids_out; + stop_flags_outGm = (__gm__ bool *)stop_flags_out; + for (int i = 0; i < this->seqBs; ++i) { + *(topk_idsGm + i) = (*(stopFlags + i)) ? *(end_idsGm) : *(topk_idsGm + i); + *(stop_flags_outGm + i) = *(stopFlagsGm + i); + } + pipe_barrier(PIPE_ALL); + + for (int i = 0; i < this->seqBs; ++i) { + int64_t id = *(topk_idsGm + i); + bool flag = 0; + for(int j = 0; j < this->length; ++j) { + if (id == *(end_idsGm + j)){ + flag = 1; + break; + } + } + if (flag) { + *(stop_flags_outGm + i) = 1; + } + } + pipe_barrier(PIPE_ALL); + } + +private: + /* data */ + __gm__ int64_t *topk_idsGm; + __gm__ bool *stopFlagsGm; + __gm__ int64_t *end_idsGm; + + __gm__ int64_t *topk_ids_outGm; + __gm__ bool *stop_flags_outGm; + + int32_t seqBs; + int32_t length; +}; +} + +extern "C" __global__ __aicore__ void set_stop_value_multi_ends(GM_ADDR topk_ids, GM_ADDR stop_flags, GM_ADDR end_ids, + GM_ADDR topk_ids_out, GM_ADDR stop_flags_out, GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + SetStopValueMultiEnds op(tilingData.seqBs, tilingData.length); + op.Process(topk_ids, stop_flags, end_ids, topk_ids_out, stop_flags_out); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_stop_value_multi_ends_v2.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_stop_value_multi_ends_v2.cpp new file mode 100644 index 00000000000..532582aba73 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_stop_value_multi_ends_v2.cpp @@ -0,0 +1,101 @@ +#include "kernel_operator.h" +using namespace AscendC; + +namespace { +// In vector core, repeat size equal to 256B = 128xhalf +const uint32_t REPEATSIZE = 128; +// Default repeat stride is 8 +const uint8_t DEFAULREPEATSTRIDE = 8; +const uint32_t BLOCK_SIZE = 16; +// 12288 * 8 means we can process 8 tokens in one loop +const uint64_t MAX_PROCESS_NUM = 12288 * 8; + +class SetStopValueMultiEndsV2 { +public: + __aicore__ inline SetStopValueMultiEndsV2(int32_t bs, int32_t length) + { + this->batchNum = bs; + this->lengthNum = length; + } + + __aicore__ inline void Init(__gm__ uint8_t *topkIds, + __gm__ uint8_t *stopFlags, + __gm__ uint8_t *seqLens, + __gm__ uint8_t *endIds, + __gm__ uint8_t *nextTokens, + __gm__ uint8_t *topkIdsOut, + __gm__ uint8_t *stopFlagsOut, + __gm__ uint8_t *nextTokensOut) + { + topkIdsGm = (__gm__ int64_t *)topkIds; + stopFlagsGm = (__gm__ bool *)stopFlags; + seqLensGm = (__gm__ int32_t *)seqLens; + endIdsGm = (__gm__ int64_t *)endIds; + nextTokensGm = (__gm__ int64_t *)nextTokens; + topkIdsOutGm = (__gm__ int64_t *)topkIdsOut; + stopFlagsOutGm = (__gm__ bool *)stopFlagsOut; + nextTokensOutGm = (__gm__ int64_t *)nextTokensOut; + } + + __aicore__ inline void Process() + { + for (int32_t i = 0; i < batchNum; i++) { + pipe_barrier(PIPE_ALL); + if (*(stopFlagsGm + i)) { + if (*(seqLensGm + i) == 0) { + *(topkIdsGm + i) = -1; + } else { + *(topkIdsGm + i) = *endIdsGm; + *(nextTokensGm + i) = *endIdsGm; + } + } else { + *(nextTokensGm + i) = *(topkIdsGm + i); + } + + for (int32_t j = 0; j < lengthNum; j++) { + if (*(topkIdsGm + i) == *(endIdsGm + j)) { + *(stopFlagsGm + i) = true; + break; + } + } + pipe_barrier(PIPE_ALL); + } + + for (int32_t i = 0; i < batchNum; i++) { + *(topkIdsOutGm + i) = *(topkIdsGm + i); + *(stopFlagsOutGm + i) = *(stopFlagsGm + i); + *(nextTokensOutGm + i) = *(nextTokensGm + i); + } + } + +private: + __gm__ int64_t *topkIdsGm; + __gm__ bool *stopFlagsGm; + __gm__ int32_t *seqLensGm; + __gm__ int64_t *endIdsGm; + __gm__ int64_t *nextTokensGm; + __gm__ int64_t *topkIdsOutGm; + __gm__ bool *stopFlagsOutGm; + __gm__ int64_t *nextTokensOutGm; + + int32_t batchNum; + int32_t lengthNum; +}; +} + +extern "C" __global__ __aicore__ void set_stop_value_multi_ends_v2(GM_ADDR topkIds, + GM_ADDR stopFlags, + GM_ADDR seqLens, + GM_ADDR endIds, + GM_ADDR nextTokens, + GM_ADDR topkIdsOut, + GM_ADDR stopFlagsOut, + GM_ADDR nextTokensOut, + GM_ADDR workspace, + GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + SetStopValueMultiEndsV2 op(tilingData.bs, tilingData.length); + op.Init(topkIds, stopFlags, seqLens, endIds, nextTokens, topkIdsOut, stopFlagsOut, nextTokensOut); + op.Process(); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_stop_value_multi_seqs.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_stop_value_multi_seqs.cpp new file mode 100644 index 00000000000..909890efddf --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_stop_value_multi_seqs.cpp @@ -0,0 +1,140 @@ +#include "kernel_operator.h" +using namespace AscendC; + +namespace { +// In vector core, repeat size equal to 256B = 128xhalf +const uint32_t REPEATSIZE = 128; +// Default repeat stride is 8 +const uint8_t DEFAULREPEATSTRIDE = 8; +const uint32_t BLOCK_SIZE = 16; +// 12288 * 8 means we can process 8 tokens in one loop +const uint64_t MAX_PROCESS_NUM = 12288 * 8; + +class SetStopValueMultiSeqs { +public: + __aicore__ inline SetStopValueMultiSeqs(int32_t bs, int32_t length, int32_t stop_seqs_num, int32_t stop_seqs_max_len, int32_t eos_len) + { + this->batchNum = bs; + this->lengthNum = length; + this->stopSeqsNum = stop_seqs_num; + this->stopSeqsMaxLen = stop_seqs_max_len; + this->eosLen = eos_len; + } + + __aicore__ inline void Init(__gm__ uint8_t *topkIds, + __gm__ uint8_t *preIds, + __gm__ uint8_t *stepIdx, + __gm__ uint8_t *stopFlags, + __gm__ uint8_t *seqLens, + __gm__ uint8_t *stopSeqs, + __gm__ uint8_t *stopSeqsLen, + __gm__ uint8_t *endIds, + __gm__ uint8_t *topkIdsOut, + __gm__ uint8_t *stopFlagsOut) + { + topkIdsGm = (__gm__ int64_t *)topkIds; + preIdsGm = (__gm__ int64_t *)preIds; + stepIdxGm = (__gm__ int64_t *)stepIdx; + stopFlagsGm = (__gm__ bool *)stopFlags; + seqLensGm = (__gm__ int32_t *)seqLens; + stopSeqsGm = (__gm__ int64_t *)stopSeqs; + stopSeqsLenGm = (__gm__ int32_t *)stopSeqsLen; + endIdsGm = (__gm__ int64_t *)endIds; + topkIdsOutGm = (__gm__ int64_t *)topkIdsOut; + stopFlagsOutGm = (__gm__ bool *)stopFlagsOut; + } + + __aicore__ inline void Process() + { + for (int32_t i = 0; i < batchNum; i++) { + pipe_barrier(PIPE_ALL); + if (*(stopFlagsGm + i)) { + if (*(seqLensGm + i) == 0) { + *(topkIdsGm + i) = -1; + } else { + *(topkIdsGm + i) = *endIdsGm; + } + } else { + for (int32_t j = 0; j < eosLen; j++) { + if (*(topkIdsGm + i) == *(endIdsGm + j)) { + *(stopFlagsGm + i) = true; + *(topkIdsGm + i) = *endIdsGm; + break; + } + } + pipe_barrier(PIPE_ALL); + if (*(stopFlagsGm + i) == false) { + for (int32_t j = 0; j < stopSeqsNum; j++) { + if (*(stopSeqsLenGm + j) > 0 && *(stepIdxGm + i) >= *(stopSeqsLenGm + j)) { + int32_t sameTokenCount = 0; + for (int32_t k = 0; k < *(stopSeqsLenGm + j); k++) { + if (k < *(stopSeqsLenGm + j) - 1) { + if (*(preIdsGm + i * lengthNum + *(stepIdxGm + i) - *(stopSeqsLenGm + j) + 1 + k) == *(stopSeqsGm + j * stopSeqsMaxLen + k)) { + sameTokenCount++; + } else { + break; + } + } else { + if (*(topkIdsGm + i) == *(stopSeqsGm + j * stopSeqsMaxLen + k)) { + sameTokenCount++; + } else { + break; + } + } + } + if (sameTokenCount == *(stopSeqsLenGm + j)) { + *(stopFlagsGm + i) = true; + *(topkIdsGm + i) = *endIdsGm; + break; + } + } + } + } + } + pipe_barrier(PIPE_ALL); + } + + for (int32_t i = 0; i < batchNum; i++) { + *(topkIdsOutGm + i) = *(topkIdsGm + i); + *(stopFlagsOutGm + i) = *(stopFlagsGm + i); + } + } + +private: + __gm__ int64_t *topkIdsGm; + __gm__ int64_t *preIdsGm; + __gm__ int64_t *stepIdxGm; + __gm__ bool *stopFlagsGm; + __gm__ int32_t *seqLensGm; + __gm__ int64_t *endIdsGm; + __gm__ int64_t *stopSeqsGm; + __gm__ int32_t *stopSeqsLenGm; + __gm__ int64_t *topkIdsOutGm; + __gm__ bool *stopFlagsOutGm; + + int32_t batchNum; + int32_t lengthNum; + int32_t stopSeqsNum; + int32_t stopSeqsMaxLen; + int32_t eosLen; +}; +} + +extern "C" __global__ __aicore__ void set_stop_value_multi_seqs(GM_ADDR topkIds, + GM_ADDR preIds, + GM_ADDR stepIdx, + GM_ADDR stopFlags, + GM_ADDR seqLens, + GM_ADDR stopSeqs, + GM_ADDR stopSeqsLen, + GM_ADDR endIds, + GM_ADDR topkIdsOut, + GM_ADDR stopFlagsOut, + GM_ADDR workspace, + GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + SetStopValueMultiSeqs op(tilingData.bs, tilingData.length, tilingData.stop_seqs_num, tilingData.stop_seqs_max_len, tilingData.eos_len); + op.Init(topkIds, preIds, stepIdx, stopFlags, seqLens, stopSeqs, stopSeqsLen, endIds, topkIdsOut, stopFlagsOut); + op.Process(); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_value_by_flags_and_idx.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_value_by_flags_and_idx.cpp new file mode 100644 index 00000000000..aeaef4914cc --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_value_by_flags_and_idx.cpp @@ -0,0 +1,55 @@ +#include "kernel_operator.h" +using namespace AscendC; + +class SetValueByFlagsAndIdx { +public: + __aicore__ inline SetValueByFlagsAndIdx(int32_t bs, int32_t length) + { + this->lengthNum = length; + this->batchNum = bs; + } + + __aicore__ inline void Init(__gm__ uint8_t *preIdsAll, __gm__ uint8_t *preIdsNow, __gm__ uint8_t *stepIdx, __gm__ uint8_t *stopFlags, + __gm__ uint8_t *stopFlagsOut) + { + preIdsAllGm = (__gm__ int64_t *)preIdsAll; + preIdsNowGm = (__gm__ int64_t *)preIdsNow; + stepIdxGm = (__gm__ int64_t *)stepIdx; + stopFlagsGm = (__gm__ bool *)stopFlags; + stopFlagsOutGm = (__gm__ bool *)stopFlagsOut; + } + + __aicore__ inline void Process() + { + for (int32_t i = 0; i < batchNum; i++) { + *(stopFlagsOutGm + i) = *(stopFlagsGm + i); + pipe_barrier(PIPE_ALL); + + if (!(*(stopFlagsGm + i))) { + if (*(stepIdxGm + i) >= 0) { + *(preIdsAllGm + i * lengthNum + *(stepIdxGm + i)) = *(preIdsNowGm + i); + } + } + pipe_barrier(PIPE_ALL); + } + } + +private: + __gm__ int64_t *preIdsAllGm; + __gm__ int64_t *preIdsNowGm; + __gm__ int64_t *stepIdxGm; + __gm__ bool *stopFlagsGm; + __gm__ bool *stopFlagsOutGm; + + int32_t batchNum; + int32_t lengthNum; +}; + +extern "C" __global__ __aicore__ void set_value_by_flags_and_idx(GM_ADDR preIdsAll, GM_ADDR preIdsNow, GM_ADDR stepIdx, GM_ADDR stopFlags, + GM_ADDR stopFlagsOut, GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + SetValueByFlagsAndIdx op(tilingData.bs, tilingData.length); + op.Init(preIdsAll, preIdsNow, stepIdx, stopFlags, stopFlagsOut); + op.Process(); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_value_by_flags_and_idx_v2.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_value_by_flags_and_idx_v2.cpp new file mode 100644 index 00000000000..5b8116a2aa9 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/set_value_by_flags_and_idx_v2.cpp @@ -0,0 +1,86 @@ +#include "kernel_operator.h" +using namespace AscendC; + +class SetValueByFlagsAndIdxV2 { +public: + __aicore__ inline SetValueByFlagsAndIdxV2(int32_t bs, int32_t length, int32_t lengthInput) + { + this->batchNum = bs; + this->lengthNum = length; + this->lengthInputNum = lengthInput; + } + + __aicore__ inline void Init(__gm__ uint8_t *preIdsAll, + __gm__ uint8_t *inputIds, + __gm__ uint8_t *seqLensThisTime, + __gm__ uint8_t *seqLensEncoder, + __gm__ uint8_t *seqLensDecoder, + __gm__ uint8_t *stepIdx, + __gm__ uint8_t *stopFlags, + __gm__ uint8_t *preIdsAllOut) + { + preIdsAllGm = (__gm__ int64_t *)preIdsAll; + inputIdsGm = (__gm__ int64_t *)inputIds; + seqLensThisTimeGm = (__gm__ int32_t *)seqLensThisTime; + seqLensEncoderGm = (__gm__ int32_t *)seqLensEncoder; + seqLensDecoderGm = (__gm__ int32_t *)seqLensDecoder; + stepIdxGm = (__gm__ int64_t *)stepIdx; + stopFlagsGm = (__gm__ bool *)stopFlags; + preIdsAllOutGm = (__gm__ int64_t *)preIdsAllOut; + } + + __aicore__ inline void Process() + { + for (int32_t i = 0; i < batchNum; i++) { + pipe_barrier(PIPE_ALL); + if (!(*(stopFlagsGm + i))) { + int32_t seqLenDec = *(seqLensDecoderGm + i); + int32_t seqLenEnc = *(seqLensEncoderGm + i); + if ((seqLenDec == 0) && (seqLenEnc == 0)) { + continue; + } + if (*(stepIdxGm + i) >= 0) { + if (seqLenDec == 0) { + *(preIdsAllGm + i * lengthNum + *(stepIdxGm + i)) = + *(inputIdsGm + i * lengthInputNum + (seqLenEnc - 1)); + } else { + *(preIdsAllGm + i * lengthNum + *(stepIdxGm + i)) = + *(inputIdsGm + i * lengthInputNum); + } + } + } + pipe_barrier(PIPE_ALL); + } + } + +private: + __gm__ int64_t *preIdsAllGm; + __gm__ int64_t *inputIdsGm; + __gm__ int32_t *seqLensThisTimeGm; + __gm__ int32_t *seqLensEncoderGm; + __gm__ int32_t *seqLensDecoderGm; + __gm__ int64_t *stepIdxGm; + __gm__ bool *stopFlagsGm; + __gm__ int64_t *preIdsAllOutGm; + + int32_t batchNum; + int32_t lengthNum; + int32_t lengthInputNum; +}; + +extern "C" __global__ __aicore__ void set_value_by_flags_and_idx_v2(GM_ADDR preIdsAll, + GM_ADDR inputIds, + GM_ADDR seqLensThisTime, + GM_ADDR seqLensEncoder, + GM_ADDR seqLensDecoder, + GM_ADDR stepIdx, + GM_ADDR stopFlags, + GM_ADDR preIdsAllOut, + GM_ADDR workspace, + GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + SetValueByFlagsAndIdxV2 op(tilingData.bs, tilingData.length, tilingData.lengthInput); + op.Init(preIdsAll, inputIds, seqLensThisTime, seqLensEncoder, seqLensDecoder, stepIdx, stopFlags, preIdsAllOut); + op.Process(); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/step_paddle.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/step_paddle.cpp new file mode 100644 index 00000000000..b575b410c09 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/step_paddle.cpp @@ -0,0 +1,248 @@ +#include "kernel_operator.h" +using namespace AscendC; + +class StepPaddle { +public: + __aicore__ inline StepPaddle(int32_t bsz, int32_t block_size, int32_t block_num_per_seq, int32_t max_decoder_block_num, int32_t length, int32_t pre_id_length, int64_t first_token_id) + { + this->bsz = bsz; + this->block_size = block_size; + this->block_num_per_seq = block_num_per_seq; + this->max_decoder_block_num = max_decoder_block_num; + this->length = length; + this->pre_id_length = pre_id_length; + this->first_token_id = first_token_id; + pipe_barrier(PIPE_ALL); + } + + __aicore__ inline void Init(__gm__ uint8_t* stop_flags, __gm__ uint8_t* seq_lens_this_time, __gm__ uint8_t* ori_seq_lens_encoder, __gm__ uint8_t* seq_lens_encoder, + __gm__ uint8_t* seq_lens_decoder, __gm__ uint8_t* block_tables, __gm__ uint8_t* encoder_block_lens, __gm__ uint8_t* is_block_step, + __gm__ uint8_t* step_block_list, __gm__ uint8_t* step_len, __gm__ uint8_t* recover_block_list, __gm__ uint8_t* recover_len, + __gm__ uint8_t* need_block_list, __gm__ uint8_t* need_block_len, __gm__ uint8_t* used_list_len, __gm__ uint8_t* free_list, + __gm__ uint8_t* free_list_len, __gm__ uint8_t* input_ids, __gm__ uint8_t* pre_ids, __gm__ uint8_t* step_idx, __gm__ uint8_t* next_tokens) + { + this->stop_flags = (__gm__ bool *)stop_flags; + this->seq_lens_this_time = (__gm__ int *)seq_lens_this_time; + this->ori_seq_lens_encoder = (__gm__ int *)ori_seq_lens_encoder; + this->seq_lens_encoder = (__gm__ int *)seq_lens_encoder; + this->seq_lens_decoder = (__gm__ int *)seq_lens_decoder; + this->block_tables = (__gm__ int *)block_tables; + this->encoder_block_lens = (__gm__ int *)encoder_block_lens; + this->is_block_step = (__gm__ bool *)is_block_step; + this->step_block_list = (__gm__ int *)step_block_list; + this->step_len = (__gm__ int *)step_len; + this->recover_block_list = (__gm__ int *)recover_block_list; + this->recover_len = (__gm__ int *)recover_len; + this->need_block_list = (__gm__ int *)need_block_list; + this->need_block_len = (__gm__ int *)need_block_len; + this->used_list_len = (__gm__ int *)used_list_len; + this->free_list = (__gm__ int *)free_list; + this->free_list_len = (__gm__ int *)free_list_len; + this->input_ids = (__gm__ int64_t *)input_ids; + this->pre_ids = (__gm__ int64_t *)pre_ids; + this->step_idx = (__gm__ int64_t *)step_idx; + this->next_tokens = (__gm__ int64_t *)next_tokens; + pipe_barrier(PIPE_ALL); + } + + __aicore__ inline void Process() + { + for (int32_t i = 0; i < bsz; i++) { + __gm__ int *block_table_now = block_tables + i * block_num_per_seq; + if (stop_flags[i] && !is_block_step[i]) { + const int encoder_block_len = encoder_block_lens[i]; + const int decoder_used_len = used_list_len[i]; + if (decoder_used_len > 0) { + pipe_barrier(PIPE_ALL); + const int ori_free_list_len = free_list_len[0]; + free_list_len[0] += decoder_used_len; + pipe_barrier(PIPE_ALL); + for (int32_t j = 0; j < decoder_used_len; j++) { + free_list[ori_free_list_len + j] = block_table_now[encoder_block_len + j]; + block_table_now[encoder_block_len + j] = -1; + } + encoder_block_lens[i] = 0; + used_list_len[i] = 0; + pipe_barrier(PIPE_ALL); + } + } else if (seq_lens_decoder[i] != 0 && block_table_now[(seq_lens_decoder[i] + 1) / block_size] == -1) { + const int ori_need_block_len = need_block_len[0]; + need_block_len[0] = need_block_len[0] + 1; + need_block_list[ori_need_block_len] = i; + } + pipe_barrier(PIPE_ALL); + } + + pipe_barrier(PIPE_ALL); + + while (need_block_len[0] > free_list_len[0]) { + int max_idx = 0; + int max_used_block_num = 0; + for (int i = 0; i < bsz; i++) { + const int cur_block_num = is_block_step[i] ? 0 : used_list_len[i]; + if (cur_block_num > max_used_block_num) { + pipe_barrier(PIPE_ALL); + max_idx = i; + pipe_barrier(PIPE_ALL); + max_used_block_num = cur_block_num; + pipe_barrier(PIPE_ALL); + } + } + + const int encoder_block_len = encoder_block_lens[max_idx]; + __gm__ int *block_table_now = block_tables + max_idx * block_num_per_seq; + for (int i = 0; i < max_used_block_num; i++) { + free_list[free_list_len[0] + i] = block_table_now[encoder_block_len + i]; + block_table_now[encoder_block_len + i] = -1; + pipe_barrier(PIPE_ALL); + } + + step_block_list[step_len[0]] = max_idx; + step_len[0] += 1; + free_list_len[0] += max_used_block_num; + stop_flags[max_idx] = true; + is_block_step[max_idx] = true; + seq_lens_this_time[max_idx] = 0; + seq_lens_decoder[max_idx] = 0; + seq_lens_encoder[max_idx] = 0; + pipe_barrier(PIPE_ALL); + } + + for (int32_t i = 0; i < need_block_len[0]; i++) { + const int need_block_id = need_block_list[i]; + if (need_block_list[i] != -1) { + if (!stop_flags[need_block_id]) { + used_list_len[need_block_id] += 1; + pipe_barrier(PIPE_ALL); + const int ori_free_list_len = free_list_len[0]; + free_list_len[0] -= 1; + pipe_barrier(PIPE_ALL); + __gm__ int *block_table_now = block_tables + need_block_id * block_num_per_seq; + block_table_now[(seq_lens_decoder[need_block_id] + 1) / block_size] = free_list[ori_free_list_len - 1]; + } + need_block_list[i] = -1; + pipe_barrier(PIPE_ALL); + } + } + + int ori_free_list_len = free_list_len[0]; + int ori_step_len = step_len[0]; + int ori_step_block_id = step_block_list[ori_step_len - 1]; + int tmp_used_len = used_list_len[ori_step_block_id]; + int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len; + pipe_barrier(PIPE_ALL); + while (ori_step_len > 0 && ori_free_list_len >= used_len) { + recover_block_list[recover_len[0]] = ori_step_block_id; + is_block_step[ori_step_block_id] = false; + used_list_len[ori_step_block_id] = used_len; + ori_free_list_len -= used_len; + step_block_list[ori_step_len - 1] = -1; + step_len[0] -= 1; + recover_len[0] += 1; + ori_step_len = step_len[0]; + if (ori_step_len > 0) { + ori_step_block_id = step_block_list[ori_step_len - 1]; + tmp_used_len = used_list_len[ori_step_block_id]; + used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len; + } + pipe_barrier(PIPE_ALL); + } + + need_block_len[0] = 0; + + pipe_barrier(PIPE_ALL); + + if (recover_len[0] > 0) { + int ori_free_list_len; + for (int32_t i = 0; i < recover_len[0]; i++) { + const int recover_id = recover_block_list[i]; + const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id]; + const int step_idx_now = step_idx[recover_id]; + const int seq_len = ori_seq_len_encoder + step_idx_now; + const int encoder_block_len = encoder_block_lens[recover_id]; + const int decoder_used_len = used_list_len[recover_id]; + pipe_barrier(PIPE_ALL); + __gm__ int *block_table_now = block_tables + recover_id * block_num_per_seq; + __gm__ int64_t *input_ids_now = input_ids + recover_id * length; + __gm__ int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length; + pipe_barrier(PIPE_ALL); + seq_lens_this_time[recover_id] = seq_len; + seq_lens_encoder[recover_id] = seq_len; + stop_flags[recover_id] = false; + input_ids_now[ori_seq_len_encoder + step_idx_now - 1] = next_tokens[recover_id]; + input_ids_now[0] = first_token_id; + pipe_barrier(PIPE_ALL); + const int ori_free_list_len_0 = free_list_len[0]; + free_list_len[0] -= decoder_used_len; + pipe_barrier(PIPE_ALL); + ori_free_list_len = ori_free_list_len_0; + pipe_barrier(PIPE_ALL); + + for (int32_t j = 0; j < decoder_used_len; j++) { + block_table_now[encoder_block_len + j] = free_list[ori_free_list_len - decoder_used_len + j]; + } + + pipe_barrier(PIPE_ALL); + + for (int32_t j = 0; j < step_idx_now - 1; j++) { + input_ids_now[ori_seq_len_encoder + j] = pre_ids_now[j + 1]; + } + + pipe_barrier(PIPE_ALL); + } + + + recover_len[0] = 0; + } + } + +private: + int32_t bsz = 0; + int32_t block_size = 0; + int32_t block_num_per_seq = 0; + int32_t max_decoder_block_num = 0; + int32_t length = 0; + int32_t pre_id_length = 0; + int64_t first_token_id = 0; + + __gm__ bool *stop_flags; + __gm__ int *seq_lens_this_time; + __gm__ int *ori_seq_lens_encoder; + __gm__ int *seq_lens_encoder; + __gm__ int *seq_lens_decoder; + __gm__ int *block_tables; + __gm__ int *encoder_block_lens; + __gm__ bool *is_block_step; + __gm__ int *step_block_list; + __gm__ int *step_len; + __gm__ int *recover_block_list; + __gm__ int *recover_len; + __gm__ int *need_block_list; + __gm__ int *need_block_len; + __gm__ int *used_list_len; + __gm__ int *free_list; + __gm__ int *free_list_len; + __gm__ int64_t *input_ids; + __gm__ int64_t *pre_ids; + __gm__ int64_t *step_idx; + __gm__ int64_t *next_tokens; + +}; + +extern "C" __global__ __aicore__ void step_paddle(GM_ADDR stop_flags, GM_ADDR seq_lens_this_time, GM_ADDR ori_seq_lens_encoder, GM_ADDR seq_lens_encoder, + GM_ADDR seq_lens_decoder, GM_ADDR block_tables, GM_ADDR encoder_block_lens, GM_ADDR is_block_step, + GM_ADDR step_block_list, GM_ADDR step_len, GM_ADDR recover_block_list, GM_ADDR recover_len, + GM_ADDR need_block_list, GM_ADDR need_block_len, GM_ADDR used_list_len, GM_ADDR free_list, + GM_ADDR free_list_len, GM_ADDR input_ids, GM_ADDR pre_ids, GM_ADDR step_idx, GM_ADDR next_tokens, + GM_ADDR stop_flags_out, GM_ADDR seq_lens_this_time_out, GM_ADDR seq_lens_encoder_out, GM_ADDR seq_lens_decoder_out, GM_ADDR block_tables_out, + GM_ADDR encoder_block_lens_out, GM_ADDR is_block_step_out, GM_ADDR step_block_list_out, GM_ADDR step_lens_out, GM_ADDR recover_block_list_out, + GM_ADDR recover_len_out, GM_ADDR need_block_list_out, GM_ADDR need_block_len_out, GM_ADDR used_list_len_out, GM_ADDR free_list_out, + GM_ADDR free_list_len_out, GM_ADDR input_ids_out, + GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + StepPaddle op(tilingData.bsz, tilingData.block_size, tilingData.block_num_per_seq, tilingData.max_decoder_block_num, tilingData.length, tilingData.pre_id_length, tilingData.first_token_id); + op.Init(stop_flags, seq_lens_this_time, ori_seq_lens_encoder, seq_lens_encoder, seq_lens_decoder, block_tables, encoder_block_lens, is_block_step, step_block_list, step_len, recover_block_list, recover_len, + need_block_list, need_block_len, used_list_len, free_list, free_list_len, input_ids, pre_ids, step_idx, next_tokens); + op.Process(); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/token_penalty_multi_scores.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/token_penalty_multi_scores.cpp new file mode 100644 index 00000000000..deb1aa4f352 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/token_penalty_multi_scores.cpp @@ -0,0 +1,219 @@ +#include "kernel_operator.h" +using namespace AscendC; + +const int32_t BUFFER_NUM = 1; + +class TokenPenaltyMultiScores { +public: + __aicore__ inline TokenPenaltyMultiScores( + int32_t vs, int32_t vsBlock, int32_t sl, int32_t etil, int32_t bs, int32_t bsBlock) + { + vocabSize_ = vs; + vocabLocalSize_ = vsBlock; + seqLen_ = sl; + eosTokenIdLength_ = etil; + bs_ = bs; + bsEachCore_ = bsBlock; + } + + __aicore__ inline void Init(__gm__ uint8_t *preIds, + __gm__ uint8_t *logitsIn, + __gm__ uint8_t *repeatTimes, + __gm__ uint8_t *penaltyScores, + __gm__ uint8_t *frequencyScores, + __gm__ uint8_t *presenceScores, + __gm__ uint8_t *curLen, + __gm__ uint8_t *minLen, + __gm__ uint8_t *eosTokenId, + __gm__ uint8_t *logitsOut) + { + preIdsGm_ = (__gm__ int64_t *)preIds; + logitsInGm_ = (__gm__ float *)logitsIn; + repeatTimesGm_ = (__gm__ int32_t *)repeatTimes; + penaltyScoresGm_ = (__gm__ float *)penaltyScores; + frequencyScoresGm_ = (__gm__ float *)frequencyScores; + presenceScoresGm_ = (__gm__ float *)presenceScores; + curLenGm_ = (__gm__ int64_t *)curLen; + minLenGm_ = (__gm__ int64_t *)minLen; + eosTokenIdGm_ = (__gm__ int64_t *)eosTokenId; + logitsOutGm_ = (__gm__ float *)logitsOut; + + pipe_.InitBuffer(logitsInQueue_, BUFFER_NUM, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(repeatTimesInQueue_, BUFFER_NUM, vocabLocalSize_ * sizeof(int32_t)); + pipe_.InitBuffer(logitsOutQueue_, BUFFER_NUM, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(repeatTimesFp32Buff_, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(repeatTimesUint8Buff_, vocabLocalSize_ * sizeof(uint8_t)); + pipe_.InitBuffer(repeatTimesTmpBuff_, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(logitsTmpBuff_, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(logitsTmpBuff2_, vocabLocalSize_ * sizeof(float)); + } +/** + preIds: bs, seqLen + logits: bs, vocabSize + repeatTimes: bs, vocabSize + penaltyScores: bs + frequencyScores: bs + presenceScores: bs + curLen: bs + minLen: bs + eosTokenId: end_length, + logitsOut: bs, vocabSize +*/ + __aicore__ inline void Process() + { + int32_t vocLoop = (vocabSize_ + vocabLocalSize_ - 1) / vocabLocalSize_; + for (int32_t bsIdEachCore = 0; bsIdEachCore < bsEachCore_; bsIdEachCore++) { + for (int32_t i = 0; i < vocLoop; i++) { + int32_t bsOffset = GetBlockIdx() * bsEachCore_ + bsIdEachCore; + if (bsOffset >= bs_) { + break; + } + int32_t gmOffset = bsOffset * vocabSize_ + i * vocabLocalSize_; + logitsInG_.SetGlobalBuffer(logitsInGm_ + gmOffset); + logitsOutG_.SetGlobalBuffer(logitsOutGm_ + gmOffset); + repeateTimesG_.SetGlobalBuffer(repeatTimesGm_ + gmOffset); + CopyIn(); + Compute(bsOffset, i, vocabLocalSize_); + CopyOut(); + } + } + } + + __aicore__ inline void CopyIn() + { + LocalTensor logitsInL = logitsInQueue_.AllocTensor(); + LocalTensor repeatTimesL = repeatTimesInQueue_.AllocTensor(); + DataCopy(logitsInL, logitsInG_, vocabLocalSize_); + DataCopy(repeatTimesL, repeateTimesG_, vocabLocalSize_); + logitsInQueue_.EnQue(logitsInL); + repeatTimesInQueue_.EnQue(repeatTimesL); + } + + __aicore__ inline void Compute(int32_t bsId, int32_t vocSizeId, int32_t vocabLocalSize_) + { + LocalTensor logitsInL = logitsInQueue_.DeQue(); + LocalTensor repeatTimesL = repeatTimesInQueue_.DeQue(); + + int64_t startVocId = vocabLocalSize_ * vocSizeId; + if (*(curLenGm_ + bsId) >= 0) { + // min_length_logits_process + if (*(curLenGm_ + bsId) < *(minLenGm_ + bsId)) { + for (int i = 0; i < eosTokenIdLength_; ++i) { + int64_t eosTokenIdOffset = *(eosTokenIdGm_ + i) - startVocId; + pipe_barrier(PIPE_ALL); + if (eosTokenIdOffset >= 0 && eosTokenIdOffset < vocabLocalSize_) { + logitsInL.SetValue(eosTokenIdOffset, (float)-1e10); + pipe_barrier(PIPE_ALL); + } + } + } + + // update_repeat_times + for (int i = 0; i < seqLen_; i++) { + int64_t predId = *(preIdsGm_ + bsId * seqLen_ + i); + if (predId < 0) { + break; + } + int64_t predIdOffset = predId - startVocId; + pipe_barrier(PIPE_ALL); + if (predIdOffset >= 0 && predIdOffset < vocabLocalSize_) { + int32_t repeatNew = repeatTimesL.GetValue(predIdOffset) + 1; + repeatTimesL.SetValue(predIdOffset, repeatNew); + pipe_barrier(PIPE_ALL); + } + } + } + + LocalTensor logitsOutL = logitsOutQueue_.AllocTensor(); + LocalTensor logitsInTmpL = logitsTmpBuff_.Get(); + LocalTensor logitsInTmpL2 = logitsTmpBuff2_.Get(); + LocalTensor repeatTimesFp32L = repeatTimesFp32Buff_.Get(); + LocalTensor repeatTimesUint8L = repeatTimesUint8Buff_.Get(); + LocalTensor repeatTimesTmpL = repeatTimesTmpBuff_.Get(); + + // update_value_by_repeat_times + float alpha = *(penaltyScoresGm_ + bsId); + float alphaR = 1.0f / alpha; + float beta = *(frequencyScoresGm_ + bsId); + float gamma = *(presenceScoresGm_ + bsId); + + Cast(repeatTimesFp32L, repeatTimesL, RoundMode::CAST_NONE, vocabLocalSize_); + Duplicate(repeatTimesTmpL, 0.5f, vocabLocalSize_); + Compare(repeatTimesUint8L, repeatTimesFp32L, repeatTimesTmpL, CMPMODE::GT, vocabLocalSize_); + Muls(repeatTimesTmpL, repeatTimesFp32L, beta, vocabLocalSize_); + Adds(repeatTimesTmpL, repeatTimesTmpL, gamma, vocabLocalSize_); + Select(repeatTimesFp32L, repeatTimesUint8L, repeatTimesTmpL, 0.0f, + SELMODE::VSEL_TENSOR_SCALAR_MODE, vocabLocalSize_); + pipe_barrier(PIPE_ALL); + + Maxs(logitsInTmpL, logitsInL, 0.0f, vocabLocalSize_); + Duplicate(logitsInTmpL2, alphaR, vocabLocalSize_); + Select(repeatTimesTmpL, repeatTimesUint8L, logitsInTmpL2, 1.0f, + SELMODE::VSEL_TENSOR_SCALAR_MODE, vocabLocalSize_); + Mul(logitsInTmpL, logitsInTmpL, repeatTimesTmpL, vocabLocalSize_); + Mins(logitsInL, logitsInL, 0.0f, vocabLocalSize_); + Duplicate(logitsInTmpL2, alpha, vocabLocalSize_); + Select(repeatTimesTmpL, repeatTimesUint8L, logitsInTmpL2, 1.0f, + SELMODE::VSEL_TENSOR_SCALAR_MODE, vocabLocalSize_); + Mul(logitsInL, logitsInL, repeatTimesTmpL, vocabLocalSize_); + Add(logitsInL, logitsInL, logitsInTmpL, vocabLocalSize_); + Sub(logitsOutL, logitsInL, repeatTimesFp32L, vocabLocalSize_); + + logitsOutQueue_.EnQue(logitsOutL); + logitsInQueue_.FreeTensor(logitsInL); + repeatTimesInQueue_.FreeTensor(repeatTimesL); + } + + __aicore__ inline void CopyOut() + { + LocalTensor logitsOutL = logitsOutQueue_.DeQue(); + DataCopy(logitsOutG_, logitsOutL, vocabLocalSize_); + logitsOutQueue_.FreeTensor(logitsOutL); + } + +private: + __gm__ int64_t *preIdsGm_; + __gm__ float *logitsInGm_; + __gm__ int32_t *repeatTimesGm_; + __gm__ float *penaltyScoresGm_; + __gm__ float *frequencyScoresGm_; + __gm__ float *presenceScoresGm_; + __gm__ int64_t *curLenGm_; + __gm__ int64_t *minLenGm_; + __gm__ int64_t *eosTokenIdGm_; + __gm__ float *logitsOutGm_; + + TPipe pipe_; + TQue logitsInQueue_; + TQue repeatTimesInQueue_; + TQue logitsOutQueue_; + TQue repeatTimesOutQueue_; + TBuf repeatTimesFp32Buff_; + TBuf repeatTimesUint8Buff_; + TBuf repeatTimesTmpBuff_; + TBuf logitsTmpBuff_; + TBuf logitsTmpBuff2_; + + GlobalTensor logitsInG_; + GlobalTensor logitsOutG_; + GlobalTensor repeateTimesG_; + + int32_t vocabSize_; + int32_t vocabLocalSize_; + int32_t seqLen_; + int32_t eosTokenIdLength_; + int32_t bs_; + int32_t bsEachCore_; +}; + +extern "C" __global__ __aicore__ void token_penalty_multi_scores( + GM_ADDR preIds, GM_ADDR logits, GM_ADDR repeatTimes, GM_ADDR penaltyScores, GM_ADDR frequencyScores, + GM_ADDR presenceScores, GM_ADDR curLen, GM_ADDR minLen, GM_ADDR eosTokenId, + GM_ADDR logitsOut, GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + TokenPenaltyMultiScores op( + tilingData.vs, tilingData.vsBlock, tilingData.seqLen, tilingData.etil, tilingData.bs, tilingData.bsBlock); + op.Init(preIds, logits, repeatTimes, penaltyScores, frequencyScores, presenceScores, curLen, minLen, eosTokenId, logitsOut); + op.Process(); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/token_penalty_multi_scores_v2.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/token_penalty_multi_scores_v2.cpp new file mode 100644 index 00000000000..a3af2ee92fb --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/token_penalty_multi_scores_v2.cpp @@ -0,0 +1,262 @@ +#include "kernel_operator.h" +using namespace AscendC; + +const int32_t BUFFER_NUM = 1; + +class TokenPenaltyMultiScoresV2 { +public: + __aicore__ inline TokenPenaltyMultiScoresV2( + int32_t vs, int32_t vsBlock, int32_t sl, int32_t etil, int32_t bs, int32_t bsBlock, int32_t badWordsLen) + { + vocabSize_ = vs; + vocabLocalSize_ = vsBlock; + seqLen_ = sl; + eosTokenIdLength_ = etil; + bs_ = bs; + bsEachCore_ = bsBlock; + badWLen_ = badWordsLen; + } + + __aicore__ inline void Init(__gm__ uint8_t *preIds, + __gm__ uint8_t *logitsIn, + __gm__ uint8_t *repeatTimes, + __gm__ uint8_t *penaltyScores, + __gm__ uint8_t *frequencyScores, + __gm__ uint8_t *presenceScores, + __gm__ uint8_t *temperatures, + __gm__ uint8_t *badWords, + __gm__ uint8_t *curLen, + __gm__ uint8_t *minLen, + __gm__ uint8_t *eosTokenId, + __gm__ uint8_t *logitsOut) + { + preIdsGm_ = (__gm__ int64_t *)preIds; + logitsInGm_ = (__gm__ float *)logitsIn; + repeatTimesGm_ = (__gm__ int32_t *)repeatTimes; + penaltyScoresGm_ = (__gm__ float *)penaltyScores; + frequencyScoresGm_ = (__gm__ float *)frequencyScores; + presenceScoresGm_ = (__gm__ float *)presenceScores; + temperaturesGm_ = (__gm__ float *)temperatures; + badTokenIdGm_ = (__gm__ int64_t *)badWords; + curLenGm_ = (__gm__ int64_t *)curLen; + minLenGm_ = (__gm__ int64_t *)minLen; + eosTokenIdGm_ = (__gm__ int64_t *)eosTokenId; + logitsOutGm_ = (__gm__ float *)logitsOut; + + pipe_.InitBuffer(logitsInQueue_, BUFFER_NUM, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(repeatTimesInQueue_, BUFFER_NUM, vocabLocalSize_ * sizeof(int32_t)); + pipe_.InitBuffer(logitsOutQueue_, BUFFER_NUM, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(repeatTimesFp32Buff_, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(repeatTimesUint8Buff_, vocabLocalSize_ * sizeof(uint8_t)); + pipe_.InitBuffer(repeatTimesTmpBuff_, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(temperaturesTmpBuff_, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(logitsTmpBuff_, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(logitsTmpBuff2_, vocabLocalSize_ * sizeof(float)); + } + + __aicore__ inline void Process() + { + int32_t vocLoop = (vocabSize_ + vocabLocalSize_ - 1) / vocabLocalSize_; + for (int32_t bsIdEachCore = 0; bsIdEachCore < bsEachCore_; bsIdEachCore++) { + for (int32_t i = 0; i < vocLoop; i++) { + int32_t bsOffset = GetBlockIdx() * bsEachCore_ + bsIdEachCore; + if (bsOffset >= bs_) { + break; + } + int32_t gmOffset = bsOffset * vocabSize_ + i * vocabLocalSize_; + logitsInG_.SetGlobalBuffer(logitsInGm_ + gmOffset); + logitsOutG_.SetGlobalBuffer(logitsOutGm_ + gmOffset); + repeateTimesG_.SetGlobalBuffer(repeatTimesGm_ + gmOffset); + CopyIn(); + Compute(bsOffset, i, vocabLocalSize_); + CopyOut(); + } + // pipe_barrier(PIPE_ALL); + + // ban_bad_words + for (int32_t i = 0; i < vocLoop; i++) { + int32_t bsOffset = GetBlockIdx() * bsEachCore_ + bsIdEachCore; + if (bsOffset >= bs_) { + break; + } + int32_t gmOffset = bsOffset * vocabSize_; + for (int32_t j = 0; j < badWLen_; j++) { + int64_t badTokenId = *(badTokenIdGm_ + j); + pipe_barrier(PIPE_ALL); + if (badTokenId >= static_cast(vocabSize_) || badTokenId < 0) { + continue; + } + *(logitsOutGm_ + gmOffset + badTokenId) = -1e10f; + pipe_barrier(PIPE_ALL); + } + } + } + } + + __aicore__ inline void CopyIn() + { + LocalTensor logitsInL = logitsInQueue_.AllocTensor(); + LocalTensor repeatTimesL = repeatTimesInQueue_.AllocTensor(); + DataCopy(logitsInL, logitsInG_, vocabLocalSize_); + DataCopy(repeatTimesL, repeateTimesG_, vocabLocalSize_); + logitsInQueue_.EnQue(logitsInL); + repeatTimesInQueue_.EnQue(repeatTimesL); + } + + __aicore__ inline void Compute(int32_t bsId, int32_t vocSizeId, int32_t vocabLocalSize_) + { + LocalTensor logitsInL = logitsInQueue_.DeQue(); + LocalTensor repeatTimesL = repeatTimesInQueue_.DeQue(); + + int64_t startVocId = vocabLocalSize_ * vocSizeId; + if (*(curLenGm_ + bsId) >= 0) { + // min_length_logits_process + if (*(curLenGm_ + bsId) < *(minLenGm_ + bsId)) { + for (int i = 0; i < eosTokenIdLength_; ++i) { + int64_t eosTokenIdOffset = *(eosTokenIdGm_ + i) - startVocId; + pipe_barrier(PIPE_ALL); + if (eosTokenIdOffset >= 0 && eosTokenIdOffset < vocabLocalSize_) { + logitsInL.SetValue(eosTokenIdOffset, (float)-1e10); + pipe_barrier(PIPE_ALL); + } + } + } + + // update_repeat_times + for (int i = 0; i < seqLen_; i++) { + int64_t predId = *(preIdsGm_ + bsId * seqLen_ + i); + if (predId < 0) { + break; + } + int64_t predIdOffset = predId - startVocId; + pipe_barrier(PIPE_ALL); + if (predIdOffset >= 0 && predIdOffset < vocabLocalSize_) { + int32_t repeatNew = repeatTimesL.GetValue(predIdOffset) + 1; + repeatTimesL.SetValue(predIdOffset, repeatNew); + pipe_barrier(PIPE_ALL); + } + } + } + + LocalTensor logitsOutL = logitsOutQueue_.AllocTensor(); + LocalTensor logitsInTmpL = logitsTmpBuff_.Get(); + LocalTensor logitsInTmpL2 = logitsTmpBuff2_.Get(); + LocalTensor repeatTimesFp32L = repeatTimesFp32Buff_.Get(); + LocalTensor repeatTimesUint8L = repeatTimesUint8Buff_.Get(); + LocalTensor repeatTimesTmpL = repeatTimesTmpBuff_.Get(); + LocalTensor temperaturesTmpL = temperaturesTmpBuff_.Get(); + + // update_value_by_repeat_times + float alpha = *(penaltyScoresGm_ + bsId); + float alphaR = 1.0f / alpha; + float beta = *(frequencyScoresGm_ + bsId); + float gamma = *(presenceScoresGm_ + bsId); + float temperatures = *(temperaturesGm_ + bsId); + float temperaturesR = 1.0f / temperatures; + + Cast(repeatTimesFp32L, repeatTimesL, RoundMode::CAST_NONE, vocabLocalSize_); + Duplicate(repeatTimesTmpL, 0.5f, vocabLocalSize_); + Compare(repeatTimesUint8L, repeatTimesFp32L, repeatTimesTmpL, CMPMODE::GT, vocabLocalSize_); + Muls(repeatTimesTmpL, repeatTimesFp32L, beta, vocabLocalSize_); + Adds(repeatTimesTmpL, repeatTimesTmpL, gamma, vocabLocalSize_); + + Select(repeatTimesFp32L, repeatTimesUint8L, repeatTimesTmpL, 0.0f, + SELMODE::VSEL_TENSOR_SCALAR_MODE, vocabLocalSize_); + pipe_barrier(PIPE_ALL); + + Maxs(logitsInTmpL, logitsInL, 0.0f, vocabLocalSize_); + Duplicate(logitsInTmpL2, alphaR, vocabLocalSize_); + + Select(repeatTimesTmpL, repeatTimesUint8L, logitsInTmpL2, 1.0f, + SELMODE::VSEL_TENSOR_SCALAR_MODE, vocabLocalSize_); + Mul(logitsInTmpL, logitsInTmpL, repeatTimesTmpL, vocabLocalSize_); + Mins(logitsInL, logitsInL, 0.0f, vocabLocalSize_); + Duplicate(logitsInTmpL2, alpha, vocabLocalSize_); + + Select(repeatTimesTmpL, repeatTimesUint8L, logitsInTmpL2, 1.0f, + SELMODE::VSEL_TENSOR_SCALAR_MODE, vocabLocalSize_); + Mul(logitsInL, logitsInL, repeatTimesTmpL, vocabLocalSize_); + + Add(logitsInL, logitsInL, logitsInTmpL, vocabLocalSize_); + Sub(logitsOutL, logitsInL, repeatTimesFp32L, vocabLocalSize_); + Muls(logitsOutL, logitsOutL, temperaturesR, vocabLocalSize_); + + // pipe_barrier(PIPE_ALL); + + logitsOutQueue_.EnQue(logitsOutL); + logitsInQueue_.FreeTensor(logitsInL); + repeatTimesInQueue_.FreeTensor(repeatTimesL); + } + + __aicore__ inline void CopyOut() + { + LocalTensor logitsOutL = logitsOutQueue_.DeQue(); + DataCopy(logitsOutG_, logitsOutL, vocabLocalSize_); + logitsOutQueue_.FreeTensor(logitsOutL); + } + +private: + __gm__ int64_t *preIdsGm_; + __gm__ float *logitsInGm_; + __gm__ int32_t *repeatTimesGm_; + __gm__ float *penaltyScoresGm_; + __gm__ float *frequencyScoresGm_; + __gm__ float *presenceScoresGm_; + __gm__ float *temperaturesGm_; + __gm__ int64_t *badTokenIdGm_; + __gm__ int64_t *curLenGm_; + __gm__ int64_t *minLenGm_; + __gm__ int64_t *eosTokenIdGm_; + __gm__ float *logitsOutGm_; + + TPipe pipe_; + TQue logitsInQueue_; + TQue repeatTimesInQueue_; + TQue logitsOutQueue_; + TQue repeatTimesOutQueue_; + TBuf repeatTimesFp32Buff_; + TBuf repeatTimesUint8Buff_; + TBuf repeatTimesTmpBuff_; + TBuf temperaturesTmpBuff_; + TBuf logitsTmpBuff_; + TBuf logitsTmpBuff2_; + + GlobalTensor logitsInG_; + GlobalTensor logitsOutG_; + GlobalTensor repeateTimesG_; + + int32_t vocabSize_; + int32_t vocabLocalSize_; + int32_t seqLen_; + int32_t eosTokenIdLength_; + int32_t bs_; + int32_t bsEachCore_; + int32_t badWLen_; +}; + +/** + 0 preIds: bs, length_id + 1 logits: bs, length + 2 repeatTimes: bs, length + 3 penaltyScores: bs + 4 frequencyScores: bs + 5 presenceScores: bs + 6 temperatures: bs + 7 badWords: bs, badWordsLen + 8 curLen: bs + 9 minLen: bs + 10 eosTokenId: end_length, + + 0 logitsOut: bs, length +*/ +extern "C" __global__ __aicore__ void token_penalty_multi_scores_v2( + GM_ADDR preIds, GM_ADDR logits, GM_ADDR repeatTimes, GM_ADDR penaltyScores, GM_ADDR frequencyScores, + GM_ADDR presenceScores, GM_ADDR temperatures, GM_ADDR badWords, GM_ADDR curLen, GM_ADDR minLen, GM_ADDR eosTokenId, + GM_ADDR logitsOut, GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + TokenPenaltyMultiScoresV2 op( + tilingData.vs, tilingData.vsBlock, tilingData.seqLen, tilingData.etil, tilingData.bs, tilingData.bsBlock, tilingData.badWLen); + op.Init(preIds, logits, repeatTimes, penaltyScores, frequencyScores, presenceScores, temperatures, badWords, curLen, minLen, eosTokenId, logitsOut); + op.Process(); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/token_penalty_multi_scores_with_stop_seqs.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/token_penalty_multi_scores_with_stop_seqs.cpp new file mode 100644 index 00000000000..4d82b6fa150 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/token_penalty_multi_scores_with_stop_seqs.cpp @@ -0,0 +1,239 @@ +#include "kernel_operator.h" +using namespace AscendC; + +const int32_t BUFFER_NUM = 1; + +class TokenPenaltyMultiScoresWithStopSeqs { +public: + __aicore__ inline TokenPenaltyMultiScoresWithStopSeqs( + int32_t vs, int32_t vsBlock, int32_t sl, int32_t stop_seqs_num, int32_t stop_seqs_max_len, int32_t eos_len, int32_t bs, int32_t bsBlock) + { + vocabSize_ = vs; + vocabLocalSize_ = vsBlock; + seqLen_ = sl; + stopSeqsNum = stop_seqs_num; + stopSeqsMaxLen = stop_seqs_max_len; + eosLen = eos_len; + bs_ = bs; + bsEachCore_ = bsBlock; + } + + __aicore__ inline void Init(__gm__ uint8_t *preIds, + __gm__ uint8_t *logitsIn, + __gm__ uint8_t *repeatTimes, + __gm__ uint8_t *penaltyScores, + __gm__ uint8_t *frequencyScores, + __gm__ uint8_t *presenceScores, + __gm__ uint8_t *curLen, + __gm__ uint8_t *minLen, + __gm__ uint8_t *stopSeqs, + __gm__ uint8_t *stopSeqsLen, + __gm__ uint8_t *eosTokenId, + __gm__ uint8_t *logitsOut) + { + preIdsGm_ = (__gm__ int64_t *)preIds; + logitsInGm_ = (__gm__ float *)logitsIn; + repeatTimesGm_ = (__gm__ int32_t *)repeatTimes; + penaltyScoresGm_ = (__gm__ float *)penaltyScores; + frequencyScoresGm_ = (__gm__ float *)frequencyScores; + presenceScoresGm_ = (__gm__ float *)presenceScores; + curLenGm_ = (__gm__ int64_t *)curLen; + minLenGm_ = (__gm__ int64_t *)minLen; + stopSeqsGm_ = (__gm__ int64_t *)stopSeqs; + stopSeqsLenGm_ = (__gm__ int32_t *)stopSeqsLen; + eosTokenIdGm_ = (__gm__ int64_t *)eosTokenId; + logitsOutGm_ = (__gm__ float *)logitsOut; + + pipe_.InitBuffer(logitsInQueue_, BUFFER_NUM, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(repeatTimesInQueue_, BUFFER_NUM, vocabLocalSize_ * sizeof(int32_t)); + pipe_.InitBuffer(logitsOutQueue_, BUFFER_NUM, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(repeatTimesFp32Buff_, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(repeatTimesUint8Buff_, vocabLocalSize_ * sizeof(uint8_t)); + pipe_.InitBuffer(repeatTimesTmpBuff_, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(logitsTmpBuff_, vocabLocalSize_ * sizeof(float)); + pipe_.InitBuffer(logitsTmpBuff2_, vocabLocalSize_ * sizeof(float)); + } +/** + preIds: bs, seqLen + logits: bs, vocabSize + repeatTimes: bs, vocabSize + penaltyScores: bs + frequencyScores: bs + presenceScores: bs + curLen: bs + minLen: bs + eosTokenId: end_length, + logitsOut: bs, vocabSize +*/ + __aicore__ inline void Process() + { + int32_t vocLoop = (vocabSize_ + vocabLocalSize_ - 1) / vocabLocalSize_; + for (int32_t bsIdEachCore = 0; bsIdEachCore < bsEachCore_; bsIdEachCore++) { + for (int32_t i = 0; i < vocLoop; i++) { + int32_t bsOffset = GetBlockIdx() * bsEachCore_ + bsIdEachCore; + if (bsOffset >= bs_) { + break; + } + int32_t gmOffset = bsOffset * vocabSize_ + i * vocabLocalSize_; + logitsInG_.SetGlobalBuffer(logitsInGm_ + gmOffset); + logitsOutG_.SetGlobalBuffer(logitsOutGm_ + gmOffset); + repeateTimesG_.SetGlobalBuffer(repeatTimesGm_ + gmOffset); + CopyIn(); + Compute(bsOffset, i, vocabLocalSize_); + CopyOut(); + } + } + } + + __aicore__ inline void CopyIn() + { + LocalTensor logitsInL = logitsInQueue_.AllocTensor(); + LocalTensor repeatTimesL = repeatTimesInQueue_.AllocTensor(); + DataCopy(logitsInL, logitsInG_, vocabLocalSize_); + DataCopy(repeatTimesL, repeateTimesG_, vocabLocalSize_); + logitsInQueue_.EnQue(logitsInL); + repeatTimesInQueue_.EnQue(repeatTimesL); + } + + __aicore__ inline void Compute(int32_t bsId, int32_t vocSizeId, int32_t vocabLocalSize_) + { + LocalTensor logitsInL = logitsInQueue_.DeQue(); + LocalTensor repeatTimesL = repeatTimesInQueue_.DeQue(); + + int64_t startVocId = vocabLocalSize_ * vocSizeId; + if (*(curLenGm_ + bsId) >= 0) { + // min_length_logits_process + if (*(curLenGm_ + bsId) < *(minLenGm_ + bsId)) { + for (int32_t i = 0; i < stopSeqsNum; i++) { + for (int32_t j = 0; j < *(stopSeqsLenGm_ + i); j++) { + int64_t eosTokenIdOffset = *(stopSeqsGm_ + i * stopSeqsMaxLen + j) - startVocId; + pipe_barrier(PIPE_ALL); + if (eosTokenIdOffset >= 0 && eosTokenIdOffset < vocabLocalSize_) { + logitsInL.SetValue(eosTokenIdOffset, (float)-1e10); + pipe_barrier(PIPE_ALL); + } + } + } + for (int i = 0; i < eosLen; ++i) { + int64_t eosTokenIdOffset = *(eosTokenIdGm_ + i) - startVocId; + pipe_barrier(PIPE_ALL); + if (eosTokenIdOffset >= 0 && eosTokenIdOffset < vocabLocalSize_) { + logitsInL.SetValue(eosTokenIdOffset, (float)-1e10); + pipe_barrier(PIPE_ALL); + } + } + } + + // update_repeat_times + for (int i = 0; i < seqLen_; i++) { + int64_t predId = *(preIdsGm_ + bsId * seqLen_ + i); + if (predId < 0) { + break; + } + int64_t predIdOffset = predId - startVocId; + pipe_barrier(PIPE_ALL); + if (predIdOffset >= 0 && predIdOffset < vocabLocalSize_) { + int32_t repeatNew = repeatTimesL.GetValue(predIdOffset) + 1; + repeatTimesL.SetValue(predIdOffset, repeatNew); + pipe_barrier(PIPE_ALL); + } + } + } + + LocalTensor logitsOutL = logitsOutQueue_.AllocTensor(); + LocalTensor logitsInTmpL = logitsTmpBuff_.Get(); + LocalTensor logitsInTmpL2 = logitsTmpBuff2_.Get(); + LocalTensor repeatTimesFp32L = repeatTimesFp32Buff_.Get(); + LocalTensor repeatTimesUint8L = repeatTimesUint8Buff_.Get(); + LocalTensor repeatTimesTmpL = repeatTimesTmpBuff_.Get(); + + // update_value_by_repeat_times + float alpha = *(penaltyScoresGm_ + bsId); + float alphaR = 1.0f / alpha; + float beta = *(frequencyScoresGm_ + bsId); + float gamma = *(presenceScoresGm_ + bsId); + + Cast(repeatTimesFp32L, repeatTimesL, RoundMode::CAST_NONE, vocabLocalSize_); + Duplicate(repeatTimesTmpL, 0.5f, vocabLocalSize_); + Compare(repeatTimesUint8L, repeatTimesFp32L, repeatTimesTmpL, CMPMODE::GT, vocabLocalSize_); + Muls(repeatTimesTmpL, repeatTimesFp32L, beta, vocabLocalSize_); + Adds(repeatTimesTmpL, repeatTimesTmpL, gamma, vocabLocalSize_); + Select(repeatTimesFp32L, repeatTimesUint8L, repeatTimesTmpL, 0.0f, + SELMODE::VSEL_TENSOR_SCALAR_MODE, vocabLocalSize_); + pipe_barrier(PIPE_ALL); + + Maxs(logitsInTmpL, logitsInL, 0.0f, vocabLocalSize_); + Duplicate(logitsInTmpL2, alphaR, vocabLocalSize_); + Select(repeatTimesTmpL, repeatTimesUint8L, logitsInTmpL2, 1.0f, + SELMODE::VSEL_TENSOR_SCALAR_MODE, vocabLocalSize_); + Mul(logitsInTmpL, logitsInTmpL, repeatTimesTmpL, vocabLocalSize_); + Mins(logitsInL, logitsInL, 0.0f, vocabLocalSize_); + Duplicate(logitsInTmpL2, alpha, vocabLocalSize_); + Select(repeatTimesTmpL, repeatTimesUint8L, logitsInTmpL2, 1.0f, + SELMODE::VSEL_TENSOR_SCALAR_MODE, vocabLocalSize_); + Mul(logitsInL, logitsInL, repeatTimesTmpL, vocabLocalSize_); + Add(logitsInL, logitsInL, logitsInTmpL, vocabLocalSize_); + Sub(logitsOutL, logitsInL, repeatTimesFp32L, vocabLocalSize_); + + logitsOutQueue_.EnQue(logitsOutL); + logitsInQueue_.FreeTensor(logitsInL); + repeatTimesInQueue_.FreeTensor(repeatTimesL); + } + + __aicore__ inline void CopyOut() + { + LocalTensor logitsOutL = logitsOutQueue_.DeQue(); + DataCopy(logitsOutG_, logitsOutL, vocabLocalSize_); + logitsOutQueue_.FreeTensor(logitsOutL); + } + +private: + __gm__ int64_t *preIdsGm_; + __gm__ float *logitsInGm_; + __gm__ int32_t *repeatTimesGm_; + __gm__ float *penaltyScoresGm_; + __gm__ float *frequencyScoresGm_; + __gm__ float *presenceScoresGm_; + __gm__ int64_t *curLenGm_; + __gm__ int64_t *minLenGm_; + __gm__ int64_t *stopSeqsGm_; + __gm__ int32_t *stopSeqsLenGm_; + __gm__ int64_t *eosTokenIdGm_; + __gm__ float *logitsOutGm_; + + TPipe pipe_; + TQue logitsInQueue_; + TQue repeatTimesInQueue_; + TQue logitsOutQueue_; + TQue repeatTimesOutQueue_; + TBuf repeatTimesFp32Buff_; + TBuf repeatTimesUint8Buff_; + TBuf repeatTimesTmpBuff_; + TBuf logitsTmpBuff_; + TBuf logitsTmpBuff2_; + + GlobalTensor logitsInG_; + GlobalTensor logitsOutG_; + GlobalTensor repeateTimesG_; + + int32_t vocabSize_; + int32_t vocabLocalSize_; + int32_t seqLen_; + int32_t stopSeqsNum; + int32_t stopSeqsMaxLen; + int32_t eosLen; + int32_t bs_; + int32_t bsEachCore_; +}; + +extern "C" __global__ __aicore__ void token_penalty_multi_scores_with_stop_seqs( + GM_ADDR preIds, GM_ADDR logits, GM_ADDR repeatTimes, GM_ADDR penaltyScores, GM_ADDR frequencyScores, + GM_ADDR presenceScores, GM_ADDR curLen, GM_ADDR minLen, GM_ADDR stopSeqs, GM_ADDR stopSeqsLen, GM_ADDR eosTokenIds, + GM_ADDR logitsOut, GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + TokenPenaltyMultiScoresWithStopSeqs op( + tilingData.vs, tilingData.vsBlock, tilingData.seqLen, tilingData.stop_seqs_num, tilingData.stop_seqs_max_len, tilingData.eos_len, tilingData.bs, tilingData.bsBlock); + op.Init(preIds, logits, repeatTimes, penaltyScores, frequencyScores, presenceScores, curLen, minLen, stopSeqs, stopSeqsLen, eosTokenIds, logitsOut); + op.Process(); +} diff --git a/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/update_inputs.cpp b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/update_inputs.cpp new file mode 100644 index 00000000000..bdf08324d86 --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/src/ops/ascendc/op_kernel/update_inputs.cpp @@ -0,0 +1,86 @@ +#include "kernel_operator.h" +using namespace AscendC; + +class UpdateInputs { +public: + __aicore__ inline UpdateInputs(int32_t bs, int32_t max_bsz, int32_t length) + { + set_atomic_none(); + set_mask_norm(); + this->lengthNum = length; + this->bsz = bs; + this->max_bsz = max_bsz; + pipe_barrier(PIPE_ALL); + } + + __aicore__ inline void Init( __gm__ uint8_t *stop_flags, __gm__ uint8_t *not_need_stop, __gm__ uint8_t *seq_lens_this_time, + __gm__ uint8_t *seq_lens_encoder, __gm__ uint8_t *seq_lens_decoder, __gm__ uint8_t *input_ids, + __gm__ uint8_t *stop_nums, __gm__ uint8_t *is_block_step, __gm__ uint8_t *next_tokens) + { + this->stop_flags = (__gm__ bool *)stop_flags; + this->not_need_stop = (__gm__ bool *)not_need_stop; + this->seq_lens_this_time = (__gm__ int *)seq_lens_this_time; + this->seq_lens_encoder = (__gm__ int *)seq_lens_encoder; + this->seq_lens_decoder = (__gm__ int *)seq_lens_decoder; + this->input_ids = (__gm__ int64_t *)input_ids; + this->stop_nums = (__gm__ int64_t *)stop_nums; + this->is_block_step = (__gm__ bool *)is_block_step; + this->next_tokens = (__gm__ int64_t *)next_tokens; + pipe_barrier(PIPE_ALL); + } + + __aicore__ inline void Process() + { + int64_t stop_sum = 0; + __gm__ int64_t *input_ids_now; + for (int32_t i = 0; i < max_bsz; i++) { + if (i < bsz) { + if (!is_block_step[i]) { + stop_sum += stop_flags[i] ? 1 : 0; + } + } else { + stop_sum += 1; + } + pipe_barrier(PIPE_ALL); + if (i < bsz) { + seq_lens_decoder[i] = stop_flags[i] ? 0 : (seq_lens_decoder[i] == 0 ? \ + seq_lens_encoder[i] : seq_lens_decoder[i] + 1); + seq_lens_this_time[i] = stop_flags[i] ? 0 : 1; + seq_lens_encoder[i] = 0; + input_ids_now = input_ids + i * lengthNum; + *input_ids_now = next_tokens[i]; + } + pipe_barrier(PIPE_ALL); + } + not_need_stop[0] = stop_sum < stop_nums[0]; + pipe_barrier(PIPE_ALL); + } + +private: + __gm__ bool *stop_flags; + __gm__ bool *not_need_stop; + __gm__ int *seq_lens_this_time; + __gm__ int *seq_lens_encoder; + __gm__ int *seq_lens_decoder; + __gm__ int64_t *input_ids; + __gm__ int64_t *stop_nums; + __gm__ bool *is_block_step; + __gm__ int64_t *next_tokens; + + int32_t bsz; + int32_t max_bsz; + int32_t lengthNum; +}; + +extern "C" __global__ __aicore__ void update_inputs(GM_ADDR stop_flags, GM_ADDR not_need_stop, GM_ADDR seq_lens_this_time, + GM_ADDR seq_lens_encoder, GM_ADDR seq_lens_decoder, GM_ADDR input_ids, + GM_ADDR stop_nums, GM_ADDR next_tokens, GM_ADDR is_block_step, + GM_ADDR not_need_stop_out, GM_ADDR seq_lens_this_time_out, + GM_ADDR seq_lens_encoder_out, GM_ADDR seq_lens_decoder_out, GM_ADDR input_ids_out, + GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + UpdateInputs op(tilingData.bs, tilingData.max_bs, tilingData.length); + op.Init(stop_flags, not_need_stop, seq_lens_this_time, seq_lens_encoder, seq_lens_decoder, input_ids, stop_nums, is_block_step, next_tokens); + op.Process(); +} \ No newline at end of file diff --git a/backends/npu/opp/ascendc_custom_ops/tests/ascendc/utils/common.h b/backends/npu/opp/ascendc_custom_ops/tests/ascendc/utils/common.h new file mode 100644 index 00000000000..c7ce512d0fe --- /dev/null +++ b/backends/npu/opp/ascendc_custom_ops/tests/ascendc/utils/common.h @@ -0,0 +1,106 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TESTS_ASCENDC_UTILS_H__ +#define __TESTS_ASCENDC_UTILS_H__ + +#include +#include +#include +#include +#include +#include +#include "../../../src/ops/ascendc/op_host/common_utils.cpp" + +#define CALL_RT(x) \ + if (auto ret = (x) != 0) { \ + std::cout << "[ERROR] Failed to exec acl api " << #x << ", result: " << ret << std::endl; \ + return -1; \ + } + +#define LOG_INFO(FMT, ...) printf("[INFO] " FMT "\n", __VA_ARGS__) + +#define ERROR_LOG(fmt, args...) fprintf(stdout, "[ERROR] " fmt "\n", ##args) + +namespace utils { + +double getMillisecs() +{ + struct timeval tv; + gettimeofday(&tv, nullptr); + const double sec2msec = 1e3; + const double usec2msec = 1e-3; + return tv.tv_sec * sec2msec + tv.tv_usec * usec2msec; +} + +bool WriteFile(const std::string &filePath, const void *buffer, size_t size) +{ + if (buffer == nullptr) { + ERROR_LOG("Write file failed. buffer is nullptr"); + return false; + } + int fd = open(filePath.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWRITE); + if (fd < 0) { + ERROR_LOG("Open file failed. path = %s", filePath.c_str()); + return false; + } + auto writeSize = write(fd, buffer, size); + (void)close(fd); + if (writeSize != size) { + ERROR_LOG("Write file Failed."); + return false; + } + return true; +} + +bool ReadFile(const std::string &filePath, size_t &fileSize, void *buffer, size_t bufferSize) +{ + struct stat sBuf; + int fileStatus = stat(filePath.data(), &sBuf); + if (fileStatus == -1) { + ERROR_LOG("failed to get file"); + return false; + } + if (S_ISREG(sBuf.st_mode) == 0) { + ERROR_LOG("%s is not a file, please enter a file", filePath.c_str()); + return false; + } + std::ifstream file; + file.open(filePath, std::ios::binary); + if (!file.is_open()) { + ERROR_LOG("Open file failed. path = %s", filePath.c_str()); + return false; + } + std::filebuf *buf = file.rdbuf(); + size_t size = buf->pubseekoff(0, std::ios::end, std::ios::in); + if (size == 0) { + ERROR_LOG("file size is 0"); + file.close(); + return false; + } + if (size > bufferSize) { + ERROR_LOG("file size is larger than buffer size"); + file.close(); + return false; + } + buf->pubseekpos(0, std::ios::in); + buf->sgetn(static_cast(buffer), size); + fileSize = size; + file.close(); + return true; +} +} /* namespace */ +#endif // __TESTS_ASCENDC_UTILS_H__ \ No newline at end of file diff --git a/backends/npu/tools/set_env.sh b/backends/npu/tools/set_env.sh new file mode 100644 index 00000000000..b71369600c3 --- /dev/null +++ b/backends/npu/tools/set_env.sh @@ -0,0 +1 @@ +export ASCEND_CUSTOM_OPP_PATH=${ASCEND_OPP_PATH}