|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) |
| 16 | + |
| 17 | +#include "paddle/common/flags.h" |
| 18 | +#include "paddle/phi/backends/gpu/gpu_context.h" |
| 19 | +#include "paddle/phi/core/kernel_registry.h" |
| 20 | +#include "paddle/phi/kernels/contiguous_kernel.h" |
| 21 | +#include "paddle/phi/kernels/funcs/broadcast_function.h" |
| 22 | +#include "paddle/phi/kernels/funcs/dense_tensor_iterator.h" |
| 23 | +#include "paddle/phi/kernels/funcs/elementwise_base.h" |
| 24 | +#include "paddle/phi/kernels/funcs/elementwise_functor.h" |
| 25 | +#include "paddle/phi/kernels/funcs/index_elementwise.cu.h" |
| 26 | +#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h" |
| 27 | + |
| 28 | +#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) |
| 29 | +#include "paddle/phi/kernels/funcs/dims_simplifier.h" |
| 30 | + |
| 31 | +#endif |
| 32 | + |
| 33 | +COMMON_DECLARE_bool(use_stride_kernel); |
| 34 | +COMMON_DECLARE_bool(use_stride_compute_kernel); |
| 35 | + |
| 36 | +namespace phi { |
| 37 | +template <typename Functor, |
| 38 | + typename OutT, |
| 39 | + int Arity, |
| 40 | + int NumOuts, |
| 41 | + int VecSize, |
| 42 | + int vt> |
| 43 | +__global__ void BinaryElementwiseKernel( |
| 44 | + Array<const _ptr_ char *__restrict__, Arity> ins, |
| 45 | + Array<_ptr_ OutT *, NumOuts> outs, |
| 46 | + uint32_t numel, |
| 47 | + int read_lens, |
| 48 | + Functor func, |
| 49 | + funcs::OffsetCalculator<Arity + NumOuts> offset_calc) { |
| 50 | + int64_t tid = THREAD_ID_X; |
| 51 | + int64_t nv = BLOCK_NUM_X * vt; |
| 52 | + int64_t idx = nv * BLOCK_ID_X + tid; |
| 53 | +#pragma unroll |
| 54 | + for (int i = 0; i < vt; i++) { |
| 55 | + if (idx < numel) { |
| 56 | + auto offsets = offset_calc.get(idx); |
| 57 | + using Traits = phi::funcs::FunctionTraits<Functor>; |
| 58 | + using ArgsT = typename Traits::ArgsTuple; |
| 59 | + __simd__ ArgsT args[VecSize]; |
| 60 | + __simd__ ConditionalT<OutT, NumOuts> result[VecSize]; |
| 61 | + std::get<0>(args[idx]) = |
| 62 | + *(reinterpret_cast<const _ptr_ std::tuple_element_t<0, ArgsT> *>( |
| 63 | + reinterpret_cast<const _ptr_ char *>(ins[0]) + offsets[1])); |
| 64 | + std::get<1>(args[idx]) = |
| 65 | + *(reinterpret_cast<const _ptr_ std::tuple_element_t<1, ArgsT> *>( |
| 66 | + reinterpret_cast<const _ptr_ char *>(ins[1]) + offsets[2])); |
| 67 | + funcs::SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>, |
| 68 | + VecSize, |
| 69 | + Functor, |
| 70 | + ArgsT, |
| 71 | + Arity>()( |
| 72 | + func, args, result, read_lens); |
| 73 | + char *out_ptr = reinterpret_cast<char *>(outs[0]) + offsets[0]; |
| 74 | + *reinterpret_cast<OutT *>(out_ptr) = |
| 75 | + *reinterpret_cast<const OutT *>(&(result[0])); |
| 76 | + idx += BLOCK_NUM_X; |
| 77 | + } |
| 78 | + } |
| 79 | +} |
| 80 | + |
| 81 | +// Not Support Vectorized Kernel For Now |
| 82 | +#define VEC_SIZE 1 |
| 83 | + |
| 84 | +template <typename OutT, typename Context, typename Functor, int NumOuts = 1> |
| 85 | +void BinaryStrideBroadcastKernel(const Context &dev_ctx, |
| 86 | + const std::vector<const DenseTensor *> &ins, |
| 87 | + std::vector<DenseTensor *> *outs, |
| 88 | + Functor func, |
| 89 | + int axis = -1) { |
| 90 | + using Traits = phi::funcs::FunctionTraits<Functor>; |
| 91 | + const int Arity = Traits::arity; |
| 92 | + for (auto i = 0; i < outs->size(); ++i) { |
| 93 | + if (i > 0) { |
| 94 | + PADDLE_ENFORCE_EQ( |
| 95 | + (*outs)[i]->dims(), |
| 96 | + (*outs)[0]->dims(), |
| 97 | + common::errors::InvalidArgument( |
| 98 | + "The shape of each output tensor shall be identical yet, but " |
| 99 | + "%d-th output tensor`s shape is not.", |
| 100 | + i)); |
| 101 | + } |
| 102 | + dev_ctx.template Alloc<OutT>((*outs)[i]); |
| 103 | + } |
| 104 | + if ((*outs)[0]->numel() == 0) { |
| 105 | + return; |
| 106 | + } |
| 107 | + int max_rank = 0; |
| 108 | + int min_rank = phi::DDim::kMaxRank; |
| 109 | + for (auto *in : ins) { |
| 110 | + max_rank = std::max(max_rank, in->dims().size()); |
| 111 | + min_rank = std::min(min_rank, in->dims().size()); |
| 112 | + } |
| 113 | + if (ins.size() == 1) { |
| 114 | + max_rank = std::max(max_rank, (*outs)[0]->dims().size()); |
| 115 | + } |
| 116 | + axis = axis == -1 ? max_rank - min_rank : axis; |
| 117 | + auto classifier = |
| 118 | + funcs::BroadcastTypeClassifier<OutT, Functor, Arity, NumOuts>( |
| 119 | + ins, outs, axis); |
| 120 | + DenseTensorIteratorConfig config; |
| 121 | + config.add_output(*((*outs)[0])); |
| 122 | + config.add_const_input(*(ins[0])); |
| 123 | + config.add_const_input(*(ins[1])); |
| 124 | + DenseTensorIterator iter = config.build(); |
| 125 | + const int &numel = iter.numel(); |
| 126 | + funcs::OffsetCalculator offset_calc = funcs::make_offset_calculator<3>(iter); |
| 127 | + constexpr int unroll_factor = sizeof(OutT) >= 4 ? 2 : 4; |
| 128 | + auto stream = dev_ctx.stream(); |
| 129 | + auto threads = 128; |
| 130 | + auto blocks = (numel + 128 * unroll_factor - 1) / (128 * unroll_factor); |
| 131 | + int vec_size = VEC_SIZE; |
| 132 | + BinaryElementwiseKernel<Functor, |
| 133 | + OutT, |
| 134 | + Arity, |
| 135 | + NumOuts, |
| 136 | + VEC_SIZE, |
| 137 | + unroll_factor> |
| 138 | + <<<blocks, threads, 0, stream>>>(classifier.ins_data, |
| 139 | + classifier.outs_data, |
| 140 | + numel, |
| 141 | + vec_size, |
| 142 | + func, |
| 143 | + offset_calc); |
| 144 | +} |
| 145 | + |
| 146 | +template <typename T, typename Context, typename Functor> |
| 147 | +void LaunchBinaryElementwiseStrideKernel(const Context &dev_ctx, |
| 148 | + const DenseTensor &x, |
| 149 | + const DenseTensor &y, |
| 150 | + Functor func, |
| 151 | + int axis, |
| 152 | + DenseTensor *out) { |
| 153 | + std::vector<const DenseTensor *> inputs = {&x, &y}; |
| 154 | + std::vector<DenseTensor *> outputs = {out}; |
| 155 | + dev_ctx.template Alloc<T>(out); |
| 156 | + BinaryStrideBroadcastKernel<T, Context>( |
| 157 | + dev_ctx, inputs, &outputs, func, axis); |
| 158 | +} |
| 159 | + |
| 160 | +template <typename Context> |
| 161 | +phi::DenseTensor Tensor2Contiguous(const Context &dev_ctx, |
| 162 | + const phi::DenseTensor &tensor) { |
| 163 | + phi::DenseTensor dense_out; |
| 164 | + phi::MetaTensor meta_input(tensor); |
| 165 | + phi::MetaTensor meta_out(&dense_out); |
| 166 | + UnchangedInferMeta(meta_input, &meta_out); |
| 167 | + PD_VISIT_ALL_TYPES(tensor.dtype(), "Tensor2Contiguous", ([&] { |
| 168 | + phi::ContiguousKernel<data_t, Context>( |
| 169 | + dev_ctx, tensor, &dense_out); |
| 170 | + })); |
| 171 | + return dense_out; |
| 172 | +} |
| 173 | + |
| 174 | +#define DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(name, functor_name) \ |
| 175 | + template <typename T, typename Context> \ |
| 176 | + void name##StrideKernel(const Context &dev_ctx, \ |
| 177 | + const DenseTensor &x, \ |
| 178 | + const DenseTensor &y, \ |
| 179 | + DenseTensor *out) { \ |
| 180 | + if (!FLAGS_use_stride_kernel) { \ |
| 181 | + PADDLE_THROW(common::errors::Fatal( \ |
| 182 | + "FLAGS_use_stride_kernel is closed. Strided kernel " \ |
| 183 | + "be called, something wrong has happened!")); \ |
| 184 | + } \ |
| 185 | + DenseTensor x_; \ |
| 186 | + DenseTensor y_; \ |
| 187 | + if (!FLAGS_use_stride_compute_kernel || x.offset() != 0 || \ |
| 188 | + y.offset() != 0) { \ |
| 189 | + if (!x.meta().is_contiguous() || x.offset() != 0) { \ |
| 190 | + x_ = Tensor2Contiguous<Context>(dev_ctx, x); \ |
| 191 | + } else { \ |
| 192 | + x_ = x; \ |
| 193 | + } \ |
| 194 | + if (!y.meta().is_contiguous() || y.offset() != 0) { \ |
| 195 | + y_ = Tensor2Contiguous<Context>(dev_ctx, y); \ |
| 196 | + } else { \ |
| 197 | + y_ = y; \ |
| 198 | + } \ |
| 199 | + } else { \ |
| 200 | + x_ = x; \ |
| 201 | + y_ = y; \ |
| 202 | + } \ |
| 203 | + if (x_.meta().is_contiguous() && y_.meta().is_contiguous()) { \ |
| 204 | + auto meta = out->meta(); \ |
| 205 | + meta.strides = meta.calc_strides(out->dims()); \ |
| 206 | + out->set_meta(meta); \ |
| 207 | + phi::name##Kernel<T, Context>(dev_ctx, x_, y_, out); \ |
| 208 | + return; \ |
| 209 | + } \ |
| 210 | + if (!FLAGS_use_stride_compute_kernel) { \ |
| 211 | + PADDLE_THROW( \ |
| 212 | + common::errors::Fatal("FLAGS_use_stride_compute_kernel is closed. " \ |
| 213 | + "Kernel using DenseTensorIterator " \ |
| 214 | + "be called, something wrong has happened!")); \ |
| 215 | + } \ |
| 216 | + LaunchBinaryElementwiseStrideKernel<T, Context>( \ |
| 217 | + dev_ctx, x_, y_, funcs::functor_name##Functor<T>(), -1, out); \ |
| 218 | + } |
| 219 | + |
| 220 | +DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(Maximum, Maximum) |
| 221 | +DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(Minimum, Minimum) |
| 222 | +DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(FloorDivide, FloorDivide) |
| 223 | +DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(Heaviside, ElementwiseHeaviside) |
| 224 | +DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(FMax, FMax) |
| 225 | +DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(FMin, FMin) |
| 226 | + |
| 227 | +} // namespace phi |
| 228 | + |
| 229 | +using float16 = phi::dtype::float16; |
| 230 | +using bfloat16 = phi::dtype::bfloat16; |
| 231 | +using complex64 = ::phi::dtype::complex<float>; |
| 232 | +using complex128 = ::phi::dtype::complex<double>; |
| 233 | + |
| 234 | +PD_REGISTER_KERNEL(maximum, |
| 235 | + GPU, |
| 236 | + STRIDED, |
| 237 | + phi::MaximumStrideKernel, |
| 238 | + float, |
| 239 | + double, |
| 240 | + int, |
| 241 | + int64_t, |
| 242 | + phi::dtype::float16, |
| 243 | + phi::dtype::bfloat16) {} |
| 244 | + |
| 245 | +PD_REGISTER_KERNEL(minimum, |
| 246 | + GPU, |
| 247 | + STRIDED, |
| 248 | + phi::MinimumStrideKernel, |
| 249 | + float, |
| 250 | + double, |
| 251 | + int, |
| 252 | + int64_t, |
| 253 | + phi::dtype::float16, |
| 254 | + phi::dtype::bfloat16) {} |
| 255 | + |
| 256 | +PD_REGISTER_KERNEL(floor_divide, |
| 257 | + GPU, |
| 258 | + STRIDED, |
| 259 | + phi::FloorDivideStrideKernel, |
| 260 | + uint8_t, |
| 261 | + int8_t, |
| 262 | + int16_t, |
| 263 | + int, |
| 264 | + int64_t, |
| 265 | + float, |
| 266 | + double, |
| 267 | + phi::dtype::float16, |
| 268 | + phi::dtype::bfloat16) {} |
| 269 | + |
| 270 | +PD_REGISTER_KERNEL(heaviside, |
| 271 | + GPU, |
| 272 | + STRIDED, |
| 273 | + phi::HeavisideStrideKernel, |
| 274 | + float, |
| 275 | + double, |
| 276 | + int, |
| 277 | + float16, |
| 278 | + bfloat16, |
| 279 | + int64_t) {} |
| 280 | + |
| 281 | +PD_REGISTER_KERNEL(fmax, |
| 282 | + GPU, |
| 283 | + STRIDED, |
| 284 | + phi::FMaxStrideKernel, |
| 285 | + float, |
| 286 | + double, |
| 287 | + int, |
| 288 | + float16, |
| 289 | + bfloat16, |
| 290 | + int64_t) {} |
| 291 | + |
| 292 | +PD_REGISTER_KERNEL(fmin, |
| 293 | + GPU, |
| 294 | + STRIDED, |
| 295 | + phi::FMinStrideKernel, |
| 296 | + float, |
| 297 | + double, |
| 298 | + int, |
| 299 | + float16, |
| 300 | + bfloat16, |
| 301 | + int64_t) {} |
| 302 | + |
| 303 | +#endif |
0 commit comments