Skip to content

Commit 50d670e

Browse files
authored
Unify dtype and datatype (#5869)
* Change all `data_type` in Python to `dtype` * Change `date_type` in C++ to `dtype` * Refine
1 parent 1ab1b09 commit 50d670e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+225
-239
lines changed

paddle/framework/backward.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ ParamGradInfoMap AppendBackward(
522522
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
523523
{{"shape", std::vector<int>{1}},
524524
{"value", static_cast<float>(1.0)},
525-
{"data_type", target.GetDataType()}}));
525+
{"dtype", target.GetDataType()}}));
526526
// infer var type of fill_one_op
527527
fill_one_op->InferVarType(root_block);
528528

paddle/framework/tensor_array.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ LoDTensor TensorArray::Stack() const {
302302

303303
const auto& first_dims = values_.front().dims();
304304
// check all the values have the same shape
305-
// TODO(superjom) check the same dtypes
305+
// TODO(superjom) check the same data_type
306306
for (size_t idx = 1; idx < size(); idx++) {
307307
const auto& value_dims = values_[idx].dims();
308308
PADDLE_ENFORCE_EQ(first_dims, value_dims);

paddle/operators/cast_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
2525
: OpProtoAndCheckerMaker(proto, op_checker) {
2626
AddInput("X", "The input tensor of cast op");
2727
AddOutput("Out", "The output tensor of cast op");
28-
AddAttr<int>("out_data_type", "output data type");
29-
AddAttr<int>("in_data_type", "input data type");
28+
AddAttr<int>("out_dtype", "output data type");
29+
AddAttr<int>("in_dtype", "input data type");
3030
AddComment(R"DOC(
3131
Cast Operator.
3232
@@ -58,8 +58,8 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker {
5858
grad->SetType("cast");
5959
grad->SetInput("X", OutputGrad("Out"));
6060
grad->SetOutput("Out", InputGrad("X"));
61-
grad->SetAttr("out_data_type", GetAttr("in_data_type"));
62-
grad->SetAttr("in_data_type", GetAttr("out_data_type"));
61+
grad->SetAttr("out_dtype", GetAttr("in_dtype"));
62+
grad->SetAttr("in_dtype", GetAttr("out_dtype"));
6363
return std::unique_ptr<framework::OpDescBind>(grad);
6464
}
6565
};

paddle/operators/cast_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class CastOpKernel : public framework::OpKernel<InT> {
5555
auto* in = context.Input<framework::Tensor>("X");
5656
auto* out = context.Output<framework::Tensor>("Out");
5757
framework::VisitDataType(
58-
static_cast<framework::DataType>(context.Attr<int>("out_data_type")),
58+
static_cast<framework::DataType>(context.Attr<int>("out_dtype")),
5959
CastOpFunctor<Place, InT>(in, out, context.device_context()));
6060
}
6161
};

paddle/operators/fill_constant_batch_size_like_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
5252
framework::OpKernelType GetKernelType(
5353
const framework::ExecutionContext &ctx) const override {
5454
return framework::OpKernelType(
55-
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
55+
static_cast<framework::DataType>(ctx.Attr<int>("dtype")),
5656
ctx.device_context());
5757
}
5858
};
@@ -63,7 +63,7 @@ class FillConstantBatchSizeLikeOpMaker
6363
FillConstantBatchSizeLikeOpMaker(framework::OpProto *proto,
6464
framework::OpAttrChecker *op_checker)
6565
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
66-
AddAttr<int>("data_type",
66+
AddAttr<int>("dtype",
6767
"(int, default 5 (FP32)) "
6868
"Output data type")
6969
.SetDefault(framework::DataType::FP32);

paddle/operators/fill_constant_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class FillConstantOp : public framework::OperatorBase {
3434
using framework::OperatorBase::OperatorBase;
3535
void Run(const framework::Scope &scope,
3636
const platform::DeviceContext &dev_ctx) const override {
37-
auto data_type = static_cast<framework::DataType>(Attr<int>("data_type"));
37+
auto data_type = static_cast<framework::DataType>(Attr<int>("dtype"));
3838
auto value = Attr<float>("value");
3939
auto force_cpu = Attr<bool>("force_cpu");
4040
auto &out =
@@ -55,7 +55,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
5555
FillConstantOpMaker(framework::OpProto *proto,
5656
framework::OpAttrChecker *op_checker)
5757
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
58-
AddAttr<int>("data_type",
58+
AddAttr<int>("dtype",
5959
"(int, default 5 (FP32)) "
6060
"Output data type")
6161
.SetDefault(framework::DataType::FP32);

paddle/operators/gaussian_random_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
6060
framework::OpKernelType GetKernelType(
6161
const framework::ExecutionContext& ctx) const override {
6262
return framework::OpKernelType(
63-
static_cast<framework::DataType>(ctx.Attr<int>("data_type")),
63+
static_cast<framework::DataType>(ctx.Attr<int>("dtype")),
6464
ctx.device_context());
6565
}
6666
};
@@ -88,7 +88,7 @@ class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
8888
"Random seed of generator."
8989
"0 means use system wide seed.")
9090
.SetDefault(0);
91-
AddAttr<int>("data_type",
91+
AddAttr<int>("dtype",
9292
"(int, default 5(FP32)) "
9393
"Output data type.")
9494
.SetDefault(framework::DataType::FP32);

paddle/operators/nccl_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
4949
AddOutput("Communicator",
5050
"Create Communicator for communicating between gpus");
5151
AddAttr<std::vector<int>>("gpus", "(vector<int>) GPU id lists");
52-
AddAttr<int>("data_type",
52+
AddAttr<int>("dtype",
5353
"(int, default 5 (FP32)) "
5454
"Output data type")
5555
.SetDefault(framework::DataType::FP32);

paddle/operators/recurrent_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ class RecurrentGradOp : public RecurrentBase {
401401
auto &inside_tensor = cur_scope.FindVar(inside_grad_name)
402402
->Get<framework::LoDTensor>();
403403
framework::AttributeMap attrs;
404-
attrs["data_type"] = framework::ToDataType(inside_tensor.type());
404+
attrs["dtype"] = framework::ToDataType(inside_tensor.type());
405405
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
406406
attrs["value"] = 0.0f;
407407

paddle/operators/rnn_memory_helper_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class RNNMemoryHelperOpInfoMaker : public framework::OpProtoAndCheckerMaker {
6262
: OpProtoAndCheckerMaker(proto, op_checker) {
6363
AddInput("X", "");
6464
AddOutput("Out", "");
65-
AddAttr<int>("data_type",
65+
AddAttr<int>("dtype",
6666
"(int, default 5 (FP32)) "
6767
"Output data type")
6868
.SetDefault(framework::DataType::FP32);
@@ -95,7 +95,7 @@ class RNNMemoryHelperGradOp : public framework::OperatorBase {
9595
auto &in_var_tensor = in_var->Get<framework::LoDTensor>();
9696

9797
framework::AttributeMap attrs;
98-
attrs["data_type"] = framework::ToDataType(in_var_tensor.type());
98+
attrs["dtype"] = framework::ToDataType(in_var_tensor.type());
9999
attrs["shape"] = framework::vectorize2int(in_var_tensor.dims());
100100
attrs["value"] = 0.0f;
101101

@@ -121,7 +121,7 @@ class RNNMemoryHelperGradOpInfoMaker
121121
AddInput("X", "");
122122
AddInput("Out", "");
123123
AddOutput(framework::GradVarName("X"), "");
124-
AddAttr<int>("data_type",
124+
AddAttr<int>("dtype",
125125
"(int, default 5 (FP32)) "
126126
"Output data type")
127127
.SetDefault(framework::DataType::FP32);

0 commit comments

Comments
 (0)