Skip to content

Commit 008857b

Browse files
authored
fix error message for scatter and scatter_nd (#24514)
1 parent 1437648 commit 008857b

File tree

8 files changed

+155
-91
lines changed

8 files changed

+155
-91
lines changed

paddle/fluid/operators/scatter.cu.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,17 @@ void GPUScatterAssign(const framework::ExecutionContext& context,
9595
const auto& ctx = context.device_context();
9696
if (index.dims().size() == 2) {
9797
PADDLE_ENFORCE_EQ(index.dims()[1], 1,
98-
"index.dims()[1] should be 1 when index.dims().size() == "
99-
"2 in scatter_op.");
98+
platform::errors::InvalidArgument(
99+
"index.dims()[1] should be 1 when "
100+
"index.dims().size() = 2 in scatter_op."
101+
"But received value is [%d]",
102+
index.dims()[1]));
100103
} else {
101104
PADDLE_ENFORCE_EQ(index.dims().size(), 1,
102-
"index.dims().size() should be 1 or 2 in scatter_op.");
105+
platform::errors::InvalidArgument(
106+
"index.dims().size() should be 1 or 2 in scatter_op."
107+
"But received value is [%d]",
108+
index.dims().size()));
103109
}
104110
int index_size = index.dims()[0];
105111

paddle/fluid/operators/scatter.h

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,23 @@ elementwise_inner_add(const framework::ExecutionContext& ctx,
7373
template <typename T, typename IndexT = int>
7474
void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
7575
const Tensor& index, Tensor* output) {
76-
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true);
76+
PADDLE_ENFORCE_EQ(
77+
platform::is_cpu_place(ctx.GetPlace()), true,
78+
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
7779
// check index of shape 1-D
7880
if (index.dims().size() == 2) {
7981
PADDLE_ENFORCE_EQ(index.dims()[1], 1,
80-
"index.dims()[1] should be 1 when index.dims().size() == "
81-
"2 in scatter_op.");
82+
platform::errors::InvalidArgument(
83+
"index.dims()[1] should be 1 when "
84+
"index.dims().size() =2 in scatter_op."
85+
"But received value is [%d]",
86+
index.dims()[1]));
8287
} else {
8388
PADDLE_ENFORCE_EQ(index.dims().size(), 1,
84-
"index.dims().size() should be 1 or 2 in scatter_op.");
89+
platform::errors::InvalidArgument(
90+
"index.dims().size() should be 1 or 2 in scatter_op."
91+
"But received value is [%d]",
92+
index.dims().size()));
8593
}
8694
int index_size = index.dims()[0];
8795

@@ -94,7 +102,9 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
94102

95103
// check src shape and dst shape should match
96104
for (int i = 1; i < src_dims.size(); i++)
97-
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i]);
105+
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i],
106+
platform::errors::InvalidArgument(
107+
"src shape and dst shape should match"));
98108

99109
// slice size
100110
size_t slice_size = 1;
@@ -111,12 +121,14 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
111121
template <typename T, typename IndexT = int>
112122
void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
113123
const Tensor& index, Tensor* output) {
114-
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.device_context().GetPlace()),
115-
true);
124+
PADDLE_ENFORCE_EQ(
125+
platform::is_cpu_place(ctx.device_context().GetPlace()), true,
126+
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
116127
// check index of shape 1-D
117-
PADDLE_ENFORCE(index.dims().size() == 1 ||
118-
(index.dims().size() == 2 && index.dims()[1] == 1),
119-
"");
128+
PADDLE_ENFORCE_EQ(
129+
index.dims().size() == 1 ||
130+
(index.dims().size() == 2 && index.dims()[1] == 1),
131+
true, platform::errors::InvalidArgument("index's shape is error."));
120132
int index_size = index.dims()[0];
121133

122134
auto src_dims = src.dims();
@@ -130,7 +142,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
130142

131143
// check src shape and dst shape should match
132144
for (int i = 1; i < src_dims.size(); i++)
133-
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i]);
145+
PADDLE_ENFORCE_EQ(src_dims[i], dst_dims[i],
146+
platform::errors::InvalidArgument(
147+
"src shape and dst shape should match"));
134148

135149
// slice size
136150
size_t slice_size = 1;
@@ -156,8 +170,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
156170
template <typename T, typename IndexT = int>
157171
void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
158172
const Tensor& index, Tensor* output) {
159-
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.device_context().GetPlace()),
160-
true, "It should be running on the CPU");
173+
PADDLE_ENFORCE_EQ(
174+
platform::is_cpu_place(ctx.device_context().GetPlace()), true,
175+
platform::errors::PreconditionNotMet("It should be running on the CPU"));
161176

162177
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
163178
auto index_dims = index.dims();

paddle/fluid/operators/scatter_nd_add_op.cc

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,19 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
2626

2727
void InferShape(framework::InferShapeContext* ctx) const override {
2828
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
29-
"Input(X) of ScatterNdAddOp should not be null.");
30-
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
31-
"Input(Index) of ScatterNdAddOp should not be null.");
32-
PADDLE_ENFORCE_EQ(ctx->HasInput("Updates"), true,
33-
"Input(Updates) of ScatterNdAddOp should not be null.");
29+
platform::errors::InvalidArgument(
30+
"Input(X) of ScatterNdAddOp should not be null."));
31+
PADDLE_ENFORCE_EQ(
32+
ctx->HasInput("Index"), true,
33+
platform::errors::InvalidArgument(
34+
"Input(Index) of ScatterNdAddOp should not be null."));
35+
PADDLE_ENFORCE_EQ(
36+
ctx->HasInput("Updates"), true,
37+
platform::errors::InvalidArgument(
38+
"Input(Updates) of ScatterNdAddOp should not be null."));
3439
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
35-
"Output(Out) of ScatterNdAddOp should not be null.");
40+
platform::errors::InvalidArgument(
41+
"Output(Out) of ScatterNdAddOp should not be null."));
3642

3743
auto ref_dims = ctx->GetInputDim("X");
3844
auto ref_dims_size = ref_dims.size();
@@ -43,9 +49,11 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
4349

4450
PADDLE_ENFORCE_LE(
4551
index_dims[index_dims_size - 1], ref_dims_size,
46-
"Input(Index).shape[-1] should be no greater than Input(X).rank");
52+
platform::errors::InvalidArgument(
53+
"Input(Index).shape[-1] should be no greater than Input(X).rank"));
4754
PADDLE_ENFORCE_GE(index_dims_size, 2UL,
48-
"The rank of Input(Index) should be greater than 1");
55+
platform::errors::InvalidArgument(
56+
"The rank of Input(Index) should be greater than 1"));
4957

5058
// update.shape = index.shape[:-1] + output.shape[index.shape[-1]:]
5159
std::vector<int64_t> r_updates_dims;
@@ -56,12 +64,14 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
5664
r_updates_dims.emplace_back(ref_dims[i]);
5765
}
5866

59-
PADDLE_ENFORCE_EQ(r_updates_dims.size(), updates_dims_size,
60-
"Updates has wrong shape");
67+
PADDLE_ENFORCE_EQ(
68+
r_updates_dims.size(), updates_dims_size,
69+
platform::errors::InvalidArgument("Updates has wrong shape"));
6170

6271
for (int64_t i = 0; i < updates_dims_size; ++i) {
63-
PADDLE_ENFORCE_EQ(r_updates_dims[i], updates_dims[i],
64-
"Updates has wrong shape");
72+
PADDLE_ENFORCE_EQ(
73+
r_updates_dims[i], updates_dims[i],
74+
platform::errors::InvalidArgument("Updates has wrong shape"));
6575
}
6676
ctx->SetOutputDim("Out", ref_dims);
6777
ctx->ShareLoD("X", /*->*/ "Out");
@@ -72,7 +82,8 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
7282
const framework::ExecutionContext& ctx) const override {
7383
PADDLE_ENFORCE_EQ(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
7484
OperatorWithKernel::IndicateVarDataType(ctx, "Updates"),
75-
"Ref and Updates must have same type");
85+
platform::errors::InvalidArgument(
86+
"Ref and Updates must have same type"));
7687
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
7788
ctx.device_context());
7889
}

paddle/fluid/operators/scatter_nd_add_op.cu

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> {
2525
public:
2626
void Compute(const framework::ExecutionContext &ctx) const override {
2727
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
28-
"This kernel only runs on GPU device.");
28+
platform::errors::PreconditionNotMet(
29+
"This kernel only runs on GPU device."));
2930
auto *X = ctx.Input<Tensor>("X");
3031
auto *Ids = ctx.Input<Tensor>("Index");
3132
auto *Updates = ctx.Input<Tensor>("Updates");
@@ -35,12 +36,15 @@ class ScatterNdAddOpCUDAKernel : public framework::OpKernel<T> {
3536
const auto &index_type = Ids->type();
3637
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
3738
index_type == framework::proto::VarType::INT64;
38-
PADDLE_ENFORCE_EQ(
39-
index_type_match, true,
40-
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
41-
paddle::framework::DataTypeToString(index_type),
42-
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
43-
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
39+
PADDLE_ENFORCE_EQ(index_type_match, true,
40+
platform::errors::InvalidArgument(
41+
"Index holds the wrong type, it holds [%s], but "
42+
"desires to be [%s] or [%s].",
43+
paddle::framework::DataTypeToString(index_type),
44+
paddle::framework::DataTypeToString(
45+
framework::proto::VarType::INT32),
46+
paddle::framework::DataTypeToString(
47+
framework::proto::VarType::INT64)));
4448
if (index_type == framework::proto::VarType::INT32) {
4549
GPUScatterNdAdd<DeviceContext, T, int32_t>(ctx, *Updates, *Ids, Out);
4650
} else {
@@ -54,7 +58,8 @@ class ScatterNdAddGradOpCUDAKernel : public framework::OpKernel<T> {
5458
public:
5559
void Compute(const framework::ExecutionContext &ctx) const override {
5660
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
57-
"This kernel only runs on GPU device.");
61+
platform::errors::PreconditionNotMet(
62+
"This kernel only runs on GPU device."));
5863
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
5964
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
6065
auto *Ids = ctx.Input<Tensor>("Index");

paddle/fluid/operators/scatter_nd_add_op.h

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ template <typename T>
2727
class ScatterNdAddOpKernel : public framework::OpKernel<T> {
2828
public:
2929
void Compute(const framework::ExecutionContext &ctx) const override {
30-
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
31-
"This kernel only runs on CPU.");
30+
PADDLE_ENFORCE_EQ(
31+
platform::is_cpu_place(ctx.GetPlace()), true,
32+
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
3233
auto *X = ctx.Input<Tensor>("X");
3334
auto *Ids = ctx.Input<Tensor>("Index");
3435
auto *Updates = ctx.Input<Tensor>("Updates");
@@ -39,12 +40,15 @@ class ScatterNdAddOpKernel : public framework::OpKernel<T> {
3940
const auto &index_type = Ids->type();
4041
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
4142
index_type == framework::proto::VarType::INT64;
42-
PADDLE_ENFORCE_EQ(
43-
index_type_match, true,
44-
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
45-
paddle::framework::DataTypeToString(index_type),
46-
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
47-
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
43+
PADDLE_ENFORCE_EQ(index_type_match, true,
44+
platform::errors::InvalidArgument(
45+
"Index holds the wrong type, it holds [%s], but "
46+
"desires to be [%s] or [%s].",
47+
paddle::framework::DataTypeToString(index_type),
48+
paddle::framework::DataTypeToString(
49+
framework::proto::VarType::INT32),
50+
paddle::framework::DataTypeToString(
51+
framework::proto::VarType::INT64)));
4852

4953
if (index_type == framework::proto::VarType::INT32) {
5054
ScatterNdAdd<T, int32_t>(ctx, *Updates, *Ids, Out);
@@ -58,8 +62,9 @@ template <typename T>
5862
class ScatterNdAddGradientOpKernel : public framework::OpKernel<T> {
5963
public:
6064
void Compute(const framework::ExecutionContext &ctx) const override {
61-
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
62-
"This kernel only runs on CPU.");
65+
PADDLE_ENFORCE_EQ(
66+
platform::is_cpu_place(ctx.GetPlace()), true,
67+
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
6368
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
6469
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
6570
auto *Ids = ctx.Input<Tensor>("Index");

paddle/fluid/operators/scatter_op.cc

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,32 @@ class ScatterOp : public framework::OperatorWithKernel {
2424
using framework::OperatorWithKernel::OperatorWithKernel;
2525

2626
void InferShape(framework::InferShapeContext* ctx) const override {
27-
PADDLE_ENFORCE(ctx->HasInput("X"),
28-
"Input(X) of ScatterOp should not be null.");
29-
PADDLE_ENFORCE(ctx->HasInput("Ids"),
30-
"Input(Ids) of ScatterOp should not be null.");
31-
PADDLE_ENFORCE(ctx->HasInput("Updates"),
32-
"Input(Updates) of ScatterOp should not be null.");
33-
PADDLE_ENFORCE(ctx->HasOutput("Out"),
34-
"Output(Out) of ScatterOp should not be null.");
27+
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
28+
platform::errors::InvalidArgument(
29+
"Input(X) of ScatterOp should not be null."));
30+
PADDLE_ENFORCE_EQ(ctx->HasInput("Ids"), true,
31+
platform::errors::InvalidArgument(
32+
"Input(Ids) of ScatterOp should not be null."));
33+
PADDLE_ENFORCE_EQ(ctx->HasInput("Updates"), true,
34+
platform::errors::InvalidArgument(
35+
"Input(Updates) of ScatterOp should not be null."));
36+
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
37+
platform::errors::InvalidArgument(
38+
"Output(Out) of ScatterOp should not be null."));
3539

3640
auto updates_dims = ctx->GetInputDim("Updates");
3741
auto ref_dims = ctx->GetInputDim("X");
38-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Ids").size(), 1,
39-
"Update Ids should be 1-D.");
40-
PADDLE_ENFORCE_EQ(ref_dims.size(), updates_dims.size(),
41-
"Xerence and Updates should have the same shape size");
42+
PADDLE_ENFORCE_EQ(
43+
ctx->GetInputDim("Ids").size(), 1,
44+
platform::errors::InvalidArgument("Update Ids should be 1-D."));
45+
PADDLE_ENFORCE_EQ(
46+
ref_dims.size(), updates_dims.size(),
47+
platform::errors::InvalidArgument(
48+
"Rerence and Updates should have the same shape size."));
4249
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0],
4350
ctx->GetInputDim("Ids")[0],
44-
"Updates and Ids should have same batch-size.");
51+
platform::errors::InvalidArgument(
52+
"Updates and Ids should have same batch-size."));
4553
ctx->SetOutputDim("Out", ref_dims);
4654
}
4755

paddle/fluid/operators/scatter_op.cu

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ template <typename T>
2424
class ScatterOpCUDAKernel : public framework::OpKernel<T> {
2525
public:
2626
void Compute(const framework::ExecutionContext &ctx) const override {
27-
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
28-
"This kernel only runs on GPU device.");
27+
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
28+
platform::errors::PreconditionNotMet(
29+
"This kernel only runs on GPU device."));
2930
auto *X = ctx.Input<Tensor>("X");
3031
auto *Ids = ctx.Input<Tensor>("Ids");
3132
auto *Updates = ctx.Input<Tensor>("Updates");
@@ -39,11 +40,14 @@ class ScatterOpCUDAKernel : public framework::OpKernel<T> {
3940
index_type == framework::proto::VarType::INT64;
4041
PADDLE_ENFORCE_EQ(
4142
index_type_match, true,
42-
"scatter_op Index holds the wrong type, it holds %s, but desires to be "
43-
"%s or %s",
44-
paddle::framework::DataTypeToString(index_type),
45-
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
46-
paddle::framework::DataTypeToString(framework::proto::VarType::INT64));
43+
platform::errors::InvalidArgument(
44+
"scatter_op Index holds the wrong type, it holds [%s],"
45+
"but desires to be [%s] or [%s].",
46+
paddle::framework::DataTypeToString(index_type),
47+
paddle::framework::DataTypeToString(
48+
framework::proto::VarType::INT32),
49+
paddle::framework::DataTypeToString(
50+
framework::proto::VarType::INT64)));
4751
if (index_type == framework::proto::VarType::INT32) {
4852
GPUScatterAssign<T, int32_t>(ctx, *Updates, *Ids, Out, overwrite);
4953
} else {
@@ -56,8 +60,9 @@ template <typename T>
5660
class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
5761
public:
5862
void Compute(const framework::ExecutionContext &ctx) const override {
59-
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
60-
"This kernel only runs on GPU device.");
63+
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
64+
platform::errors::PreconditionNotMet(
65+
"This kernel only runs on GPU device."));
6166
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
6267
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
6368
auto *Ids = ctx.Input<Tensor>("Ids");
@@ -74,12 +79,14 @@ class ScatterGradOpCUDAKernel : public framework::OpKernel<T> {
7479
index_type == framework::proto::VarType::INT64;
7580
PADDLE_ENFORCE_EQ(
7681
index_type_match, true,
77-
"scatter_op Index holds the wrong type, it holds %s, but desires to "
78-
"be %s or %s",
79-
paddle::framework::DataTypeToString(index_type),
80-
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
81-
paddle::framework::DataTypeToString(
82-
framework::proto::VarType::INT64));
82+
platform::errors::InvalidArgument(
83+
"scatter_op Index holds the wrong type, it holds [%s], "
84+
"but desires to be [%s] or [%s]",
85+
paddle::framework::DataTypeToString(index_type),
86+
paddle::framework::DataTypeToString(
87+
framework::proto::VarType::INT32),
88+
paddle::framework::DataTypeToString(
89+
framework::proto::VarType::INT64)));
8390
// Gradient by Gather: dUpdates = dO[Ids]
8491
if (index_type == framework::proto::VarType::INT32) {
8592
GPUGather<T, int32_t>(ctx.device_context(), *dOut, *Ids, dUpdates);

0 commit comments

Comments
 (0)