Skip to content

Commit c077a6d

Browse files
authored
Feature/support int64 for sum (#5832)
* Support int64 for sum op * Refine code
1 parent e800c0d commit c077a6d

File tree

5 files changed

+24
-2
lines changed

5 files changed

+24
-2
lines changed

paddle/operators/math/selected_rows_functor.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ struct SelectedRowsAddTo<platform::CPUPlace, T> {
145145

146146
template struct SelectedRowsAddTo<platform::CPUPlace, float>;
147147
template struct SelectedRowsAddTo<platform::CPUPlace, double>;
148+
template struct SelectedRowsAddTo<platform::CPUPlace, int>;
149+
template struct SelectedRowsAddTo<platform::CPUPlace, int64_t>;
148150

149151
template <typename T>
150152
struct SelectedRowsAddToTensor<platform::CPUPlace, T> {
@@ -175,6 +177,8 @@ struct SelectedRowsAddToTensor<platform::CPUPlace, T> {
175177

176178
template struct SelectedRowsAddToTensor<platform::CPUPlace, float>;
177179
template struct SelectedRowsAddToTensor<platform::CPUPlace, double>;
180+
template struct SelectedRowsAddToTensor<platform::CPUPlace, int>;
181+
template struct SelectedRowsAddToTensor<platform::CPUPlace, int64_t>;
178182

179183
} // namespace math
180184
} // namespace operators

paddle/operators/math/selected_rows_functor.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ struct SelectedRowsAddTo<platform::GPUPlace, T> {
173173

174174
template struct SelectedRowsAddTo<platform::GPUPlace, float>;
175175
template struct SelectedRowsAddTo<platform::GPUPlace, double>;
176+
template struct SelectedRowsAddTo<platform::GPUPlace, int>;
177+
template struct SelectedRowsAddTo<platform::GPUPlace, int64_t>;
176178

177179
namespace {
178180
template <typename T, int block_size>
@@ -223,6 +225,8 @@ struct SelectedRowsAddToTensor<platform::GPUPlace, T> {
223225

224226
template struct SelectedRowsAddToTensor<platform::GPUPlace, float>;
225227
template struct SelectedRowsAddToTensor<platform::GPUPlace, double>;
228+
template struct SelectedRowsAddToTensor<platform::GPUPlace, int>;
229+
template struct SelectedRowsAddToTensor<platform::GPUPlace, int64_t>;
226230

227231
} // namespace math
228232
} // namespace operators

paddle/operators/sum_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,6 @@ namespace ops = paddle::operators;
176176
REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker,
177177
ops::SumOpVarTypeInference);
178178
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>,
179-
ops::SumKernel<paddle::platform::CPUPlace, double>);
179+
ops::SumKernel<paddle::platform::CPUPlace, double>,
180+
ops::SumKernel<paddle::platform::CPUPlace, int>,
181+
ops::SumKernel<paddle::platform::CPUPlace, int64_t>);

paddle/operators/sum_op.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@ limitations under the License. */
1414

1515
namespace ops = paddle::operators;
1616
REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel<paddle::platform::GPUPlace, float>,
17-
ops::SumKernel<paddle::platform::GPUPlace, double>);
17+
ops::SumKernel<paddle::platform::GPUPlace, double>,
18+
ops::SumKernel<paddle::platform::GPUPlace, int>,
19+
ops::SumKernel<paddle::platform::GPUPlace, int64_t>);

paddle/platform/cuda_helper.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,16 @@ constexpr int PADDLE_CUDA_NUM_THREADS = 512;
3131

3232
// For atomicAdd.
3333
USE_CUDA_ATOMIC(Add, float);
34+
USE_CUDA_ATOMIC(Add, int);
35+
USE_CUDA_ATOMIC(Add, unsigned int);
36+
USE_CUDA_ATOMIC(Add, unsigned long long int);
37+
38+
CUDA_ATOMIC_WRAPPER(Add, int64_t) {
39+
static_assert(sizeof(int64_t) == sizeof(long long int),
40+
"long long should be int64");
41+
return CudaAtomicAdd(reinterpret_cast<unsigned long long int*>(address),
42+
static_cast<unsigned long long int>(val));
43+
}
3444

3545
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
3646
USE_CUDA_ATOMIC(Add, double);

0 commit comments

Comments
 (0)