Skip to content

Commit 7d1332f

Browse files
author
chengduo
authored
Merge pull request #11006 from chengduoZH/fix_add_check_nan_inf_in_operator
Move check_nan_inf to operator
2 parents 15db5a5 + cb1c657 commit 7d1332f

File tree

2 files changed

+28
-27
lines changed

2 files changed

+28
-27
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ limitations under the License. */
2424
#include "paddle/fluid/platform/profiler.h"
2525

2626
DECLARE_bool(benchmark);
27-
DEFINE_bool(check_nan_inf, false,
28-
"Checking whether operator produce NAN/INF or not. It will be "
29-
"extremely slow so please use this flag wisely.");
3027

3128
namespace paddle {
3229
namespace framework {
@@ -78,21 +75,6 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
7875
}
7976
}
8077

81-
static void CheckTensorNANOrInf(const std::string& name,
82-
const framework::Tensor& tensor) {
83-
if (tensor.memory_size() == 0) {
84-
return;
85-
}
86-
if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT
87-
tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
88-
return;
89-
}
90-
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
91-
"Tensor %s contains Inf", name);
92-
PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor),
93-
"Tensor %s contains NAN", name);
94-
}
95-
9678
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
9779
int block_id) {
9880
auto& global_block = pdesc.Block(block_id);
@@ -340,15 +322,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
340322
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
341323
<< memory::memory_usage(place_);
342324
}
343-
if (FLAGS_check_nan_inf) {
344-
for (auto& vname : op->OutputVars(true)) {
345-
auto* var = local_scope->FindVar(vname);
346-
if (var == nullptr) continue;
347-
if (var->IsType<framework::LoDTensor>()) {
348-
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
349-
}
350-
}
351-
}
352325
}
353326
platform::DeviceContextPool::Instance().Get(place_)->Wait();
354327
if (create_vars && create_local_scope) {

paddle/fluid/framework/operator.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ limitations under the License. */
2424
#include "paddle/fluid/platform/profiler.h"
2525

2626
DECLARE_bool(benchmark);
27+
DEFINE_bool(check_nan_inf, false,
28+
"Checking whether operator produce NAN/INF or not. It will be "
29+
"extremely slow so please use this flag wisely.");
2730

2831
namespace paddle {
2932
namespace framework {
@@ -513,6 +516,21 @@ class RuntimeInferShapeContext : public InferShapeContext {
513516
const Scope& scope_;
514517
};
515518

519+
static void CheckTensorNANOrInf(const std::string& name,
520+
const framework::Tensor& tensor) {
521+
if (tensor.memory_size() == 0) {
522+
return;
523+
}
524+
if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT
525+
tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
526+
return;
527+
}
528+
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
529+
"Tensor %s contains Inf", name);
530+
PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor),
531+
"Tensor %s contains NAN", name);
532+
}
533+
516534
void OperatorWithKernel::RunImpl(const Scope& scope,
517535
const platform::Place& place) const {
518536
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
@@ -597,6 +615,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
597615
if (FLAGS_benchmark) {
598616
new_dev_ctx->Wait();
599617
}
618+
619+
if (FLAGS_check_nan_inf) {
620+
for (auto& vname : OutputVars(true)) {
621+
auto* var = new_scope.FindVar(vname);
622+
if (var == nullptr) continue;
623+
if (var->IsType<framework::LoDTensor>()) {
624+
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
625+
}
626+
}
627+
}
600628
}
601629

602630
proto::VarType::Type OperatorWithKernel::IndicateDataType(

0 commit comments

Comments
 (0)