@@ -12,33 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include " paddle/operators/fill_constant_op.h"
15
+ #include " paddle/framework/data_type.h"
16
+ #include " paddle/framework/op_registry.h"
17
+ #include " paddle/operators/math/math_function.h"
16
18
17
19
namespace paddle {
18
20
namespace operators {
19
21
20
- class FillConstantOp : public framework ::OperatorWithKernel {
22
+ class FillConstantInferShape : public framework ::InferShapeBase {
21
23
public:
22
- using framework::OperatorWithKernel::OperatorWithKernel;
23
-
24
- void InferShape (framework::InferShapeContext *ctx) const override {
24
+ void operator ()(framework::InferShapeContext *ctx) const override {
25
25
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
26
26
" Output(Out) of FillConstantOp should not be null." );
27
27
auto &shape = ctx->Attrs ().Get <std::vector<int >>(" shape" );
28
- std::vector<int64_t > shape_int64 (shape.size (), 0 );
29
- std::transform (shape.begin (), shape.end (), shape_int64.begin (),
30
- [](int a) { return static_cast <int64_t >(a); });
31
- auto dims = framework::make_ddim (shape_int64);
32
- ctx->SetOutputDim (" Out" , dims);
28
+ ctx->SetOutputDim (" Out" , framework::make_ddim (shape));
33
29
}
30
+ };
34
31
35
- protected:
36
- framework::OpKernelType GetKernelType (
37
- const framework::ExecutionContext &ctx) const override {
38
- int data_type = ctx.Attr <int >(" data_type" );
39
- VLOG (10 ) << " FillConstant data_type = " << data_type;
40
- return framework::OpKernelType (static_cast <framework::DataType>(data_type),
41
- ctx.device_context ());
32
+ class FillConstantOp : public framework ::OperatorBase {
33
+ public:
34
+ using framework::OperatorBase::OperatorBase;
35
+ void Run (const framework::Scope &scope,
36
+ const platform::DeviceContext &dev_ctx) const override {
37
+ auto data_type = static_cast <framework::DataType>(Attr<int >(" data_type" ));
38
+ auto value = Attr<float >(" value" );
39
+ auto force_cpu = Attr<bool >(" force_cpu" );
40
+ auto &out =
41
+ *scope.FindVar (Output (" Out" ))->GetMutable <framework::LoDTensor>();
42
+ out.Resize (framework::make_ddim (Attr<std::vector<int >>(" shape" )));
43
+ if (force_cpu) {
44
+ auto cpu = platform::CPUPlace ();
45
+ out.mutable_data (cpu, framework::ToTypeIndex (data_type));
46
+ } else {
47
+ out.mutable_data (dev_ctx.GetPlace (), framework::ToTypeIndex (data_type));
48
+ }
49
+ math::set_constant (dev_ctx, &out, value);
42
50
}
43
51
};
44
52
@@ -54,6 +62,11 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
54
62
AddAttr<std::vector<int >>(" shape" , " (vector<int>) The shape of the output" );
55
63
AddAttr<float >(" value" , " (float, default 0) The value to be filled" )
56
64
.SetDefault (0 .0f );
65
+ AddAttr<bool >(" force_cpu" ,
66
+ " (bool, default false) Force fill output variable to cpu "
67
+ " memory. Otherwise, fill output variable to the running "
68
+ " device" )
69
+ .SetDefault (false );
57
70
AddOutput (" Out" ,
58
71
" (Tensor) Tensor of specified shape will be filled "
59
72
" with the specified value" );
@@ -69,10 +82,6 @@ Fill up a variable with specified constant value.
69
82
} // namespace paddle
70
83
71
84
namespace ops = paddle::operators;
72
- REGISTER_OP_WITHOUT_GRADIENT (fill_constant, ops::FillConstantOp,
73
- ops::FillConstantOpMaker);
74
- REGISTER_OP_CPU_KERNEL (
75
- fill_constant, ops::FillConstantOpKernel<paddle::platform::CPUPlace, float >,
76
- ops::FillConstantOpKernel<paddle::platform::CPUPlace, double >,
77
- ops::FillConstantOpKernel<paddle::platform::CPUPlace, int >,
78
- ops::FillConstantOpKernel<paddle::platform::CPUPlace, int64_t >);
85
+ REGISTER_OPERATOR (fill_constant, ops::FillConstantOp,
86
+ ops::FillConstantInferShape, ops::FillConstantOpMaker,
87
+ paddle::framework::EmptyGradOpMaker);
0 commit comments