|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#pragma once |
| 16 | + |
| 17 | +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) |
| 18 | + |
| 19 | +#include "paddle/common/flags.h" |
| 20 | +#include "paddle/phi/backends/gpu/gpu_context.h" |
| 21 | +#include "paddle/phi/core/kernel_registry.h" |
| 22 | +#include "paddle/phi/kernels/contiguous_kernel.h" |
| 23 | +#include "paddle/phi/kernels/elementwise_add_kernel.h" |
| 24 | +#include "paddle/phi/kernels/funcs/broadcast_function.h" |
| 25 | +#include "paddle/phi/kernels/funcs/dense_tensor_iterator.h" |
| 26 | +#include "paddle/phi/kernels/funcs/elementwise_base.h" |
| 27 | +#include "paddle/phi/kernels/funcs/elementwise_functor.h" |
| 28 | +#include "paddle/phi/kernels/funcs/index_elementwise.cu.h" |
| 29 | +#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h" |
| 30 | + |
| 31 | +#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) |
| 32 | +#include "paddle/phi/kernels/funcs/dims_simplifier.h" |
| 33 | + |
| 34 | +#endif |
| 35 | + |
| 36 | +namespace phi { |
| 37 | +template <typename Functor, |
| 38 | + typename OutT, |
| 39 | + int Arity, |
| 40 | + int NumOuts, |
| 41 | + int VecSize, |
| 42 | + int vt> |
| 43 | +__global__ void BinaryElementwiseKernel( |
| 44 | + Array<const _ptr_ char *__restrict__, Arity> ins, |
| 45 | + Array<_ptr_ OutT *, NumOuts> outs, |
| 46 | + uint32_t numel, |
| 47 | + int read_lens, |
| 48 | + Functor func, |
| 49 | + funcs::OffsetCalculator<Arity + NumOuts> offset_calc) { |
| 50 | + int64_t tid = THREAD_ID_X; |
| 51 | + int64_t nv = BLOCK_NUM_X * vt; |
| 52 | + int64_t idx = nv * BLOCK_ID_X + tid; |
| 53 | +#pragma unroll |
| 54 | + for (int i = 0; i < vt; i++) { |
| 55 | + if (idx < numel) { |
| 56 | + auto offsets = offset_calc.get(idx); |
| 57 | + using Traits = phi::funcs::FunctionTraits<Functor>; |
| 58 | + using ArgsT = typename Traits::ArgsTuple; |
| 59 | + __simd__ ArgsT args[VecSize]; |
| 60 | + __simd__ ConditionalT<OutT, NumOuts> result[VecSize]; |
| 61 | + std::get<0>(args[idx]) = |
| 62 | + *(reinterpret_cast<const _ptr_ std::tuple_element_t<0, ArgsT> *>( |
| 63 | + reinterpret_cast<const _ptr_ char *>(ins[0]) + offsets[1])); |
| 64 | + std::get<1>(args[idx]) = |
| 65 | + *(reinterpret_cast<const _ptr_ std::tuple_element_t<1, ArgsT> *>( |
| 66 | + reinterpret_cast<const _ptr_ char *>(ins[1]) + offsets[2])); |
| 67 | + funcs::SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>, |
| 68 | + VecSize, |
| 69 | + Functor, |
| 70 | + ArgsT, |
| 71 | + Arity>()( |
| 72 | + func, args, result, read_lens); |
| 73 | + char *out_ptr = reinterpret_cast<char *>(outs[0]) + offsets[0]; |
| 74 | + *reinterpret_cast<OutT *>(out_ptr) = |
| 75 | + *reinterpret_cast<const OutT *>(&(result[0])); |
| 76 | + idx += BLOCK_NUM_X; |
| 77 | + } |
| 78 | + } |
| 79 | +} |
| 80 | + |
| 81 | +// Not Support Vectorized Kernel For Now |
| 82 | +#define VEC_SIZE 1 |
| 83 | + |
| 84 | +template <typename OutT, typename Context, typename Functor, int NumOuts = 1> |
| 85 | +void BinaryStrideBroadcastKernel(const Context &dev_ctx, |
| 86 | + const std::vector<const DenseTensor *> &ins, |
| 87 | + std::vector<DenseTensor *> *outs, |
| 88 | + Functor func, |
| 89 | + int axis = -1) { |
| 90 | + using Traits = phi::funcs::FunctionTraits<Functor>; |
| 91 | + const int Arity = Traits::arity; |
| 92 | + for (auto i = 0; i < outs->size(); ++i) { |
| 93 | + if (i > 0) { |
| 94 | + PADDLE_ENFORCE_EQ( |
| 95 | + (*outs)[i]->dims(), |
| 96 | + (*outs)[0]->dims(), |
| 97 | + common::errors::InvalidArgument( |
| 98 | + "The shape of each output tensor shall be identical yet, but " |
| 99 | + "%d-th output tensor`s shape is not.", |
| 100 | + i)); |
| 101 | + } |
| 102 | + dev_ctx.template Alloc<OutT>((*outs)[i]); |
| 103 | + } |
| 104 | + if ((*outs)[0]->numel() == 0) { |
| 105 | + return; |
| 106 | + } |
| 107 | + int max_rank = 0; |
| 108 | + int min_rank = phi::DDim::kMaxRank; |
| 109 | + for (auto *in : ins) { |
| 110 | + max_rank = std::max(max_rank, in->dims().size()); |
| 111 | + min_rank = std::min(min_rank, in->dims().size()); |
| 112 | + } |
| 113 | + if (ins.size() == 1) { |
| 114 | + max_rank = std::max(max_rank, (*outs)[0]->dims().size()); |
| 115 | + } |
| 116 | + axis = axis == -1 ? max_rank - min_rank : axis; |
| 117 | + auto classifier = |
| 118 | + funcs::BroadcastTypeClassifier<OutT, Functor, Arity, NumOuts>( |
| 119 | + ins, outs, axis); |
| 120 | + DenseTensorIteratorConfig config; |
| 121 | + config.add_output(*((*outs)[0])); |
| 122 | + config.add_const_input(*(ins[0])); |
| 123 | + config.add_const_input(*(ins[1])); |
| 124 | + DenseTensorIterator iter = config.build(); |
| 125 | + const int &numel = iter.numel(); |
| 126 | + funcs::OffsetCalculator offset_calc = funcs::make_offset_calculator<3>(iter); |
| 127 | + constexpr int unroll_factor = sizeof(OutT) >= 4 ? 2 : 4; |
| 128 | + auto stream = dev_ctx.stream(); |
| 129 | + auto threads = 128; |
| 130 | + auto blocks = (numel + 128 * unroll_factor - 1) / (128 * unroll_factor); |
| 131 | + int vec_size = VEC_SIZE; |
| 132 | + BinaryElementwiseKernel<Functor, |
| 133 | + OutT, |
| 134 | + Arity, |
| 135 | + NumOuts, |
| 136 | + VEC_SIZE, |
| 137 | + unroll_factor> |
| 138 | + <<<blocks, threads, 0, stream>>>(classifier.ins_data, |
| 139 | + classifier.outs_data, |
| 140 | + numel, |
| 141 | + vec_size, |
| 142 | + func, |
| 143 | + offset_calc); |
| 144 | +} |
| 145 | + |
| 146 | +template <typename T, typename Context, typename Functor> |
| 147 | +void LaunchBinaryElementwiseStrideKernel(const Context &dev_ctx, |
| 148 | + const DenseTensor &x, |
| 149 | + const DenseTensor &y, |
| 150 | + Functor func, |
| 151 | + int axis, |
| 152 | + DenseTensor *out) { |
| 153 | + std::vector<const DenseTensor *> inputs = {&x, &y}; |
| 154 | + std::vector<DenseTensor *> outputs = {out}; |
| 155 | + dev_ctx.template Alloc<T>(out); |
| 156 | + BinaryStrideBroadcastKernel<T, Context>( |
| 157 | + dev_ctx, inputs, &outputs, func, axis); |
| 158 | +} |
| 159 | + |
| 160 | +template <typename Context> |
| 161 | +phi::DenseTensor Tensor2Contiguous(const Context &dev_ctx, |
| 162 | + const phi::DenseTensor &tensor) { |
| 163 | + phi::DenseTensor dense_out; |
| 164 | + phi::MetaTensor meta_input(tensor); |
| 165 | + phi::MetaTensor meta_out(&dense_out); |
| 166 | + UnchangedInferMeta(meta_input, &meta_out); |
| 167 | + PD_VISIT_ALL_TYPES(tensor.dtype(), "Tensor2Contiguous", ([&] { |
| 168 | + phi::ContiguousKernel<data_t, Context>( |
| 169 | + dev_ctx, tensor, &dense_out); |
| 170 | + })); |
| 171 | + return dense_out; |
| 172 | +} |
| 173 | + |
| 174 | +} // namespace phi |
| 175 | + |
| 176 | +#endif |
0 commit comments