Skip to content
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
7d58b91
Merge pull request #1 from PaddlePaddle/develop
AnnaTrainingG Mar 25, 2021
1021e08
Merge pull request #2 from PaddlePaddle/develop
AnnaTrainingG Mar 29, 2021
43f53fe
Merge pull request #3 from PaddlePaddle/develop
AnnaTrainingG Apr 19, 2021
d25ab26
Merge pull request #4 from PaddlePaddle/develop
AnnaTrainingG May 7, 2021
a244f18
max_min_prod_all_any
May 18, 2021
af4db5d
Update reduce_any_op.cu
AnnaTrainingG May 24, 2021
d804066
modified
AnnaTrainingG May 24, 2021
c7826e8
Merge branch 'reduce_max_min_prod_all_any' of https://github.com/niul…
AnnaTrainingG May 24, 2021
6ea9e9a
copyright
AnnaTrainingG May 24, 2021
8c8717f
Merge pull request #5 from PaddlePaddle/develop
AnnaTrainingG May 25, 2021
ff0a6e9
modified and {} for loop
AnnaTrainingG May 25, 2021
7ddaf91
max_min_prod_all_any
May 18, 2021
a43af7d
Update reduce_any_op.cu
AnnaTrainingG May 24, 2021
0a70b82
modified
AnnaTrainingG May 24, 2021
c91b26b
copyright
AnnaTrainingG May 24, 2021
54651e0
modified and {} for loop
AnnaTrainingG May 25, 2021
37fbd4c
Merge branch 'reduce_max_min_prod_all_any' of https://github.com/niul…
May 25, 2021
35411f7
add notes for reduce_op.cuh
May 25, 2021
8cea954
update
May 25, 2021
a719c3c
update
May 25, 2021
2e8ad8f
update
May 25, 2021
a60b90a
fix a bug in reduce_Op.cuh
AnnaTrainingG May 27, 2021
4bd9644
reset reduce_any and reduce_all
May 28, 2021
bf701a2
delete __forceinline__ in reduce_functor_op.h
May 31, 2021
6174b50
from DEVICE to HOSTTDEVICE
May 31, 2021
59c32d6
add DataBound struct for reduce_max and reduce_min
AnnaTrainingG Jun 1, 2021
790173a
Update reduce_functor_op.h
AnnaTrainingG Jun 1, 2021
8700894
update TensorReduceFunc
AnnaTrainingG Jun 2, 2021
9e32b0f
add reduce_functor_op.h pragma once
Jun 3, 2021
17dcaf8
update BOUND and kMaxTHread
AnnaTrainingG Jun 7, 2021
cb2b619
modified max min prod for cu.h
AnnaTrainingG Jun 9, 2021
6541ffb
update for struct
AnnaTrainingG Jun 9, 2021
719e435
code style reduce_op.cu.h
AnnaTrainingG Jun 9, 2021
5045a49
device to HOSTDEVICE
AnnaTrainingG Jun 10, 2021
a5dedb1
Merge branch 'reduce_max_min_prod_all_any' of https://github.com/niul…
AnnaTrainingG Jun 10, 2021
fb69e3d
ReduceCudaKernel
AnnaTrainingG Jun 15, 2021
24633a5
Merge pull request #15 from PaddlePaddle/develop
AnnaTrainingG Jun 15, 2021
b841b34
REDUCE_SPLIT_BOUNDARY
AnnaTrainingG Jun 15, 2021
1fda4d5
Update reduce_op.cu.h
AnnaTrainingG Jun 15, 2021
c85ca05
rename reduceTensorFunctor
AnnaTrainingG Jun 16, 2021
9cc8ac3
rename TensorReduceFunc
AnnaTrainingG Jun 16, 2021
140779d
delete HOSTDEVICE
AnnaTrainingG Jun 17, 2021
fa3411c
add left_num * grid.z * grid.y
AnnaTrainingG Jun 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 68 additions & 16 deletions paddle/fluid/operators/reduce_ops/reduce_functor_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,98 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <string>
#include <vector>
#include <math.h>
#include <limits>

#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#endif

namespace paddle {
namespace operators {

template <typename T>
template <typename Tx, typename Ty = Tx>
struct CustomMin {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
using Transformer = detail::IdentityFunctor<Tx>;

HOSTDEVICE __forceinline__ Ty initial() {
return std::numeric_limits<Ty>::max();
}

__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return (b < a) ? b : a;
}
};

template <typename T>
template <typename Tx, typename Ty = Tx>
struct CustomMax {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
using Transformer = detail::IdentityFunctor<Tx>;

HOSTDEVICE __forceinline__ Ty initial() {
return std::numeric_limits<Ty>::lowest();
}

__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return (b > a) ? b : a;
}
};

template <typename T>
// for cub::Reduce
template <typename Tx, typename Ty = Tx>
struct CustomSum {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
using Transformer = detail::IdentityFunctor<Tx, Ty>;

HOSTDEVICE __forceinline__ Ty initial() { return static_cast<Ty>(0.0f); }

__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b + a;
}
};

template <typename Tx, typename Ty = Tx>
struct CustomMean {
using Transformer = detail::DivideFunctor<Tx>;

HOSTDEVICE __forceinline__ Ty initial() { return static_cast<Ty>(0.0f); }

__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b + a;
}
};

template <typename T>
template <typename Tx, typename Ty = Tx>
struct CustomMul {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
using Transformer = detail::IdentityFunctor<Tx>;

HOSTDEVICE __forceinline__ Ty initial() { return static_cast<Ty>(1.0f); }

__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b * a;
}
};

template <typename Tx, typename Ty = Tx>
struct CustomLogicalOr {
using Transformer = detail::IdentityFunctor<Tx>;

HOSTDEVICE __forceinline__ Ty initial() { return static_cast<Ty>(false); }

__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b || a;
}
};

template <typename Tx, typename Ty = Tx>
struct CustomLogicalAnd {
using Transformer = detail::IdentityFunctor<Tx>;

HOSTDEVICE __forceinline__ Ty initial() { return static_cast<Ty>(true); }

__device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const {
return b && a;
}
};

} // namespace operators
} // namespace paddle
20 changes: 9 additions & 11 deletions paddle/fluid/operators/reduce_ops/reduce_max_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"

#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"

REGISTER_OP_CUDA_KERNEL(reduce_max,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::MaxFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::MaxFunctor>);
// reduce_max
REGISTER_OP_CUDA_KERNEL(
reduce_max, ops::ReduceCudaKernel<float, paddle::operators::CustomMax>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMax>,
ops::ReduceCudaKernel<int, paddle::operators::CustomMax>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMax>);
20 changes: 9 additions & 11 deletions paddle/fluid/operators/reduce_ops/reduce_min_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"

#include "paddle/fluid/operators/reduce_ops/reduce_min_max_op.h"

REGISTER_OP_CUDA_KERNEL(reduce_min,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
float, ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
double, ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int, ops::MinFunctor>,
ops::ReduceKernel<paddle::platform::CUDADeviceContext,
int64_t, ops::MinFunctor>);
// reduce_min
REGISTER_OP_CUDA_KERNEL(
reduce_min, ops::ReduceCudaKernel<float, paddle::operators::CustomMin>,
ops::ReduceCudaKernel<double, paddle::operators::CustomMin>,
ops::ReduceCudaKernel<int, paddle::operators::CustomMin>,
ops::ReduceCudaKernel<int64_t, paddle::operators::CustomMin>);
Loading