@@ -882,7 +882,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
882
882
const RuntimeContext& ctx_;
883
883
};
884
884
885
- static void CheckTensorNANOrInf (const std::string& name,
885
+ static void CheckTensorNANOrInf (const std::string& op_type,
886
+ const std::string& name,
886
887
const framework::Tensor& tensor) {
887
888
if (tensor.memory_size () == 0 ) {
888
889
return ;
@@ -892,9 +893,9 @@ static void CheckTensorNANOrInf(const std::string& name,
892
893
return ;
893
894
}
894
895
PADDLE_ENFORCE (!framework::TensorContainsInf (tensor),
895
- " Tensor %s contains Inf" , name);
896
+ " Operator %s output Tensor %s contains Inf" , op_type , name);
896
897
PADDLE_ENFORCE (!framework::TensorContainsNAN (tensor),
897
- " Tensor %s contains NAN" , name);
898
+ " Operator %s output Tensor %s contains NAN" , op_type , name);
898
899
}
899
900
900
901
void OperatorWithKernel::RuntimeInferShape (const Scope& scope,
@@ -988,9 +989,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
988
989
auto * var = exec_scope.FindVar (vname);
989
990
if (var == nullptr ) continue ;
990
991
if (var->IsType <framework::LoDTensor>()) {
991
- CheckTensorNANOrInf (vname, var->Get <framework::LoDTensor>());
992
+ CheckTensorNANOrInf (type_, vname, var->Get <framework::LoDTensor>());
992
993
} else if (var->IsType <framework::SelectedRows>()) {
993
- CheckTensorNANOrInf (vname, var->Get <framework::SelectedRows>().value ());
994
+ CheckTensorNANOrInf (type_, vname,
995
+ var->Get <framework::SelectedRows>().value ());
994
996
}
995
997
}
996
998
}
0 commit comments