Skip to content

Commit b266f50

Browse files
authored
Merge pull request #5441 from reyoung/feature/fill_constant_force_cpu
Add force_cpu for fill_constant op
2 parents 7f22a6d + 2ac5772 commit b266f50

File tree

7 files changed

+58
-86
lines changed

7 files changed

+58
-86
lines changed

paddle/framework/backward_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include "paddle/framework/var_desc.h"
2222
#include "paddle/operators/net_op.h"
2323

24-
USE_OP(fill_constant);
24+
USE_NO_KERNEL_OP(fill_constant);
2525

2626
namespace paddle {
2727
namespace framework {

paddle/framework/data_type.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ inline DataType ToDataType(std::type_index type) {
3434
}
3535
}
3636

37+
inline std::type_index ToTypeIndex(DataType type) {
38+
switch (type) {
39+
case DataType::FP32:
40+
return typeid(float);
41+
case DataType::FP64:
42+
return typeid(double);
43+
case DataType::INT32:
44+
return typeid(int);
45+
case DataType::INT64:
46+
return typeid(int64_t);
47+
default:
48+
PADDLE_THROW("Not support type %d", type);
49+
}
50+
}
51+
3752
template <typename Visitor>
3853
inline void VisitDataType(DataType type, Visitor visitor) {
3954
switch (type) {

paddle/framework/ddim.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ DDim make_ddim(const std::vector<int64_t>& dims) {
7979
return result;
8080
}
8181

82+
DDim make_ddim(const std::vector<int>& dims) {
83+
std::vector<int64_t> res(dims.size());
84+
std::transform(dims.begin(), dims.end(), res.begin(),
85+
[](int d) { return static_cast<int64_t>(d); });
86+
return make_ddim(res);
87+
}
88+
8289
/// @cond HIDDEN
8390
// XXX For some reason, putting this in an anonymous namespace causes errors
8491
class DynamicMutableIndexer : public boost::static_visitor<int64_t&> {

paddle/framework/ddim.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ struct DDim {
8181
*/
8282
DDim make_ddim(const std::vector<int64_t>& dims);
8383

84+
DDim make_ddim(const std::vector<int>& dims);
85+
8486
/**
8587
* \brief Make a DDim from an initializer list
8688
*

paddle/operators/fill_constant_op.cc

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/operators/fill_constant_op.h"
15+
#include "paddle/framework/data_type.h"
16+
#include "paddle/framework/op_registry.h"
17+
#include "paddle/operators/math/math_function.h"
1618

1719
namespace paddle {
1820
namespace operators {
1921

20-
class FillConstantOp : public framework::OperatorWithKernel {
22+
class FillConstantInferShape : public framework::InferShapeBase {
2123
public:
22-
using framework::OperatorWithKernel::OperatorWithKernel;
23-
24-
void InferShape(framework::InferShapeContext *ctx) const override {
24+
void operator()(framework::InferShapeContext *ctx) const override {
2525
PADDLE_ENFORCE(ctx->HasOutput("Out"),
2626
"Output(Out) of FillConstantOp should not be null.");
2727
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
28-
std::vector<int64_t> shape_int64(shape.size(), 0);
29-
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
30-
[](int a) { return static_cast<int64_t>(a); });
31-
auto dims = framework::make_ddim(shape_int64);
32-
ctx->SetOutputDim("Out", dims);
28+
ctx->SetOutputDim("Out", framework::make_ddim(shape));
3329
}
30+
};
3431

35-
protected:
36-
framework::OpKernelType GetKernelType(
37-
const framework::ExecutionContext &ctx) const override {
38-
int data_type = ctx.Attr<int>("data_type");
39-
VLOG(10) << " FillConstant data_type = " << data_type;
40-
return framework::OpKernelType(static_cast<framework::DataType>(data_type),
41-
ctx.device_context());
32+
class FillConstantOp : public framework::OperatorBase {
33+
public:
34+
using framework::OperatorBase::OperatorBase;
35+
void Run(const framework::Scope &scope,
36+
const platform::DeviceContext &dev_ctx) const override {
37+
auto data_type = static_cast<framework::DataType>(Attr<int>("data_type"));
38+
auto value = Attr<float>("value");
39+
auto force_cpu = Attr<bool>("force_cpu");
40+
auto &out =
41+
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
42+
out.Resize(framework::make_ddim(Attr<std::vector<int>>("shape")));
43+
if (force_cpu) {
44+
auto cpu = platform::CPUPlace();
45+
out.mutable_data(cpu, framework::ToTypeIndex(data_type));
46+
} else {
47+
out.mutable_data(dev_ctx.GetPlace(), framework::ToTypeIndex(data_type));
48+
}
49+
math::set_constant(dev_ctx, &out, value);
4250
}
4351
};
4452

@@ -54,6 +62,11 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
5462
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
5563
AddAttr<float>("value", "(float, default 0) The value to be filled")
5664
.SetDefault(0.0f);
65+
AddAttr<bool>("force_cpu",
66+
"(bool, default false) Force fill output variable to cpu "
67+
"memory. Otherwise, fill output variable to the running "
68+
"device")
69+
.SetDefault(false);
5770
AddOutput("Out",
5871
"(Tensor) Tensor of specified shape will be filled "
5972
"with the specified value");
@@ -69,10 +82,6 @@ Fill up a variable with specified constant value.
6982
} // namespace paddle
7083

7184
namespace ops = paddle::operators;
72-
REGISTER_OP_WITHOUT_GRADIENT(fill_constant, ops::FillConstantOp,
73-
ops::FillConstantOpMaker);
74-
REGISTER_OP_CPU_KERNEL(
75-
fill_constant, ops::FillConstantOpKernel<paddle::platform::CPUPlace, float>,
76-
ops::FillConstantOpKernel<paddle::platform::CPUPlace, double>,
77-
ops::FillConstantOpKernel<paddle::platform::CPUPlace, int>,
78-
ops::FillConstantOpKernel<paddle::platform::CPUPlace, int64_t>);
85+
REGISTER_OPERATOR(fill_constant, ops::FillConstantOp,
86+
ops::FillConstantInferShape, ops::FillConstantOpMaker,
87+
paddle::framework::EmptyGradOpMaker);

paddle/operators/fill_constant_op.cu

Lines changed: 0 additions & 24 deletions
This file was deleted.

paddle/operators/fill_constant_op.h

Lines changed: 0 additions & 37 deletions
This file was deleted.

0 commit comments

Comments
 (0)