Skip to content

Commit 32549d9

Browse files
[Stride] Integrate more binary elementwise operators into DenseTensorIterator, Part 3: bitwise_and / bitwise_or / bitwise_xor / logical_and / logical_or / logical_xor (#74769)
* add support to binary_elementwise_part3 * refine * aloow merge * allow merge
1 parent 70b436a commit 32549d9

File tree

4 files changed

+919
-0
lines changed

4 files changed

+919
-0
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
16+
#include "paddle/phi/kernels/bitwise_kernel.h"
17+
#include "paddle/common/flags.h"
18+
#include "paddle/phi/backends/gpu/gpu_context.h"
19+
#include "paddle/phi/core/kernel_registry.h"
20+
#include "paddle/phi/kernels/funcs/bitwise_functors.h"
21+
#include "paddle/phi/kernels/stride/elementwise_stride_base.cu.h"
22+
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
23+
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
24+
#endif
25+
COMMON_DECLARE_bool(use_stride_kernel);
26+
COMMON_DECLARE_bool(use_stride_compute_kernel);
27+
namespace phi {
28+
#define DEFINE_CUDA_BINARY_ELEMENTWISE_STRIDE_OP(name) \
29+
template <typename T, typename Context> \
30+
void name##StrideKernel(const Context &dev_ctx, \
31+
const DenseTensor &x, \
32+
const DenseTensor &y, \
33+
DenseTensor *out) { \
34+
if (!FLAGS_use_stride_kernel) { \
35+
PADDLE_THROW(common::errors::Fatal( \
36+
"FLAGS_use_stride_kernel is closed. Strided kernel " \
37+
"be called, something wrong has happened!")); \
38+
} \
39+
DenseTensor x_; \
40+
DenseTensor y_; \
41+
if (!FLAGS_use_stride_compute_kernel || x.offset() != 0 || \
42+
y.offset() != 0) { \
43+
if (!x.meta().is_contiguous() || x.offset() != 0) { \
44+
x_ = Tensor2Contiguous<Context>(dev_ctx, x); \
45+
} else { \
46+
x_ = x; \
47+
} \
48+
if (!y.meta().is_contiguous() || y.offset() != 0) { \
49+
y_ = Tensor2Contiguous<Context>(dev_ctx, y); \
50+
} else { \
51+
y_ = y; \
52+
} \
53+
} else { \
54+
x_ = x; \
55+
y_ = y; \
56+
} \
57+
if (x_.meta().is_contiguous() && y_.meta().is_contiguous()) { \
58+
auto meta = out->meta(); \
59+
meta.strides = meta.calc_strides(out->dims()); \
60+
out->set_meta(meta); \
61+
phi::name##Kernel<T, Context>(dev_ctx, x_, y_, out); \
62+
return; \
63+
} \
64+
if (!FLAGS_use_stride_compute_kernel) { \
65+
PADDLE_THROW( \
66+
common::errors::Fatal("FLAGS_use_stride_compute_kernel is closed. " \
67+
"Kernel using DenseTensorIterator " \
68+
"be called, something wrong has happened!")); \
69+
} \
70+
LaunchBinaryElementwiseStrideKernel<T, Context>( \
71+
dev_ctx, x_, y_, funcs::name##Functor<T>(), -1, out); \
72+
}
73+
DEFINE_CUDA_BINARY_ELEMENTWISE_STRIDE_OP(BitwiseAnd)
74+
DEFINE_CUDA_BINARY_ELEMENTWISE_STRIDE_OP(BitwiseOr)
75+
DEFINE_CUDA_BINARY_ELEMENTWISE_STRIDE_OP(BitwiseXor)
76+
} // namespace phi
77+
using float16 = phi::dtype::float16;
78+
using bfloat16 = phi::dtype::bfloat16;
79+
using complex64 = ::phi::dtype::complex<float>;
80+
using complex128 = ::phi::dtype::complex<double>;
81+
PD_REGISTER_KERNEL(bitwise_and,
82+
GPU,
83+
STRIDED,
84+
phi::BitwiseAndStrideKernel,
85+
bool,
86+
uint8_t,
87+
int8_t,
88+
int16_t,
89+
int,
90+
int64_t) {}
91+
PD_REGISTER_KERNEL(bitwise_or,
92+
GPU,
93+
STRIDED,
94+
phi::BitwiseOrStrideKernel,
95+
bool,
96+
uint8_t,
97+
int8_t,
98+
int16_t,
99+
int,
100+
int64_t) {}
101+
PD_REGISTER_KERNEL(bitwise_xor,
102+
GPU,
103+
STRIDED,
104+
phi::BitwiseXorStrideKernel,
105+
bool,
106+
uint8_t,
107+
int8_t,
108+
int16_t,
109+
int,
110+
int64_t) {}
111+
#endif
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
18+
19+
#include "paddle/common/flags.h"
20+
#include "paddle/phi/backends/gpu/gpu_context.h"
21+
#include "paddle/phi/core/kernel_registry.h"
22+
#include "paddle/phi/kernels/contiguous_kernel.h"
23+
#include "paddle/phi/kernels/elementwise_add_kernel.h"
24+
#include "paddle/phi/kernels/funcs/broadcast_function.h"
25+
#include "paddle/phi/kernels/funcs/dense_tensor_iterator.h"
26+
#include "paddle/phi/kernels/funcs/elementwise_base.h"
27+
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
28+
#include "paddle/phi/kernels/funcs/index_elementwise.cu.h"
29+
#include "paddle/phi/kernels/impl/elementwise_kernel_impl.h"
30+
31+
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
32+
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
33+
34+
#endif
35+
36+
namespace phi {
37+
template <typename Functor,
38+
typename OutT,
39+
int Arity,
40+
int NumOuts,
41+
int VecSize,
42+
int vt>
43+
__global__ void BinaryElementwiseKernel(
44+
Array<const _ptr_ char *__restrict__, Arity> ins,
45+
Array<_ptr_ OutT *, NumOuts> outs,
46+
uint32_t numel,
47+
int read_lens,
48+
Functor func,
49+
funcs::OffsetCalculator<Arity + NumOuts> offset_calc) {
50+
int64_t tid = THREAD_ID_X;
51+
int64_t nv = BLOCK_NUM_X * vt;
52+
int64_t idx = nv * BLOCK_ID_X + tid;
53+
#pragma unroll
54+
for (int i = 0; i < vt; i++) {
55+
if (idx < numel) {
56+
auto offsets = offset_calc.get(idx);
57+
using Traits = phi::funcs::FunctionTraits<Functor>;
58+
using ArgsT = typename Traits::ArgsTuple;
59+
__simd__ ArgsT args[VecSize];
60+
__simd__ ConditionalT<OutT, NumOuts> result[VecSize];
61+
std::get<0>(args[idx]) =
62+
*(reinterpret_cast<const _ptr_ std::tuple_element_t<0, ArgsT> *>(
63+
reinterpret_cast<const _ptr_ char *>(ins[0]) + offsets[1]));
64+
std::get<1>(args[idx]) =
65+
*(reinterpret_cast<const _ptr_ std::tuple_element_t<1, ArgsT> *>(
66+
reinterpret_cast<const _ptr_ char *>(ins[1]) + offsets[2]));
67+
funcs::SameDimsElementwisePrimitiveCaller<ConditionalT<OutT, NumOuts>,
68+
VecSize,
69+
Functor,
70+
ArgsT,
71+
Arity>()(
72+
func, args, result, read_lens);
73+
char *out_ptr = reinterpret_cast<char *>(outs[0]) + offsets[0];
74+
*reinterpret_cast<OutT *>(out_ptr) =
75+
*reinterpret_cast<const OutT *>(&(result[0]));
76+
idx += BLOCK_NUM_X;
77+
}
78+
}
79+
}
80+
81+
// Not Support Vectorized Kernel For Now
82+
#define VEC_SIZE 1
83+
84+
template <typename OutT, typename Context, typename Functor, int NumOuts = 1>
85+
void BinaryStrideBroadcastKernel(const Context &dev_ctx,
86+
const std::vector<const DenseTensor *> &ins,
87+
std::vector<DenseTensor *> *outs,
88+
Functor func,
89+
int axis = -1) {
90+
using Traits = phi::funcs::FunctionTraits<Functor>;
91+
const int Arity = Traits::arity;
92+
for (auto i = 0; i < outs->size(); ++i) {
93+
if (i > 0) {
94+
PADDLE_ENFORCE_EQ(
95+
(*outs)[i]->dims(),
96+
(*outs)[0]->dims(),
97+
common::errors::InvalidArgument(
98+
"The shape of each output tensor shall be identical yet, but "
99+
"%d-th output tensor`s shape is not.",
100+
i));
101+
}
102+
dev_ctx.template Alloc<OutT>((*outs)[i]);
103+
}
104+
if ((*outs)[0]->numel() == 0) {
105+
return;
106+
}
107+
int max_rank = 0;
108+
int min_rank = phi::DDim::kMaxRank;
109+
for (auto *in : ins) {
110+
max_rank = std::max(max_rank, in->dims().size());
111+
min_rank = std::min(min_rank, in->dims().size());
112+
}
113+
if (ins.size() == 1) {
114+
max_rank = std::max(max_rank, (*outs)[0]->dims().size());
115+
}
116+
axis = axis == -1 ? max_rank - min_rank : axis;
117+
auto classifier =
118+
funcs::BroadcastTypeClassifier<OutT, Functor, Arity, NumOuts>(
119+
ins, outs, axis);
120+
DenseTensorIteratorConfig config;
121+
config.add_output(*((*outs)[0]));
122+
config.add_const_input(*(ins[0]));
123+
config.add_const_input(*(ins[1]));
124+
DenseTensorIterator iter = config.build();
125+
const int &numel = iter.numel();
126+
funcs::OffsetCalculator offset_calc = funcs::make_offset_calculator<3>(iter);
127+
constexpr int unroll_factor = sizeof(OutT) >= 4 ? 2 : 4;
128+
auto stream = dev_ctx.stream();
129+
auto threads = 128;
130+
auto blocks = (numel + 128 * unroll_factor - 1) / (128 * unroll_factor);
131+
int vec_size = VEC_SIZE;
132+
BinaryElementwiseKernel<Functor,
133+
OutT,
134+
Arity,
135+
NumOuts,
136+
VEC_SIZE,
137+
unroll_factor>
138+
<<<blocks, threads, 0, stream>>>(classifier.ins_data,
139+
classifier.outs_data,
140+
numel,
141+
vec_size,
142+
func,
143+
offset_calc);
144+
}
145+
146+
template <typename T, typename Context, typename Functor>
147+
void LaunchBinaryElementwiseStrideKernel(const Context &dev_ctx,
148+
const DenseTensor &x,
149+
const DenseTensor &y,
150+
Functor func,
151+
int axis,
152+
DenseTensor *out) {
153+
std::vector<const DenseTensor *> inputs = {&x, &y};
154+
std::vector<DenseTensor *> outputs = {out};
155+
dev_ctx.template Alloc<T>(out);
156+
BinaryStrideBroadcastKernel<T, Context>(
157+
dev_ctx, inputs, &outputs, func, axis);
158+
}
159+
160+
template <typename Context>
161+
phi::DenseTensor Tensor2Contiguous(const Context &dev_ctx,
162+
const phi::DenseTensor &tensor) {
163+
phi::DenseTensor dense_out;
164+
phi::MetaTensor meta_input(tensor);
165+
phi::MetaTensor meta_out(&dense_out);
166+
UnchangedInferMeta(meta_input, &meta_out);
167+
PD_VISIT_ALL_TYPES(tensor.dtype(), "Tensor2Contiguous", ([&] {
168+
phi::ContiguousKernel<data_t, Context>(
169+
dev_ctx, tensor, &dense_out);
170+
}));
171+
return dense_out;
172+
}
173+
174+
} // namespace phi
175+
176+
#endif

0 commit comments

Comments
 (0)