Skip to content

Commit fbb6cd7

Browse files
Release/2.0 rc1 (#29388)
* fix random failed of complex matmul * Make transpose, trace, kron, reshape, sum op support complex type (#29321) * add complex64 and complex128 type; add +-*/@ and slice opreator for complex types * add test cases for complex elementwise, matmul and getitem unittest * add test cases for complex types * add test cases for complex matmul unittest * kron, reshape, transpose support complex types * sum and trace op support complex types * add test case of sum and trace op * fix the bug of imag part of complex not initialized * format file * format code style * kron support type promotion; modify test cases
1 parent 4a8aef4 commit fbb6cd7

File tree

15 files changed

+360
-137
lines changed

15 files changed

+360
-137
lines changed

paddle/fluid/operators/kron_op.cc

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ limitations under the License. */
1818
#include <vector>
1919

2020
#include "paddle/fluid/operators/kron_op.h"
21+
#include "paddle/fluid/platform/complex128.h"
22+
#include "paddle/fluid/platform/complex64.h"
2123
#include "paddle/fluid/platform/float16.h"
2224

2325
namespace paddle {
@@ -51,8 +53,22 @@ class KronOp : public framework::OperatorWithKernel {
5153
protected:
5254
framework::OpKernelType GetExpectedKernelType(
5355
const framework::ExecutionContext& ctx) const override {
54-
return framework::OpKernelType(
55-
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
56+
auto data_type =
57+
OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y");
58+
return framework::OpKernelType(data_type, ctx.GetPlace());
59+
}
60+
61+
framework::OpKernelType GetKernelTypeForVar(
62+
const std::string& var_name, const framework::Tensor& tensor,
63+
const framework::OpKernelType& expected_kernel_type) const {
64+
if (framework::IsComplexType(expected_kernel_type.data_type_)) {
65+
// only promote inputs’s types when contains complex input
66+
return framework::OpKernelType(tensor.type(), tensor.place(),
67+
tensor.layout());
68+
} else {
69+
return framework::OpKernelType(expected_kernel_type.data_type_,
70+
tensor.place(), tensor.layout());
71+
}
5672
}
5773
};
5874

@@ -154,7 +170,11 @@ REGISTER_OP_CPU_KERNEL(
154170
ops::KronKernel<paddle::platform::CPUDeviceContext,
155171
paddle::platform::float16>,
156172
ops::KronKernel<paddle::platform::CPUDeviceContext, int>,
157-
ops::KronKernel<paddle::platform::CPUDeviceContext, int64_t>);
173+
ops::KronKernel<paddle::platform::CPUDeviceContext, int64_t>,
174+
ops::KronKernel<paddle::platform::CPUDeviceContext,
175+
paddle::platform::complex64>,
176+
ops::KronKernel<paddle::platform::CPUDeviceContext,
177+
paddle::platform::complex128>);
158178

159179
REGISTER_OPERATOR(kron_grad, ops::KronGradOp);
160180
REGISTER_OP_CPU_KERNEL(
@@ -163,4 +183,8 @@ REGISTER_OP_CPU_KERNEL(
163183
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
164184
paddle::platform::float16>,
165185
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int>,
166-
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
186+
ops::KronGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
187+
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
188+
paddle::platform::complex64>,
189+
ops::KronGradKernel<paddle::platform::CPUDeviceContext,
190+
paddle::platform::complex128>);

paddle/fluid/operators/kron_op.cu

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/kron_op.h"
16+
#include "paddle/fluid/platform/complex128.h"
17+
#include "paddle/fluid/platform/complex64.h"
1618
#include "paddle/fluid/platform/float16.h"
1719

1820
namespace ops = paddle::operators;
@@ -22,12 +24,20 @@ REGISTER_OP_CUDA_KERNEL(
2224
ops::KronKernel<paddle::platform::CUDADeviceContext,
2325
paddle::platform::float16>,
2426
ops::KronKernel<paddle::platform::CUDADeviceContext, int>,
25-
ops::KronKernel<paddle::platform::CUDADeviceContext, int64_t>);
27+
ops::KronKernel<paddle::platform::CUDADeviceContext, int64_t>,
28+
ops::KronKernel<paddle::platform::CUDADeviceContext,
29+
paddle::platform::complex64>,
30+
ops::KronKernel<paddle::platform::CUDADeviceContext,
31+
paddle::platform::complex128>);
2632

2733
REGISTER_OP_CUDA_KERNEL(
2834
kron_grad, ops::KronGradKernel<paddle::platform::CUDADeviceContext, float>,
2935
ops::KronGradKernel<paddle::platform::CUDADeviceContext, double>,
3036
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
3137
paddle::platform::float16>,
3238
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int>,
33-
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
39+
ops::KronGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
40+
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
41+
paddle::platform::complex64>,
42+
ops::KronGradKernel<paddle::platform::CUDADeviceContext,
43+
paddle::platform::complex128>);

paddle/fluid/operators/reduce_ops/reduce_sum_op.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ REGISTER_OP_CPU_KERNEL(
115115
ops::SumFunctor>,
116116
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::SumFunctor>,
117117
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
118+
ops::SumFunctor>,
119+
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
120+
paddle::platform::complex64, ops::SumFunctor>,
121+
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
122+
paddle::platform::complex128,
123+
118124
ops::SumFunctor>);
119125

120126
template <typename T>
@@ -125,4 +131,6 @@ using CPUReduceSumGradKernel =
125131
REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<float>,
126132
CPUReduceSumGradKernel<double>,
127133
CPUReduceSumGradKernel<int>,
128-
CPUReduceSumGradKernel<int64_t>);
134+
CPUReduceSumGradKernel<int64_t>,
135+
CPUReduceSumGradKernel<paddle::platform::complex64>,
136+
CPUReduceSumGradKernel<paddle::platform::complex128>);

paddle/fluid/operators/reduce_ops/reduce_sum_op.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,6 @@ class ReduceSumKernel : public framework::OpKernel<T> {
7272

7373
REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<float>,
7474
ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>,
75-
ops::ReduceSumKernel<int64_t>);
75+
ops::ReduceSumKernel<int64_t>,
76+
ops::ReduceSumKernel<paddle::platform::complex64>,
77+
ops::ReduceSumKernel<paddle::platform::complex128>);

paddle/fluid/operators/reshape_op.cc

100755100644
Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -618,26 +618,26 @@ REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
618618
ops::ReshapeDoubleGradInplaceInferer,
619619
ops::ReshapeDoubleGradOpNoNeedBufferVarInferer);
620620

621-
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
622-
ops::ReshapeKernel, int8_t, ops::ReshapeKernel,
623-
uint8_t, ops::ReshapeKernel, int,
624-
ops::ReshapeKernel, int64_t, ops::ReshapeKernel,
625-
bool, ops::ReshapeKernel,
626-
paddle::platform::bfloat16, ops::ReshapeKernel);
627-
628-
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
629-
double, ops::ReshapeGradKernel, int,
630-
ops::ReshapeGradKernel, uint8_t,
631-
ops::ReshapeGradKernel, int64_t,
632-
ops::ReshapeGradKernel, bool,
633-
ops::ReshapeGradKernel);
634-
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad_grad, float,
635-
ops::ReshapeDoubleGradKernel, double,
636-
ops::ReshapeDoubleGradKernel, int,
637-
ops::ReshapeDoubleGradKernel, uint8_t,
638-
ops::ReshapeDoubleGradKernel, int64_t,
639-
ops::ReshapeDoubleGradKernel, bool,
640-
ops::ReshapeDoubleGradKernel);
621+
REGISTER_OP_CPU_KERNEL_FUNCTOR(
622+
reshape2, float, ops::ReshapeKernel, double, ops::ReshapeKernel, int8_t,
623+
ops::ReshapeKernel, uint8_t, ops::ReshapeKernel, int, ops::ReshapeKernel,
624+
int64_t, ops::ReshapeKernel, bool, ops::ReshapeKernel,
625+
paddle::platform::bfloat16, ops::ReshapeKernel, paddle::platform::complex64,
626+
ops::ReshapeKernel, paddle::platform::complex128, ops::ReshapeKernel);
627+
628+
REGISTER_OP_CPU_KERNEL_FUNCTOR(
629+
reshape2_grad, float, ops::ReshapeGradKernel, double,
630+
ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t,
631+
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, bool,
632+
ops::ReshapeGradKernel, paddle::platform::complex64, ops::ReshapeGradKernel,
633+
paddle::platform::complex128, ops::ReshapeGradKernel);
634+
REGISTER_OP_CPU_KERNEL_FUNCTOR(
635+
reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double,
636+
ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t,
637+
ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel, bool,
638+
ops::ReshapeDoubleGradKernel, paddle::platform::complex64,
639+
ops::ReshapeDoubleGradKernel, paddle::platform::complex128,
640+
ops::ReshapeDoubleGradKernel);
641641

642642
#ifdef PADDLE_WITH_CUDA
643643
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
@@ -656,34 +656,38 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
656656
ops::ReshapeKernel, int, ops::ReshapeKernel,
657657
uint8_t, ops::ReshapeKernel, int64_t,
658658
ops::ReshapeKernel, plat::float16,
659-
ops::ReshapeKernel, bool, ops::ReshapeKernel);
660-
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
661-
double, ops::ReshapeGradKernel, int,
662-
ops::ReshapeGradKernel, uint8_t,
663-
ops::ReshapeGradKernel, int64_t,
664-
ops::ReshapeGradKernel, plat::float16,
665-
ops::ReshapeGradKernel, bool,
666-
ops::ReshapeGradKernel);
667-
668-
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad_grad, float,
669-
ops::ReshapeDoubleGradKernel, double,
670-
ops::ReshapeDoubleGradKernel, int,
671-
ops::ReshapeDoubleGradKernel, uint8_t,
672-
ops::ReshapeDoubleGradKernel, int64_t,
673-
ops::ReshapeDoubleGradKernel, plat::float16,
674-
ops::ReshapeDoubleGradKernel, bool,
675-
ops::ReshapeDoubleGradKernel);
659+
ops::ReshapeKernel, bool, ops::ReshapeKernel,
660+
plat::complex64, ops::ReshapeKernel,
661+
plat::complex128, ops::ReshapeKernel);
662+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(
663+
reshape2_grad, float, ops::ReshapeGradKernel, double,
664+
ops::ReshapeGradKernel, int, ops::ReshapeGradKernel, uint8_t,
665+
ops::ReshapeGradKernel, int64_t, ops::ReshapeGradKernel, plat::float16,
666+
ops::ReshapeGradKernel, bool, ops::ReshapeGradKernel, plat::complex64,
667+
ops::ReshapeGradKernel, plat::complex128, ops::ReshapeGradKernel);
668+
669+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(
670+
reshape2_grad_grad, float, ops::ReshapeDoubleGradKernel, double,
671+
ops::ReshapeDoubleGradKernel, int, ops::ReshapeDoubleGradKernel, uint8_t,
672+
ops::ReshapeDoubleGradKernel, int64_t, ops::ReshapeDoubleGradKernel,
673+
plat::float16, ops::ReshapeDoubleGradKernel, bool,
674+
ops::ReshapeDoubleGradKernel, plat::complex64, ops::ReshapeDoubleGradKernel,
675+
plat::complex128, ops::ReshapeDoubleGradKernel);
676676
#endif
677677

678678
#ifdef PADDLE_WITH_XPU
679679
REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
680680
ops::ReshapeKernel, int, ops::ReshapeKernel,
681681
int64_t, ops::ReshapeKernel, plat::float16,
682-
ops::ReshapeKernel, bool, ops::ReshapeKernel);
682+
ops::ReshapeKernel, bool, ops::ReshapeKernel,
683+
plat::complex64, ops::ReshapeKernel,
684+
plat::complex128, ops::ReshapeKernel);
683685
REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
684686
double, ops::ReshapeGradKernel, int,
685687
ops::ReshapeGradKernel, int64_t,
686688
ops::ReshapeGradKernel, plat::float16,
687689
ops::ReshapeGradKernel, bool,
690+
ops::ReshapeGradKernel, plat::complex64,
691+
ops::ReshapeGradKernel, plat::complex128,
688692
ops::ReshapeGradKernel);
689693
#endif

paddle/fluid/operators/trace_op.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,17 @@ REGISTER_OP_CPU_KERNEL(
163163
trace, ops::TraceKernel<paddle::platform::CPUDeviceContext, int>,
164164
ops::TraceKernel<paddle::platform::CPUDeviceContext, float>,
165165
ops::TraceKernel<paddle::platform::CPUDeviceContext, double>,
166-
ops::TraceKernel<paddle::platform::CPUDeviceContext, int64_t>);
166+
ops::TraceKernel<paddle::platform::CPUDeviceContext, int64_t>,
167+
ops::TraceKernel<paddle::platform::CPUDeviceContext,
168+
paddle::platform::complex64>,
169+
ops::TraceKernel<paddle::platform::CPUDeviceContext,
170+
paddle::platform::complex128>);
167171
REGISTER_OP_CPU_KERNEL(
168172
trace_grad, ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int>,
169173
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, float>,
170174
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, double>,
171-
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
175+
ops::TraceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
176+
ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
177+
paddle::platform::complex64>,
178+
ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
179+
paddle::platform::complex128>);

paddle/fluid/operators/trace_op.cu

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,19 @@ REGISTER_OP_CUDA_KERNEL(
6060
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
6161
platform::float16>,
6262
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, float>,
63-
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, double>);
63+
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext, double>,
64+
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
65+
paddle::platform::complex64>,
66+
ops::TraceCUDAKernel<paddle::platform::CUDADeviceContext,
67+
paddle::platform::complex128>);
6468
REGISTER_OP_CUDA_KERNEL(
6569
trace_grad, ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int>,
6670
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
6771
ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
6872
platform::float16>,
6973
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, float>,
70-
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, double>);
74+
ops::TraceGradKernel<paddle::platform::CUDADeviceContext, double>,
75+
ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
76+
paddle::platform::complex64>,
77+
ops::TraceGradKernel<paddle::platform::CUDADeviceContext,
78+
paddle::platform::complex128>);

paddle/fluid/operators/transpose_op.cc

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,19 @@ REGISTER_OPERATOR(transpose_grad, ops::TransposeOpGrad);
321321

322322
REGISTER_OP_CPU_KERNEL(
323323
transpose, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
324-
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
324+
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
325+
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
326+
paddle::platform::complex64>,
327+
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
328+
paddle::platform::complex128>);
325329
REGISTER_OP_CPU_KERNEL(
326330
transpose_grad,
327331
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
328-
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
332+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
333+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
334+
paddle::platform::complex64>,
335+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
336+
paddle::platform::complex128>);
329337

330338
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
331339
ops::Transpose2GradMaker<paddle::framework::OpDesc>,
@@ -336,10 +344,18 @@ REGISTER_OP_CPU_KERNEL(
336344
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
337345
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int32_t>,
338346
ops::TransposeKernel<paddle::platform::CPUDeviceContext, int64_t>,
339-
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>);
347+
ops::TransposeKernel<paddle::platform::CPUDeviceContext, double>,
348+
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
349+
paddle::platform::complex64>,
350+
ops::TransposeKernel<paddle::platform::CPUDeviceContext,
351+
paddle::platform::complex128>);
340352
REGISTER_OP_CPU_KERNEL(
341353
transpose2_grad,
342354
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int32_t>,
343355
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
344356
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
345-
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>);
357+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext, double>,
358+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
359+
paddle::platform::complex64>,
360+
ops::TransposeGradKernel<paddle::platform::CPUDeviceContext,
361+
paddle::platform::complex128>);

paddle/fluid/operators/transpose_op.cu

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -730,28 +730,42 @@ REGISTER_OP_CUDA_KERNEL(
730730
transpose,
731731
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>,
732732
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>,
733+
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>,
733734
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
734-
plat::float16>);
735+
paddle::platform::complex64>,
736+
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
737+
paddle::platform::complex128>);
735738
REGISTER_OP_CUDA_KERNEL(
736739
transpose_grad,
737740
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
738741
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, double>,
739742
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
740-
plat::float16>);
743+
plat::float16>,
744+
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
745+
paddle::platform::complex64>,
746+
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
747+
paddle::platform::complex128>);
741748

742749
REGISTER_OP_CUDA_KERNEL(
743750
transpose2,
744751
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, int32_t>,
745752
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, int64_t>,
746753
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, float>,
747754
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, double>,
755+
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext, plat::float16>,
748756
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
749-
plat::float16>);
757+
paddle::platform::complex64>,
758+
ops::TransposeGPUKernel<paddle::platform::CUDADeviceContext,
759+
paddle::platform::complex128>);
750760
REGISTER_OP_CUDA_KERNEL(
751761
transpose2_grad,
752762
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int32_t>,
753763
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, int64_t>,
754764
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, float>,
755765
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext, double>,
756766
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
757-
plat::float16>);
767+
plat::float16>,
768+
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
769+
paddle::platform::complex64>,
770+
ops::TransposeGradGPUKernel<paddle::platform::CUDADeviceContext,
771+
paddle::platform::complex128>);

paddle/fluid/platform/complex64.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ struct PADDLE_ALIGN(8) complex64 {
124124

125125
HOSTDEVICE inline complex64& operator=(int32_t val) {
126126
real = static_cast<float>(val);
127+
imag = 0;
127128
return *this;
128129
}
129130

0 commit comments

Comments
 (0)