Skip to content

Commit 70b436a

Browse files
[Stride] Integrate more binary elementwise operators into DenseTensorIterator, Part 2: maximum / minimum / floordiv / heaviside / fmax / fmin (#74740)
* add binary_elementwise_part2 * allow merge * allow merge * refine
1 parent 2ae2249 commit 70b436a

File tree

7 files changed

+1028
-0
lines changed

7 files changed

+1028
-0
lines changed
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
16+
17+
#include "paddle/common/flags.h"
18+
#include "paddle/phi/backends/gpu/gpu_context.h"
19+
#include "paddle/phi/core/kernel_registry.h"
20+
#include "paddle/phi/kernels/contiguous_kernel.h"
21+
#include "paddle/phi/kernels/funcs/broadcast_function.h"
22+
#include "paddle/phi/kernels/funcs/dense_tensor_iterator.h"
23+
#include "paddle/phi/kernels/funcs/elementwise_base.h"
24+
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
25+
#include "paddle/phi/kernels/funcs/index_elementwise.cu.h"
26+
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
27+
28+
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
29+
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
30+
31+
#endif
32+
33+
COMMON_DECLARE_bool(use_stride_kernel);
34+
COMMON_DECLARE_bool(use_stride_compute_kernel);
35+
36+
namespace phi {
37+
template <typename Functor,
38+
typename OutT,
39+
int Arity,
40+
int NumOuts,
41+
int VecSize,
42+
int vt>
43+
__global__ void BinaryElementwiseKernel(
44+
Array<const _ptr_ char *__restrict__, Arity> ins,
45+
Array<_ptr_ OutT *, NumOuts> outs,
46+
uint32_t numel,
47+
int read_lens,
48+
Functor func,
49+
funcs::OffsetCalculator<Arity + NumOuts> offset_calc) {
50+
int64_t tid = THREAD_ID_X;
51+
int64_t nv = BLOCK_NUM_X * vt;
52+
int64_t idx = nv * BLOCK_ID_X + tid;
53+
#pragma unroll
54+
for (int i = 0; i < vt; i++) {
55+
if (idx < numel) {
56+
auto offsets = offset_calc.get(idx);
57+
using Traits = phi::funcs::FunctionTraits<Functor>;
58+
using ArgsT = typename Traits::ArgsTuple;
59+
__simd__ ArgsT args[VecSize];
60+
__simd__ ConditionalT<OutT, NumOuts> result[VecSize];
61+
std::get<0>(args[idx]) =
62+
*(reinterpret_cast<const _ptr_ std::tuple_element_t<0, ArgsT> *>(
63+
reinterpret_cast<const _ptr_ char *>(ins[0]) + offsets[1]));
64+
std::get<1>(args[idx]) =
65+
*(reinterpret_cast<const _ptr_ std::tuple_element_t<1, ArgsT> *>(
66+
reinterpret_cast<const _ptr_ char *>(ins[1]) + offsets[2]));
67+
funcs::SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
68+
VecSize,
69+
Functor,
70+
ArgsT,
71+
Arity>()(
72+
func, args, result, read_lens);
73+
char *out_ptr = reinterpret_cast<char *>(outs[0]) + offsets[0];
74+
*reinterpret_cast<OutT *>(out_ptr) =
75+
*reinterpret_cast<const OutT *>(&(result[0]));
76+
idx += BLOCK_NUM_X;
77+
}
78+
}
79+
}
80+
81+
// Not Support Vectorized Kernel For Now
82+
#define VEC_SIZE 1
83+
84+
template <typename OutT, typename Context, typename Functor, int NumOuts = 1>
85+
void BinaryStrideBroadcastKernel(const Context &dev_ctx,
86+
const std::vector<const DenseTensor *> &ins,
87+
std::vector<DenseTensor *> *outs,
88+
Functor func,
89+
int axis = -1) {
90+
using Traits = phi::funcs::FunctionTraits<Functor>;
91+
const int Arity = Traits::arity;
92+
for (auto i = 0; i < outs->size(); ++i) {
93+
if (i > 0) {
94+
PADDLE_ENFORCE_EQ(
95+
(*outs)[i]->dims(),
96+
(*outs)[0]->dims(),
97+
common::errors::InvalidArgument(
98+
"The shape of each output tensor shall be identical yet, but "
99+
"%d-th output tensor`s shape is not.",
100+
i));
101+
}
102+
dev_ctx.template Alloc<OutT>((*outs)[i]);
103+
}
104+
if ((*outs)[0]->numel() == 0) {
105+
return;
106+
}
107+
int max_rank = 0;
108+
int min_rank = phi::DDim::kMaxRank;
109+
for (auto *in : ins) {
110+
max_rank = std::max(max_rank, in->dims().size());
111+
min_rank = std::min(min_rank, in->dims().size());
112+
}
113+
if (ins.size() == 1) {
114+
max_rank = std::max(max_rank, (*outs)[0]->dims().size());
115+
}
116+
axis = axis == -1 ? max_rank - min_rank : axis;
117+
auto classifier =
118+
funcs::BroadcastTypeClassifier<OutT, Functor, Arity, NumOuts>(
119+
ins, outs, axis);
120+
DenseTensorIteratorConfig config;
121+
config.add_output(*((*outs)[0]));
122+
config.add_const_input(*(ins[0]));
123+
config.add_const_input(*(ins[1]));
124+
DenseTensorIterator iter = config.build();
125+
const int &numel = iter.numel();
126+
funcs::OffsetCalculator offset_calc = funcs::make_offset_calculator<3>(iter);
127+
constexpr int unroll_factor = sizeof(OutT) >= 4 ? 2 : 4;
128+
auto stream = dev_ctx.stream();
129+
auto threads = 128;
130+
auto blocks = (numel + 128 * unroll_factor - 1) / (128 * unroll_factor);
131+
int vec_size = VEC_SIZE;
132+
BinaryElementwiseKernel<Functor,
133+
OutT,
134+
Arity,
135+
NumOuts,
136+
VEC_SIZE,
137+
unroll_factor>
138+
<<<blocks, threads, 0, stream>>>(classifier.ins_data,
139+
classifier.outs_data,
140+
numel,
141+
vec_size,
142+
func,
143+
offset_calc);
144+
}
145+
146+
template <typename T, typename Context, typename Functor>
147+
void LaunchBinaryElementwiseStrideKernel(const Context &dev_ctx,
148+
const DenseTensor &x,
149+
const DenseTensor &y,
150+
Functor func,
151+
int axis,
152+
DenseTensor *out) {
153+
std::vector<const DenseTensor *> inputs = {&x, &y};
154+
std::vector<DenseTensor *> outputs = {out};
155+
dev_ctx.template Alloc<T>(out);
156+
BinaryStrideBroadcastKernel<T, Context>(
157+
dev_ctx, inputs, &outputs, func, axis);
158+
}
159+
160+
template <typename Context>
161+
phi::DenseTensor Tensor2Contiguous(const Context &dev_ctx,
162+
const phi::DenseTensor &tensor) {
163+
phi::DenseTensor dense_out;
164+
phi::MetaTensor meta_input(tensor);
165+
phi::MetaTensor meta_out(&dense_out);
166+
UnchangedInferMeta(meta_input, &meta_out);
167+
PD_VISIT_ALL_TYPES(tensor.dtype(), "Tensor2Contiguous", ([&] {
168+
phi::ContiguousKernel<data_t, Context>(
169+
dev_ctx, tensor, &dense_out);
170+
}));
171+
return dense_out;
172+
}
173+
174+
#define DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(name, functor_name) \
175+
template <typename T, typename Context> \
176+
void name##StrideKernel(const Context &dev_ctx, \
177+
const DenseTensor &x, \
178+
const DenseTensor &y, \
179+
DenseTensor *out) { \
180+
if (!FLAGS_use_stride_kernel) { \
181+
PADDLE_THROW(common::errors::Fatal( \
182+
"FLAGS_use_stride_kernel is closed. Strided kernel " \
183+
"be called, something wrong has happened!")); \
184+
} \
185+
DenseTensor x_; \
186+
DenseTensor y_; \
187+
if (!FLAGS_use_stride_compute_kernel || x.offset() != 0 || \
188+
y.offset() != 0) { \
189+
if (!x.meta().is_contiguous() || x.offset() != 0) { \
190+
x_ = Tensor2Contiguous<Context>(dev_ctx, x); \
191+
} else { \
192+
x_ = x; \
193+
} \
194+
if (!y.meta().is_contiguous() || y.offset() != 0) { \
195+
y_ = Tensor2Contiguous<Context>(dev_ctx, y); \
196+
} else { \
197+
y_ = y; \
198+
} \
199+
} else { \
200+
x_ = x; \
201+
y_ = y; \
202+
} \
203+
if (x_.meta().is_contiguous() && y_.meta().is_contiguous()) { \
204+
auto meta = out->meta(); \
205+
meta.strides = meta.calc_strides(out->dims()); \
206+
out->set_meta(meta); \
207+
phi::name##Kernel<T, Context>(dev_ctx, x_, y_, out); \
208+
return; \
209+
} \
210+
if (!FLAGS_use_stride_compute_kernel) { \
211+
PADDLE_THROW( \
212+
common::errors::Fatal("FLAGS_use_stride_compute_kernel is closed. " \
213+
"Kernel using DenseTensorIterator " \
214+
"be called, something wrong has happened!")); \
215+
} \
216+
LaunchBinaryElementwiseStrideKernel<T, Context>( \
217+
dev_ctx, x_, y_, funcs::functor_name##Functor<T>(), -1, out); \
218+
}
219+
220+
DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(Maximum, Maximum)
221+
DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(Minimum, Minimum)
222+
DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(FloorDivide, FloorDivide)
223+
DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(Heaviside, ElementwiseHeaviside)
224+
DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(FMax, FMax)
225+
DEFINE_CUDA_MATH_ELEMENTWISE_STRIDE_OP(FMin, FMin)
226+
227+
} // namespace phi
228+
229+
using float16 = phi::dtype::float16;
230+
using bfloat16 = phi::dtype::bfloat16;
231+
using complex64 = ::phi::dtype::complex<float>;
232+
using complex128 = ::phi::dtype::complex<double>;
233+
234+
PD_REGISTER_KERNEL(maximum,
235+
GPU,
236+
STRIDED,
237+
phi::MaximumStrideKernel,
238+
float,
239+
double,
240+
int,
241+
int64_t,
242+
phi::dtype::float16,
243+
phi::dtype::bfloat16) {}
244+
245+
PD_REGISTER_KERNEL(minimum,
246+
GPU,
247+
STRIDED,
248+
phi::MinimumStrideKernel,
249+
float,
250+
double,
251+
int,
252+
int64_t,
253+
phi::dtype::float16,
254+
phi::dtype::bfloat16) {}
255+
256+
PD_REGISTER_KERNEL(floor_divide,
257+
GPU,
258+
STRIDED,
259+
phi::FloorDivideStrideKernel,
260+
uint8_t,
261+
int8_t,
262+
int16_t,
263+
int,
264+
int64_t,
265+
float,
266+
double,
267+
phi::dtype::float16,
268+
phi::dtype::bfloat16) {}
269+
270+
PD_REGISTER_KERNEL(heaviside,
271+
GPU,
272+
STRIDED,
273+
phi::HeavisideStrideKernel,
274+
float,
275+
double,
276+
int,
277+
float16,
278+
bfloat16,
279+
int64_t) {}
280+
281+
PD_REGISTER_KERNEL(fmax,
282+
GPU,
283+
STRIDED,
284+
phi::FMaxStrideKernel,
285+
float,
286+
double,
287+
int,
288+
float16,
289+
bfloat16,
290+
int64_t) {}
291+
292+
PD_REGISTER_KERNEL(fmin,
293+
GPU,
294+
STRIDED,
295+
phi::FMinStrideKernel,
296+
float,
297+
double,
298+
int,
299+
float16,
300+
bfloat16,
301+
int64_t) {}
302+
303+
#endif

0 commit comments

Comments
 (0)