Skip to content

Commit 9567cbd

Browse files
authored
[cherry-pick 2.1.1]2.1/fix concat (#33383)
* add unit8 for concat (#32850) * add bool type for tril api (#33402)
1 parent 1444090 commit 9567cbd

File tree

12 files changed

+32
-15
lines changed

12 files changed

+32
-15
lines changed

paddle/fluid/operators/concat_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,8 @@ REGISTER_OP_CPU_KERNEL(
233233
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int64_t>,
234234
ops::ConcatKernel<paddle::platform::CPUDeviceContext,
235235
paddle::platform::float16>,
236-
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>);
236+
ops::ConcatKernel<paddle::platform::CPUDeviceContext, int>,
237+
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>);
237238
REGISTER_OP_CPU_KERNEL(
238239
concat_grad,
239240
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, double>,
@@ -242,4 +243,5 @@ REGISTER_OP_CPU_KERNEL(
242243
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
243244
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext,
244245
paddle::platform::float16>,
245-
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>);
246+
ops::ConcatGradKernel<paddle::platform::CPUDeviceContext, int>,
247+
ops::ConcatKernel<paddle::platform::CPUDeviceContext, uint8_t>);

paddle/fluid/operators/concat_op.cu.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@ REGISTER_OP_CUDA_KERNEL(
2323
ops::ConcatKernel<paddle::platform::CUDADeviceContext, bool>,
2424
ops::ConcatKernel<paddle::platform::CUDADeviceContext, plat::float16>,
2525
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int64_t>,
26-
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>);
26+
ops::ConcatKernel<paddle::platform::CUDADeviceContext, int>,
27+
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>);
2728
REGISTER_OP_CUDA_KERNEL(
2829
concat_grad,
2930
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, double>,
3031
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, float>,
3132
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, bool>,
3233
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
3334
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
34-
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>);
35+
ops::ConcatGradKernel<paddle::platform::CUDADeviceContext, int>,
36+
ops::ConcatKernel<paddle::platform::CUDADeviceContext, uint8_t>);

paddle/fluid/operators/reduce_ops/reduce_mean_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp,
100100
ops::ReduceMeanDoubleGradOpBaseMaker,
101101
ops::ReduceMeanGradNoNeedBufferVarInferer);
102102
REGISTER_OP_CPU_KERNEL(reduce_mean,
103+
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
104+
bool, ops::MeanFunctor>,
103105
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
104106
float, ops::MeanFunctor>,
105107
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
@@ -110,5 +112,6 @@ using CPUReduceMeanGradKernel =
110112
ops::ReduceGradKernel<paddle::platform::CPUDeviceContext, T,
111113
ops::MeanGradFunctor, true>;
112114

113-
REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel<float>,
115+
REGISTER_OP_CPU_KERNEL(reduce_mean_grad, CPUReduceMeanGradKernel<bool>,
116+
CPUReduceMeanGradKernel<float>,
114117
CPUReduceMeanGradKernel<double>);

paddle/fluid/operators/reduce_ops/reduce_mean_op.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,6 @@ class ReduceMeanKernel : public framework::OpKernel<T> {
6565
} // namespace operators
6666
} // namespace paddle
6767

68-
REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<float>,
68+
REGISTER_OP_CUDA_KERNEL(reduce_mean, ops::ReduceMeanKernel<bool>,
69+
ops::ReduceMeanKernel<float>,
6970
ops::ReduceMeanKernel<double>);

paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ using CUDAReduceMeanGradKernel =
2020
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
2121
ops::MeanGradFunctor, true>;
2222

23-
REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<float>,
23+
REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel<bool>,
24+
CUDAReduceMeanGradKernel<float>,
2425
CUDAReduceMeanGradKernel<double>);

paddle/fluid/operators/reduce_ops/reduce_sum_op.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,10 @@ REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
109109
ops::ReduceSumGradNoNeedBufferVarInferer);
110110

111111
REGISTER_OP_CPU_KERNEL(
112-
reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
112+
reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, bool,
113113
ops::SumFunctor>,
114+
ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
115+
ops::SumFunctor>,
114116
ops::ReduceKernel<paddle::platform::CPUDeviceContext, double,
115117
ops::SumFunctor>,
116118
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::SumFunctor>,
@@ -128,7 +130,8 @@ using CPUReduceSumGradKernel =
128130
ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, T,
129131
ops::SumGradFunctor, true>;
130132

131-
REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<float>,
133+
REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<bool>,
134+
CPUReduceSumGradKernel<float>,
132135
CPUReduceSumGradKernel<double>,
133136
CPUReduceSumGradKernel<int>,
134137
CPUReduceSumGradKernel<int64_t>,

paddle/fluid/operators/reduce_ops/reduce_sum_op.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ class ReduceSumKernel : public framework::OpKernel<T> {
7070
} // namespace operators
7171
} // namespace paddle
7272

73-
REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<float>,
73+
REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<bool>,
74+
ops::ReduceSumKernel<float>,
7475
ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>,
7576
ops::ReduceSumKernel<int64_t>,
7677
ops::ReduceSumKernel<paddle::platform::complex64>,

paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ using CUDAReduceSumGradKernel =
2020
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
2121
ops::SumGradFunctor, true>;
2222

23-
REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel<float>,
23+
REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel<bool>,
24+
CUDAReduceSumGradKernel<float>,
2425
CUDAReduceSumGradKernel<double>,
2526
CUDAReduceSumGradKernel<int>,
2627
CUDAReduceSumGradKernel<int64_t>,

paddle/fluid/operators/tril_triu_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,15 @@ REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker,
105105
ops::TrilTriuGradOpMaker<paddle::imperative::OpBase>);
106106
REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp);
107107
REGISTER_OP_CPU_KERNEL(
108-
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>,
108+
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, bool>,
109+
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, float>,
109110
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, double>,
110111
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int>,
111112
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
112113
ops::TrilTriuOpKernel<paddle::platform::CPUDeviceContext, plat::float16>);
113114
REGISTER_OP_CPU_KERNEL(
114115
tril_triu_grad,
116+
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, bool>,
115117
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, float>,
116118
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, double>,
117119
ops::TrilTriuGradOpKernel<paddle::platform::CPUDeviceContext, int>,

paddle/fluid/operators/tril_triu_op.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@ namespace ops = paddle::operators;
1818
namespace plat = paddle::platform;
1919

2020
REGISTER_OP_CUDA_KERNEL(
21-
tril_triu,
21+
tril_triu, ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, bool>,
2222
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, float>,
2323
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, double>,
2424
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int>,
2525
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
2626
ops::TrilTriuOpKernel<paddle::platform::CUDADeviceContext, plat::float16>);
2727
REGISTER_OP_CUDA_KERNEL(
2828
tril_triu_grad,
29+
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, bool>,
2930
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, float>,
3031
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, double>,
3132
ops::TrilTriuGradOpKernel<paddle::platform::CUDADeviceContext, int>,

0 commit comments

Comments
 (0)