@@ -1690,30 +1690,27 @@ void BindInsertionPoint(pybind11::module *m) {
1690
1690
return_value_policy::reference);
1691
1691
}
1692
1692
1693
- std::list<Operation *>::const_iterator list_offset (const Block *block,
1694
- int start_idx) {
1695
- auto it = block->begin ();
1696
- while (it != block->end () && start_idx--) ++it;
1697
- return it;
1698
- }
1699
-
1700
1693
template <typename F, typename S>
1701
1694
void range_block_do (const Block *block,
1702
- std::vector< int > range,
1695
+ std::pair< size_t , size_t > range,
1703
1696
F fn,
1704
1697
S skip_fn) {
1705
- for (auto it = list_offset (block, range[0 ]);
1706
- it != list_offset (block, range[1 ]);
1707
- ++it) {
1708
- if (skip_fn (*it)) {
1698
+ auto [start, end] = range;
1699
+ if (start >= end) {
1700
+ return ;
1701
+ }
1702
+ auto it = block->begin ();
1703
+ std::advance (it, start);
1704
+ for (size_t i = start; i < end && it != block->end (); ++i, ++it) {
1705
+ if (skip_fn (it)) {
1709
1706
continue ;
1710
1707
}
1711
- fn (* it);
1708
+ fn (it);
1712
1709
}
1713
1710
}
1714
1711
1715
1712
template <typename F>
1716
- void range_block_do (const Block *block, std::vector< int > range, F fn) {
1713
+ void range_block_do (const Block *block, std::pair< size_t , size_t > range, F fn) {
1717
1714
range_block_do (block, range, fn, [](Operation *op) { return false ; });
1718
1715
}
1719
1716
@@ -1754,8 +1751,8 @@ std::pair<std::vector<pir::Value>, std::unordered_set<pir::Value>>
1754
1751
AnalysisMiddleVariable (const Program &program,
1755
1752
const std::vector<pir::Value> &forward_inputs,
1756
1753
const std::vector<pir::Value> &backward_outputs,
1757
- const std::vector< int > &forward_range,
1758
- const std::vector< int > &backward_range) {
1754
+ const std::pair< size_t , size_t > &forward_range,
1755
+ const std::pair< size_t , size_t > &backward_range) {
1759
1756
std::vector<pir::Value> middle_values;
1760
1757
1761
1758
std::unordered_set<pir::Value> backward_used_values;
@@ -1811,7 +1808,7 @@ using SplitedAttribute = std::map<std::string, std::vector<pir::Value>>;
1811
1808
using SplitedResult = std::pair<SplitedProgram, SplitedAttribute>;
1812
1809
1813
1810
static auto GetNoNeedBufferValue (const ::pir::Block *whole_block,
1814
- std::vector< int > range) {
1811
+ std::pair< size_t , size_t > range) {
1815
1812
// filter no need buffer values.
1816
1813
std::unordered_set<::pir::Value> need_buffer_values;
1817
1814
std::unordered_set<::pir::Value> no_need_buffer_values;
@@ -1926,8 +1923,8 @@ SplitedResult SplitForwardBackward(
1926
1923
const std::vector<pir::Value> &forward_inputs_grads,
1927
1924
const std::vector<pir::Value> &forward_params_grads,
1928
1925
const std::vector<pir::Value> &forward_outputs_grads,
1929
- const std::vector< int > &forward_range,
1930
- const std::vector< int > &backward_range) {
1926
+ const std::pair< size_t , size_t > &forward_range,
1927
+ const std::pair< size_t , size_t > &backward_range) {
1931
1928
std::vector<pir::Value> forward_in_out_values;
1932
1929
for (auto &v :
1933
1930
std::vector ({&forward_inputs, &forward_outputs, &forward_params})) {
0 commit comments