Skip to content

Commit a0af374

Browse files
baojun-nervanatensor-tang
authored andcommitted
fix training validation test=release/1.4 (#16716)
1 parent 266cdf7 commit a0af374

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

paddle/fluid/operators/ngraph/ngraph_engine.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ std::vector<std::string> NgraphEngine::feed_vars = {};
7575
std::vector<std::string> NgraphEngine::fetch_vars = {};
7676
framework::Variable* NgraphEngine::pre_var_ptr = nullptr;
7777
const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr;
78+
bool NgraphEngine::is_training = false;
7879

7980
std::unordered_map<std::string, EngineCache> NgraphEngine::engine_cache = {};
8081
std::unordered_map<std::string,
@@ -93,11 +94,13 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
9394
int size = ops->size();
9495
int left = 0;
9596
while (left < size && ops->at(left)->Type() != framework::kFeedOpType &&
97+
ops->at(left)->Type() != "read" &&
9698
ops->at(left)->Type() != framework::kFetchOpType) {
9799
++left;
98100
}
99101

100-
while (left < size && ops->at(left)->Type() == framework::kFeedOpType) {
102+
while (left < size && (ops->at(left)->Type() == framework::kFeedOpType ||
103+
ops->at(left)->Type() == "read")) {
101104
for (auto& var_name_item : ops->at(left)->Outputs()) {
102105
for (auto& var_name : var_name_item.second) {
103106
NgraphEngine::feed_vars.emplace_back(var_name);
@@ -270,6 +273,7 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
270273

271274
for (auto op_desc : ops_desc) {
272275
if (op_desc->Type().find("_grad") != std::string::npos) {
276+
is_training = true;
273277
this->is_test_ = false;
274278
break;
275279
}
@@ -590,7 +594,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
590594
}
591595
bool is_persistable =
592596
(p_persistables->find(vi) != p_persistables->end()) ? true : false;
593-
if (is_test && is_persistable) {
597+
if (!is_training && is_test && is_persistable) {
594598
ti->set_stale(false);
595599
}
596600
(*p_t_in).emplace_back(ti);

paddle/fluid/operators/ngraph/ngraph_engine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class NgraphEngine {
5757

5858
void Run(const framework::Scope& scope, const platform::Place& place) const;
5959

60+
static bool is_training;
6061
static const framework::BlockDesc* p_bdesc;
6162
static std::vector<std::string> feed_vars, fetch_vars;
6263

0 commit comments

Comments
 (0)