@@ -75,6 +75,7 @@ std::vector<std::string> NgraphEngine::feed_vars = {};
75
75
std::vector<std::string> NgraphEngine::fetch_vars = {};
76
76
framework::Variable* NgraphEngine::pre_var_ptr = nullptr ;
77
77
const framework::BlockDesc* NgraphEngine::p_bdesc = nullptr ;
78
+ bool NgraphEngine::is_training = false ;
78
79
79
80
std::unordered_map<std::string, EngineCache> NgraphEngine::engine_cache = {};
80
81
std::unordered_map<std::string,
@@ -93,11 +94,13 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
93
94
int size = ops->size ();
94
95
int left = 0 ;
95
96
while (left < size && ops->at (left)->Type () != framework::kFeedOpType &&
97
+ ops->at (left)->Type () != " read" &&
96
98
ops->at (left)->Type () != framework::kFetchOpType ) {
97
99
++left;
98
100
}
99
101
100
- while (left < size && ops->at (left)->Type () == framework::kFeedOpType ) {
102
+ while (left < size && (ops->at (left)->Type () == framework::kFeedOpType ||
103
+ ops->at (left)->Type () == " read" )) {
101
104
for (auto & var_name_item : ops->at (left)->Outputs ()) {
102
105
for (auto & var_name : var_name_item.second ) {
103
106
NgraphEngine::feed_vars.emplace_back (var_name);
@@ -270,6 +273,7 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
270
273
271
274
for (auto op_desc : ops_desc) {
272
275
if (op_desc->Type ().find (" _grad" ) != std::string::npos) {
276
+ is_training = true ;
273
277
this ->is_test_ = false ;
274
278
break ;
275
279
}
@@ -590,7 +594,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
590
594
}
591
595
bool is_persistable =
592
596
(p_persistables->find (vi) != p_persistables->end ()) ? true : false ;
593
- if (is_test && is_persistable) {
597
+ if (!is_training && is_test && is_persistable) {
594
598
ti->set_stale (false );
595
599
}
596
600
(*p_t_in).emplace_back (ti);
0 commit comments