Skip to content

Commit 239715a

Browse files
authored
[Dy2St] Optimize range_block_do performance (#69834)
1 parent 035be36 commit 239715a

File tree

2 files changed

+18
-21
lines changed

2 files changed

+18
-21
lines changed

paddle/fluid/pybind/pir.cc

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,30 +1690,27 @@ void BindInsertionPoint(pybind11::module *m) {
16901690
return_value_policy::reference);
16911691
}
16921692

1693-
std::list<Operation *>::const_iterator list_offset(const Block *block,
1694-
int start_idx) {
1695-
auto it = block->begin();
1696-
while (it != block->end() && start_idx--) ++it;
1697-
return it;
1698-
}
1699-
17001693
template <typename F, typename S>
17011694
void range_block_do(const Block *block,
1702-
std::vector<int> range,
1695+
std::pair<size_t, size_t> range,
17031696
F fn,
17041697
S skip_fn) {
1705-
for (auto it = list_offset(block, range[0]);
1706-
it != list_offset(block, range[1]);
1707-
++it) {
1708-
if (skip_fn(*it)) {
1698+
auto [start, end] = range;
1699+
if (start >= end) {
1700+
return;
1701+
}
1702+
auto it = block->begin();
1703+
std::advance(it, start);
1704+
for (size_t i = start; i < end && it != block->end(); ++i, ++it) {
1705+
if (skip_fn(it)) {
17091706
continue;
17101707
}
1711-
fn(*it);
1708+
fn(it);
17121709
}
17131710
}
17141711

17151712
template <typename F>
1716-
void range_block_do(const Block *block, std::vector<int> range, F fn) {
1713+
void range_block_do(const Block *block, std::pair<size_t, size_t> range, F fn) {
17171714
range_block_do(block, range, fn, [](Operation *op) { return false; });
17181715
}
17191716

@@ -1754,8 +1751,8 @@ std::pair<std::vector<pir::Value>, std::unordered_set<pir::Value>>
17541751
AnalysisMiddleVariable(const Program &program,
17551752
const std::vector<pir::Value> &forward_inputs,
17561753
const std::vector<pir::Value> &backward_outputs,
1757-
const std::vector<int> &forward_range,
1758-
const std::vector<int> &backward_range) {
1754+
const std::pair<size_t, size_t> &forward_range,
1755+
const std::pair<size_t, size_t> &backward_range) {
17591756
std::vector<pir::Value> middle_values;
17601757

17611758
std::unordered_set<pir::Value> backward_used_values;
@@ -1811,7 +1808,7 @@ using SplitedAttribute = std::map<std::string, std::vector<pir::Value>>;
18111808
using SplitedResult = std::pair<SplitedProgram, SplitedAttribute>;
18121809

18131810
static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
1814-
std::vector<int> range) {
1811+
std::pair<size_t, size_t> range) {
18151812
// filter no need buffer values.
18161813
std::unordered_set<::pir::Value> need_buffer_values;
18171814
std::unordered_set<::pir::Value> no_need_buffer_values;
@@ -1926,8 +1923,8 @@ SplitedResult SplitForwardBackward(
19261923
const std::vector<pir::Value> &forward_inputs_grads,
19271924
const std::vector<pir::Value> &forward_params_grads,
19281925
const std::vector<pir::Value> &forward_outputs_grads,
1929-
const std::vector<int> &forward_range,
1930-
const std::vector<int> &backward_range) {
1926+
const std::pair<size_t, size_t> &forward_range,
1927+
const std::pair<size_t, size_t> &backward_range) {
19311928
std::vector<pir::Value> forward_in_out_values;
19321929
for (auto &v :
19331930
std::vector({&forward_inputs, &forward_outputs, &forward_params})) {

python/paddle/jit/dy2static/pir_partial_program.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ def split_forward_backward(self):
260260
self.x_grad_values,
261261
self.param_grad_values,
262262
self.out_grad_values,
263-
list(self.forward_range),
264-
list(self.backward_range),
263+
self.forward_range,
264+
self.backward_range,
265265
)
266266
return [fwd_prog, bwd_prog], prog_attr
267267

0 commit comments

Comments
 (0)