Skip to content

Commit f1b934d

Browse files
committed
Optimize tensor copies on inference path
1 parent c115709 commit f1b934d

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

include/graph/graph.hpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,13 @@ class Graph {
247247
if (it != branch_list_.rend()) {
248248
for (size_t f = 0; f < it->distribution.size(); ++f) {
249249
if (it->distribution[f].first == current_layer) {
250-
inten_.push_back(it->give_for_all[it->distribution[f].second]);
250+
bool last_use = (it->count_used_ten == 1);
251+
auto& src = it->give_for_all[it->distribution[f].second];
252+
if (last_use) {
253+
inten_.push_back(std::move(src));
254+
} else {
255+
inten_.push_back(src);
256+
}
251257
}
252258
}
253259
}
@@ -262,6 +268,9 @@ class Graph {
262268
}
263269
}
264270
}
271+
if (outten_.empty()) {
272+
outten_.resize(1);
273+
}
265274
layers_[current_layer]->run(inten_, outten_);
266275

267276
#ifdef ENABLE_STATISTIC_TENSORS
@@ -272,18 +281,18 @@ class Graph {
272281
weights_.push_back(layers_[current_layer]->get_weights());
273282
#endif
274283

275-
inten_ = outten_;
284+
inten_.swap(outten_);
276285

277286
if (layers_[current_layer]->postops.count > 0) {
278287
for (unsigned int j = 0; j < layers_[current_layer]->postops.count;
279288
j++) {
280289
layers_[current_layer]->postops.layers[j]->run(inten_, outten_);
281290
}
282-
inten_ = outten_;
291+
inten_.swap(outten_);
283292
}
284293

285294
BranchState new_branch;
286-
new_branch.give_for_all = inten_;
295+
new_branch.give_for_all = std::move(inten_);
287296
new_branch.count_used_ten = countinout[current_layer].second;
288297
new_branch.ind_layer = current_layer;
289298
new_branch.split = layers_[current_layer]->getName() == kSplit;
@@ -308,7 +317,12 @@ class Graph {
308317
}
309318
new_branch.distribution = dis;
310319
}
311-
branch_list_.push_back(new_branch);
320+
branch_list_.push_back(std::move(new_branch));
321+
if (outtenres_ && current_layer == end_ &&
322+
!branch_list_.back().give_for_all.empty() &&
323+
countinout[current_layer].second == 0) {
324+
*outtenres_ = std::move(branch_list_.back().give_for_all[0]);
325+
}
312326

313327
#ifdef ENABLE_STATISTIC_TIME
314328
auto end = std::chrono::high_resolution_clock::now();
@@ -318,10 +332,6 @@ class Graph {
318332
time_layer_.push_back(layers_[current_layer]->getName());
319333
#endif
320334
}
321-
322-
if (outtenres_ && !outten_.empty()) {
323-
*outtenres_ = outten_[0];
324-
}
325335
}
326336

327337
void setOutput(Layer* layer, Tensor& vec) {
@@ -437,4 +447,4 @@ class Graph {
437447
return traversal;
438448
}
439449
};
440-
} // namespace it_lab_ai
450+
} // namespace it_lab_ai

0 commit comments

Comments
 (0)