Skip to content

Commit 6d5a04c

Browse files
authored
add op type in check nan/inf (#15986)
* add op name in check nan/inf, test=develop
1 parent 187cffd commit 6d5a04c

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
882882
const RuntimeContext& ctx_;
883883
};
884884

885-
static void CheckTensorNANOrInf(const std::string& name,
885+
static void CheckTensorNANOrInf(const std::string& op_type,
886+
const std::string& name,
886887
const framework::Tensor& tensor) {
887888
if (tensor.memory_size() == 0) {
888889
return;
@@ -892,9 +893,9 @@ static void CheckTensorNANOrInf(const std::string& name,
892893
return;
893894
}
894895
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
895-
"Tensor %s contains Inf", name);
896+
"Operator %s output Tensor %s contains Inf", op_type, name);
896897
PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor),
897-
"Tensor %s contains NAN", name);
898+
"Operator %s output Tensor %s contains NAN", op_type, name);
898899
}
899900

900901
void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
@@ -988,9 +989,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
988989
auto* var = exec_scope.FindVar(vname);
989990
if (var == nullptr) continue;
990991
if (var->IsType<framework::LoDTensor>()) {
991-
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
992+
CheckTensorNANOrInf(type_, vname, var->Get<framework::LoDTensor>());
992993
} else if (var->IsType<framework::SelectedRows>()) {
993-
CheckTensorNANOrInf(vname, var->Get<framework::SelectedRows>().value());
994+
CheckTensorNANOrInf(type_, vname,
995+
var->Get<framework::SelectedRows>().value());
994996
}
995997
}
996998
}

0 commit comments

Comments
 (0)