Skip to content

Commit d29e9aa

Browse files
authored
[Cherry-pick to 1.6] Block part of "tensor should not be null" error message (#20845)
* Add IndicateVarDataType interface to block tensor is not initialized problem in OP GetExceptedKernelType (#20044) * add indicate_var_data_type inferface, test=develop * add unittests & polish error message, test=develop * remove needless include, test=develop * extract public function & polish message, test=develop * delete empty var check, test=develop * change data_type to pointer parameter, test=develop * polish details, test=develop * Replace risky GetInputType method with secure IndicateVarDataType interface (#20668) * replace part of the old implementation, test=develop * restore concat op, test=develop * update all ops implemention & delete GetDataTypeOfVar func, test=develop test=release/1.6
1 parent e1e5845 commit d29e9aa

File tree

167 files changed

+882
-565
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

167 files changed

+882
-565
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,6 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
4848
std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
4949
};
5050

51-
proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
52-
if (var->IsType<framework::LoDTensor>()) {
53-
return var->Get<framework::LoDTensor>().type();
54-
} else if (var->IsType<framework::SelectedRows>()) {
55-
return var->Get<framework::SelectedRows>().value().type();
56-
} else {
57-
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
58-
}
59-
}
60-
6151
static DDim GetDimsDebug(const Scope& scope, const std::string& name,
6252
bool get_actual_dim = false) {
6353
Variable* var = scope.FindVar(name);
@@ -1152,40 +1142,65 @@ Scope* OperatorWithKernel::PrepareData(
11521142
return new_scope;
11531143
}
11541144

1145+
void OperatorWithKernel::ParseInputDataType(
1146+
const ExecutionContext& ctx, const std::string& name,
1147+
proto::VarType::Type* data_type) const {
1148+
proto::VarType::Type dafault_data_type =
1149+
static_cast<proto::VarType::Type>(-1);
1150+
const std::vector<const Variable*> vars = ctx.MultiInputVar(name);
1151+
for (size_t i = 0; i < vars.size(); ++i) {
1152+
const Variable* var = vars[i];
1153+
if (var != nullptr) {
1154+
const Tensor* t = nullptr;
1155+
if (var->IsType<Tensor>()) {
1156+
t = &var->Get<Tensor>();
1157+
} else if (var->IsType<LoDTensor>()) {
1158+
t = &var->Get<LoDTensor>();
1159+
} else if (var->IsType<SelectedRows>()) {
1160+
t = &(var->Get<SelectedRows>().value());
1161+
}
1162+
if (t != nullptr) {
1163+
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
1164+
"The Tensor in the %s Op's Input Variable %s(%s) is "
1165+
"not initialized.",
1166+
Type(), name, ctx.Inputs(name).at(i));
1167+
proto::VarType::Type tmp = t->type();
1168+
PADDLE_ENFORCE(tmp == *data_type || *data_type == dafault_data_type,
1169+
"The DataType of %s Op's duplicable Variable %s must be "
1170+
"consistent. The current variable type is (%s), but the "
1171+
"previous variable type is (%s).",
1172+
Type(), name, DataTypeToString(tmp),
1173+
DataTypeToString(*data_type));
1174+
*data_type = tmp;
1175+
}
1176+
}
1177+
}
1178+
}
1179+
11551180
proto::VarType::Type OperatorWithKernel::IndicateDataType(
11561181
const ExecutionContext& ctx) const {
11571182
proto::VarType::Type dafault_data_type =
11581183
static_cast<proto::VarType::Type>(-1);
11591184
proto::VarType::Type data_type = dafault_data_type;
11601185
for (auto& input : this->inputs_) {
1161-
const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first);
1162-
for (size_t i = 0; i < vars.size(); ++i) {
1163-
const Variable* var = vars[i];
1164-
if (var != nullptr) {
1165-
const Tensor* t = nullptr;
1166-
if (var->IsType<Tensor>()) {
1167-
t = &var->Get<Tensor>();
1168-
} else if (var->IsType<LoDTensor>()) {
1169-
t = &var->Get<LoDTensor>();
1170-
} else if (var->IsType<SelectedRows>()) {
1171-
t = &(var->Get<SelectedRows>().value());
1172-
}
1173-
if (t != nullptr) {
1174-
PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu) is not initialized",
1175-
input.first, i);
1176-
proto::VarType::Type tmp = t->type();
1177-
PADDLE_ENFORCE(
1178-
tmp == data_type || data_type == dafault_data_type,
1179-
"DataType of Paddle Op %s %s must be the same. Get (%s) != (%s)",
1180-
Type(), input.first, DataTypeToString(data_type),
1181-
DataTypeToString(tmp));
1182-
data_type = tmp;
1183-
}
1184-
}
1185-
}
1186+
ParseInputDataType(ctx, input.first, &data_type);
11861187
}
1187-
PADDLE_ENFORCE(data_type != dafault_data_type,
1188-
"DataType should be indicated by input");
1188+
PADDLE_ENFORCE_NE(data_type, dafault_data_type,
1189+
"DataType should be indicated by input Variable.");
1190+
return data_type;
1191+
}
1192+
1193+
proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
1194+
const ExecutionContext& ctx, const std::string& name) const {
1195+
proto::VarType::Type dafault_data_type =
1196+
static_cast<proto::VarType::Type>(-1);
1197+
proto::VarType::Type data_type = dafault_data_type;
1198+
ParseInputDataType(ctx, name, &data_type);
1199+
PADDLE_ENFORCE_NE(
1200+
data_type, dafault_data_type,
1201+
"The Input Variable(%s) of %s Op used to determine kernel data type "
1202+
"is empty or not LoDTensor or SelectedRows.",
1203+
name, Type());
11891204
return data_type;
11901205
}
11911206

paddle/fluid/framework/operator.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) {
102102
}
103103
}
104104

105-
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
106105
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
107106
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
108107

@@ -459,6 +458,9 @@ class OperatorWithKernel : public OperatorBase {
459458
void RuntimeInferShape(const Scope& scope, const platform::Place& place,
460459
const RuntimeContext& ctx) const override;
461460

461+
proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx,
462+
const std::string& name) const;
463+
462464
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
463465

464466
std::vector<KernelConfig>* GetKernelConfig(const OpKernelType& key) const;
@@ -470,6 +472,8 @@ class OperatorWithKernel : public OperatorBase {
470472
const OpKernelType& expected_kernel_type) const;
471473

472474
private:
475+
void ParseInputDataType(const ExecutionContext& ctx, const std::string& name,
476+
proto::VarType::Type* type) const;
473477
// indicate kernel DataType by input data. By default all input data must be
474478
// same.
475479
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;

paddle/fluid/framework/operator_test.cc

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,182 @@ TEST(VarNameTest, all) {
315315
original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
316316
ASSERT_EQ(original_var_name, "");
317317
}
318+
319+
namespace paddle {
320+
namespace framework {
321+
322+
class IndicateLoDTensorDataTypeTest : public OperatorWithKernel {
323+
public:
324+
using OperatorWithKernel::OperatorWithKernel;
325+
326+
protected:
327+
void InferShape(framework::InferShapeContext* ctx) const override {}
328+
OpKernelType GetExpectedKernelType(
329+
const ExecutionContext& ctx) const override {
330+
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "LoDTensor");
331+
return framework::OpKernelType(data_type, ctx.device_context());
332+
}
333+
};
334+
class IndicateLoDTensorDataTypeTestProtoMaker : public OpProtoAndCheckerMaker {
335+
public:
336+
void Make() {
337+
AddInput("LoDTensor", "Input of Tensor type Variable.");
338+
AddComment("This Op is only for IndicateVarDataType inferface test.");
339+
}
340+
};
341+
342+
class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel {
343+
public:
344+
using OperatorWithKernel::OperatorWithKernel;
345+
346+
protected:
347+
void InferShape(framework::InferShapeContext* ctx) const override {}
348+
OpKernelType GetExpectedKernelType(
349+
const ExecutionContext& ctx) const override {
350+
auto data_type =
351+
OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows");
352+
return framework::OpKernelType(data_type, ctx.device_context());
353+
}
354+
};
355+
class IndicateSelectedRowsDataTypeTestProtoMaker
356+
: public OpProtoAndCheckerMaker {
357+
public:
358+
void Make() {
359+
AddInput("SelectedRows", "Input of SelectedRows type Variable.");
360+
AddComment("This Op is only for IndicateVarDataType inferface test.");
361+
}
362+
};
363+
364+
class IndicateOtherDataTypeTest : public OperatorWithKernel {
365+
public:
366+
using OperatorWithKernel::OperatorWithKernel;
367+
368+
protected:
369+
void InferShape(framework::InferShapeContext* ctx) const override {}
370+
OpKernelType GetExpectedKernelType(
371+
const ExecutionContext& ctx) const override {
372+
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other");
373+
return framework::OpKernelType(data_type, ctx.device_context());
374+
}
375+
};
376+
class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker {
377+
public:
378+
void Make() {
379+
AddInput("Other", "Input of Other type Variable");
380+
AddComment("This Op is only for IndicateVarDataType inferface test.");
381+
}
382+
};
383+
384+
template <typename DeviceContext, typename T>
385+
class IndicateVarDataTypeKernelTest : public OpKernel<T> {
386+
public:
387+
void Compute(const ExecutionContext& ctx) const {}
388+
};
389+
390+
} // namespace framework
391+
} // namespace paddle
392+
393+
REGISTER_OP_WITHOUT_GRADIENT(
394+
indicate_lod_tensor_data_type_test,
395+
paddle::framework::IndicateLoDTensorDataTypeTest,
396+
paddle::framework::IndicateLoDTensorDataTypeTestProtoMaker);
397+
REGISTER_OP_WITHOUT_GRADIENT(
398+
indicate_selected_rows_data_type_test,
399+
paddle::framework::IndicateSelectedRowsDataTypeTest,
400+
paddle::framework::IndicateSelectedRowsDataTypeTestProtoMaker);
401+
REGISTER_OP_WITHOUT_GRADIENT(
402+
indicate_other_data_type_test, paddle::framework::IndicateOtherDataTypeTest,
403+
paddle::framework::IndicateOtherDataTypeTestProtoMaker);
404+
405+
REGISTER_OP_CPU_KERNEL(indicate_lod_tensor_data_type_test,
406+
paddle::framework::IndicateVarDataTypeKernelTest<
407+
paddle::platform::CPUDeviceContext, int>);
408+
REGISTER_OP_CPU_KERNEL(indicate_selected_rows_data_type_test,
409+
paddle::framework::IndicateVarDataTypeKernelTest<
410+
paddle::platform::CPUDeviceContext, int>);
411+
REGISTER_OP_CPU_KERNEL(indicate_other_data_type_test,
412+
paddle::framework::IndicateVarDataTypeKernelTest<
413+
paddle::platform::CPUDeviceContext, int>);
414+
415+
TEST(IndicateVarDataTypeTest, lodtensor) {
416+
paddle::framework::InitDevices(true);
417+
paddle::framework::proto::OpDesc op_desc;
418+
op_desc.set_type("indicate_lod_tensor_data_type_test");
419+
BuildVar("LoDTensor", {"lodtensor_1"}, op_desc.add_inputs());
420+
421+
paddle::platform::CPUPlace cpu_place;
422+
paddle::framework::Scope scope;
423+
424+
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
425+
auto* var = scope.Var("lodtensor_1");
426+
var->GetMutable<paddle::framework::LoDTensor>();
427+
428+
bool caught = false;
429+
try {
430+
op->Run(scope, cpu_place);
431+
} catch (paddle::platform::EnforceNotMet err) {
432+
caught = true;
433+
std::string ex_msg = err.what();
434+
EXPECT_TRUE(
435+
ex_msg.find(
436+
"The Tensor in the indicate_lod_tensor_data_type_test Op's "
437+
"Input Variable LoDTensor(lodtensor_1) is not initialized") !=
438+
std::string::npos);
439+
}
440+
ASSERT_TRUE(caught);
441+
}
442+
443+
TEST(IndicateVarDataTypeTest, selectedrows) {
444+
paddle::framework::InitDevices(true);
445+
paddle::framework::proto::OpDesc op_desc;
446+
op_desc.set_type("indicate_selected_rows_data_type_test");
447+
BuildVar("SelectedRows", {"selected_rows_1"}, op_desc.add_inputs());
448+
449+
paddle::platform::CPUPlace cpu_place;
450+
paddle::framework::Scope scope;
451+
452+
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
453+
auto* var = scope.Var("selected_rows_1");
454+
var->GetMutable<paddle::framework::SelectedRows>();
455+
456+
bool caught = false;
457+
try {
458+
op->Run(scope, cpu_place);
459+
} catch (paddle::platform::EnforceNotMet err) {
460+
caught = true;
461+
std::string ex_msg = err.what();
462+
EXPECT_TRUE(
463+
ex_msg.find("The Tensor in the indicate_selected_rows_data_type_test "
464+
"Op's Input Variable SelectedRows(selected_rows_1) is not "
465+
"initialized") != std::string::npos);
466+
}
467+
ASSERT_TRUE(caught);
468+
}
469+
470+
TEST(IndicateVarDataTypeTest, other) {
471+
paddle::framework::InitDevices(true);
472+
paddle::framework::proto::OpDesc op_desc;
473+
op_desc.set_type("indicate_other_data_type_test");
474+
BuildVar("Other", {"lod_tensor_array_1"}, op_desc.add_inputs());
475+
476+
paddle::platform::CPUPlace cpu_place;
477+
paddle::framework::Scope scope;
478+
479+
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
480+
auto* var = scope.Var("lod_tensor_array_1");
481+
var->GetMutable<paddle::framework::LoDTensorArray>();
482+
483+
bool caught = false;
484+
try {
485+
op->Run(scope, cpu_place);
486+
} catch (paddle::platform::EnforceNotMet err) {
487+
caught = true;
488+
std::string ex_msg = err.what();
489+
EXPECT_TRUE(ex_msg.find("The Input Variable(Other) of "
490+
"indicate_other_data_type_test Op used to "
491+
"determine kernel data type "
492+
"is empty or not LoDTensor or SelectedRows") !=
493+
std::string::npos);
494+
}
495+
ASSERT_TRUE(caught);
496+
}

paddle/fluid/framework/variable.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ class Variable {
3030
static_assert(
3131
IsRegisteredVarType<T>(),
3232
"Not registered type. Please register T inside var_type_traits.h");
33-
PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing");
33+
PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized.");
3434
PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
35-
"Variable must be type %s, the holding type is %s",
35+
"The Variable type must be %s, but the type it holds is %s.",
3636
ToTypeName(VarTypeTrait<T>::kId),
3737
ToTypeName(holder_->Type()));
3838
return *static_cast<const T*>(holder_->Ptr());
@@ -45,10 +45,10 @@ class Variable {
4545
if (!holder_) {
4646
holder_.reset(new PlaceholderImpl<T>());
4747
} else {
48-
PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
49-
"Variable must be type %s, the holding type is %s",
50-
ToTypeName(VarTypeTrait<T>::kId),
51-
ToTypeName(holder_->Type()));
48+
PADDLE_ENFORCE(
49+
holder_->Type() == VarTypeTrait<T>::kId,
50+
"The Variable type must be %s, but the type it holds is %s.",
51+
ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type()));
5252
}
5353
return static_cast<T*>(holder_->Ptr());
5454
}
@@ -61,7 +61,7 @@ class Variable {
6161
void Clear() { holder_.reset(); }
6262

6363
int Type() const {
64-
PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory");
64+
PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized.");
6565
return holder_->Type();
6666
}
6767

paddle/fluid/operators/activation_op.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,8 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
114114
layout = framework::DataLayout::kMKLDNN;
115115
}
116116
#endif
117-
return framework::OpKernelType(
118-
framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
119-
library);
117+
return framework::OpKernelType(oper.IndicateVarDataType(ctx, name),
118+
ctx.GetPlace(), layout, library);
120119
}
121120

122121
class ActivationOp : public framework::OperatorWithKernel {

paddle/fluid/operators/add_position_encoding_op.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel {
3737
protected:
3838
framework::OpKernelType GetExpectedKernelType(
3939
const framework::ExecutionContext& ctx) const override {
40-
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
41-
platform::CPUPlace());
40+
return framework::OpKernelType(
41+
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
42+
platform::CPUPlace());
4243
}
4344
};
4445

@@ -56,9 +57,9 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel {
5657
protected:
5758
framework::OpKernelType GetExpectedKernelType(
5859
const framework::ExecutionContext& ctx) const override {
59-
return framework::OpKernelType(
60-
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
61-
platform::CPUPlace());
60+
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
61+
ctx, framework::GradVarName("Out")),
62+
platform::CPUPlace());
6263
}
6364
};
6465

0 commit comments

Comments
 (0)