Skip to content

Commit effb70f

Browse files
authored
[cherry-pick]CPU forward calculation replaces Eigen with Lapack (#35916) (#36091)
cherry-pick #35916,CPU前向计算将Eigen替换为Lapack,修改linalg暴露规则
1 parent 14cdcde commit effb70f

File tree

9 files changed

+217
-179
lines changed

9 files changed

+217
-179
lines changed

paddle/fluid/operators/eigh_op.cc

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,17 @@ REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker,
147147
REGISTER_OPERATOR(eigh_grad, ops::EighGradOp);
148148

149149
REGISTER_OP_CPU_KERNEL(
150-
eigh, ops::EighKernel<paddle::platform::CPUDeviceContext, float, float>,
151-
ops::EighKernel<paddle::platform::CPUDeviceContext, double, double>,
152-
ops::EighKernel<paddle::platform::CPUDeviceContext, float,
150+
eigh, ops::EighKernel<paddle::platform::CPUDeviceContext, float>,
151+
ops::EighKernel<paddle::platform::CPUDeviceContext, double>,
152+
ops::EighKernel<paddle::platform::CPUDeviceContext,
153153
paddle::platform::complex<float>>,
154-
ops::EighKernel<paddle::platform::CPUDeviceContext, double,
154+
ops::EighKernel<paddle::platform::CPUDeviceContext,
155155
paddle::platform::complex<double>>);
156156

157157
REGISTER_OP_CPU_KERNEL(
158-
eigh_grad,
159-
ops::EighGradKernel<paddle::platform::CPUDeviceContext, float, float>,
160-
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double, double>,
161-
ops::EighGradKernel<paddle::platform::CPUDeviceContext, float,
158+
eigh_grad, ops::EighGradKernel<paddle::platform::CPUDeviceContext, float>,
159+
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double>,
160+
ops::EighGradKernel<paddle::platform::CPUDeviceContext,
162161
paddle::platform::complex<float>>,
163-
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double,
162+
ops::EighGradKernel<paddle::platform::CPUDeviceContext,
164163
paddle::platform::complex<double>>);

paddle/fluid/operators/eigh_op.cu

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,17 @@ limitations under the License. */
1616

1717
namespace ops = paddle::operators;
1818
REGISTER_OP_CUDA_KERNEL(
19-
eigh, ops::EighKernel<paddle::platform::CUDADeviceContext, float, float>,
20-
ops::EighKernel<paddle::platform::CUDADeviceContext, double, double>,
21-
ops::EighKernel<paddle::platform::CUDADeviceContext, float,
19+
eigh, ops::EighKernel<paddle::platform::CUDADeviceContext, float>,
20+
ops::EighKernel<paddle::platform::CUDADeviceContext, double>,
21+
ops::EighKernel<paddle::platform::CUDADeviceContext,
2222
paddle::platform::complex<float>>,
23-
ops::EighKernel<paddle::platform::CUDADeviceContext, double,
23+
ops::EighKernel<paddle::platform::CUDADeviceContext,
2424
paddle::platform::complex<double>>);
2525

2626
REGISTER_OP_CUDA_KERNEL(
27-
eigh_grad,
28-
ops::EighGradKernel<paddle::platform::CUDADeviceContext, float, float>,
29-
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double, double>,
30-
ops::EighGradKernel<paddle::platform::CUDADeviceContext, float,
27+
eigh_grad, ops::EighGradKernel<paddle::platform::CUDADeviceContext, float>,
28+
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double>,
29+
ops::EighGradKernel<paddle::platform::CUDADeviceContext,
3130
paddle::platform::complex<float>>,
32-
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double,
31+
ops::EighGradKernel<paddle::platform::CUDADeviceContext,
3332
paddle::platform::complex<double>>);

paddle/fluid/operators/eigh_op.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace operators {
2222

2323
using Tensor = framework::Tensor;
2424

25-
template <typename DeviceContext, typename ValueType, typename T>
25+
template <typename DeviceContext, typename T>
2626
class EighKernel : public framework::OpKernel<T> {
2727
public:
2828
void Compute(const framework::ExecutionContext& ctx) const override {
@@ -31,15 +31,16 @@ class EighKernel : public framework::OpKernel<T> {
3131
auto output_v = ctx.Output<Tensor>("Eigenvectors");
3232
std::string lower = ctx.Attr<std::string>("UPLO");
3333
bool is_lower = (lower == "L");
34-
math::MatrixEighFunctor<DeviceContext, ValueType, T> functor;
34+
math::MatrixEighFunctor<DeviceContext, T> functor;
3535
functor(ctx, *input, output_w, output_v, is_lower, true);
3636
}
3737
};
3838

39-
template <typename DeviceContext, typename ValueType, typename T>
39+
template <typename DeviceContext, typename T>
4040
class EighGradKernel : public framework::OpKernel<T> {
4141
public:
4242
void Compute(const framework::ExecutionContext& ctx) const override {
43+
using ValueType = math::Real<T>;
4344
auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X"));
4445
x_grad.mutable_data<T>(ctx.GetPlace());
4546
auto& output_w = *ctx.Input<Tensor>("Eigenvalues");

0 commit comments

Comments
 (0)