File tree Expand file tree Collapse file tree 8 files changed +56
-0
lines changed Expand file tree Collapse file tree 8 files changed +56
-0
lines changed Original file line number Diff line number Diff line change @@ -17,6 +17,7 @@ limitations under the License. */
17
17
namespace paddle {
18
18
namespace operators {
19
19
20
+ using Tensor = framework::Tensor;
20
21
class AdadeltaOp : public framework ::OperatorWithKernel {
21
22
public:
22
23
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -55,6 +56,12 @@ class AdadeltaOp : public framework::OperatorWithKernel {
55
56
ctx->SetOutputDim (" AvgSquaredGradOut" , param_dim);
56
57
ctx->SetOutputDim (" AvgSquaredUpdateOut" , param_dim);
57
58
}
59
+ framework::OpKernelType GetExpectedKernelType (
60
+ const framework::ExecutionContext &ctx) const override {
61
+ auto input_data_type =
62
+ framework::ToDataType (ctx.Input <Tensor>(" Param" )->type ());
63
+ return framework::OpKernelType (input_data_type, ctx.GetPlace ());
64
+ }
58
65
};
59
66
60
67
class AdadeltaOpMaker : public framework ::OpProtoAndCheckerMaker {
Original file line number Diff line number Diff line change @@ -23,6 +23,7 @@ limitations under the License. */
23
23
namespace paddle {
24
24
namespace operators {
25
25
26
+ using Tensor = framework::Tensor;
26
27
class AdagradOp : public framework ::OperatorWithKernel {
27
28
public:
28
29
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -56,6 +57,12 @@ class AdagradOp : public framework::OperatorWithKernel {
56
57
ctx->SetOutputDim (" ParamOut" , param_dims);
57
58
ctx->SetOutputDim (" MomentOut" , param_dims);
58
59
}
60
+ framework::OpKernelType GetExpectedKernelType (
61
+ const framework::ExecutionContext& ctx) const override {
62
+ auto input_data_type =
63
+ framework::ToDataType (ctx.Input <Tensor>(" Param" )->type ());
64
+ return framework::OpKernelType (input_data_type, ctx.GetPlace ());
65
+ }
59
66
};
60
67
61
68
class AdagradOpMaker : public framework ::OpProtoAndCheckerMaker {
Original file line number Diff line number Diff line change @@ -17,6 +17,7 @@ limitations under the License. */
17
17
namespace paddle {
18
18
namespace operators {
19
19
20
+ using Tensor = framework::Tensor;
20
21
class AdamOp : public framework ::OperatorWithKernel {
21
22
public:
22
23
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -69,6 +70,12 @@ class AdamOp : public framework::OperatorWithKernel {
69
70
ctx->SetOutputDim (" Moment1Out" , param_dims);
70
71
ctx->SetOutputDim (" Moment2Out" , param_dims);
71
72
}
73
+ framework::OpKernelType GetExpectedKernelType (
74
+ const framework::ExecutionContext &ctx) const override {
75
+ auto input_data_type =
76
+ framework::ToDataType (ctx.Input <Tensor>(" Param" )->type ());
77
+ return framework::OpKernelType (input_data_type, ctx.GetPlace ());
78
+ }
72
79
};
73
80
74
81
class AdamOpMaker : public framework ::OpProtoAndCheckerMaker {
Original file line number Diff line number Diff line change @@ -17,6 +17,7 @@ limitations under the License. */
17
17
namespace paddle {
18
18
namespace operators {
19
19
20
+ using Tensor = framework::Tensor;
20
21
class AdamaxOp : public framework ::OperatorWithKernel {
21
22
public:
22
23
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -63,6 +64,12 @@ class AdamaxOp : public framework::OperatorWithKernel {
63
64
ctx->SetOutputDim (" MomentOut" , param_dims);
64
65
ctx->SetOutputDim (" InfNormOut" , param_dims);
65
66
}
67
+ framework::OpKernelType GetExpectedKernelType (
68
+ const framework::ExecutionContext &ctx) const override {
69
+ auto input_data_type =
70
+ framework::ToDataType (ctx.Input <Tensor>(" Param" )->type ());
71
+ return framework::OpKernelType (input_data_type, ctx.GetPlace ());
72
+ }
66
73
};
67
74
68
75
class AdamaxOpMaker : public framework ::OpProtoAndCheckerMaker {
Original file line number Diff line number Diff line change @@ -17,6 +17,7 @@ limitations under the License. */
17
17
namespace paddle {
18
18
namespace operators {
19
19
20
+ using Tensor = framework::Tensor;
20
21
class DecayedAdagradOp : public framework ::OperatorWithKernel {
21
22
public:
22
23
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -51,6 +52,12 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
51
52
ctx->SetOutputDim (" ParamOut" , param_dims);
52
53
ctx->SetOutputDim (" MomentOut" , param_dims);
53
54
}
55
+ framework::OpKernelType GetExpectedKernelType (
56
+ const framework::ExecutionContext &ctx) const override {
57
+ auto input_data_type =
58
+ framework::ToDataType (ctx.Input <Tensor>(" Param" )->type ());
59
+ return framework::OpKernelType (input_data_type, ctx.GetPlace ());
60
+ }
54
61
};
55
62
56
63
class DecayedAdagradOpMaker : public framework ::OpProtoAndCheckerMaker {
Original file line number Diff line number Diff line change @@ -17,6 +17,7 @@ limitations under the License. */
17
17
namespace paddle {
18
18
namespace operators {
19
19
20
+ using Tensor = framework::Tensor;
20
21
class FTRLOp : public framework ::OperatorWithKernel {
21
22
public:
22
23
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -53,6 +54,12 @@ class FTRLOp : public framework::OperatorWithKernel {
53
54
ctx->SetOutputDim (" SquaredAccumOut" , param_dim);
54
55
ctx->SetOutputDim (" LinearAccumOut" , param_dim);
55
56
}
57
+ framework::OpKernelType GetExpectedKernelType (
58
+ const framework::ExecutionContext &ctx) const override {
59
+ auto input_data_type =
60
+ framework::ToDataType (ctx.Input <Tensor>(" Param" )->type ());
61
+ return framework::OpKernelType (input_data_type, ctx.GetPlace ());
62
+ }
56
63
};
57
64
58
65
class FTRLOpMaker : public framework ::OpProtoAndCheckerMaker {
Original file line number Diff line number Diff line change @@ -17,6 +17,7 @@ limitations under the License. */
17
17
namespace paddle {
18
18
namespace operators {
19
19
20
+ using Tensor = framework::Tensor;
20
21
class ProximalAdagradOp : public framework ::OperatorWithKernel {
21
22
public:
22
23
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -55,6 +56,12 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
55
56
ctx->SetOutputDim (" ParamOut" , param_dim);
56
57
ctx->SetOutputDim (" MomentOut" , param_dim);
57
58
}
59
+ framework::OpKernelType GetExpectedKernelType (
60
+ const framework::ExecutionContext &ctx) const override {
61
+ auto input_data_type =
62
+ framework::ToDataType (ctx.Input <Tensor>(" Param" )->type ());
63
+ return framework::OpKernelType (input_data_type, ctx.GetPlace ());
64
+ }
58
65
};
59
66
60
67
class ProximalAdagradOpMaker : public framework ::OpProtoAndCheckerMaker {
Original file line number Diff line number Diff line change @@ -17,6 +17,7 @@ limitations under the License. */
17
17
namespace paddle {
18
18
namespace operators {
19
19
20
+ using Tensor = framework::Tensor;
20
21
class ProximalGDOp : public framework ::OperatorWithKernel {
21
22
public:
22
23
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -43,6 +44,12 @@ class ProximalGDOp : public framework::OperatorWithKernel {
43
44
44
45
ctx->SetOutputDim (" ParamOut" , param_dim);
45
46
}
47
+ framework::OpKernelType GetExpectedKernelType (
48
+ const framework::ExecutionContext &ctx) const override {
49
+ auto input_data_type =
50
+ framework::ToDataType (ctx.Input <Tensor>(" Param" )->type ());
51
+ return framework::OpKernelType (input_data_type, ctx.GetPlace ());
52
+ }
46
53
};
47
54
48
55
class ProximalGDOpMaker : public framework ::OpProtoAndCheckerMaker {
You can’t perform that action at this time.
0 commit comments