Skip to content

Commit 8273a5f

Browse files
Add lazy segment tree template
1 parent db4937d commit 8273a5f

File tree

8 files changed

+176
-13
lines changed

8 files changed

+176
-13
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ node_modules
66
library/test.exe
77
library/test.d
88
.yarn/
9+
*.gch

.vscode/settings.json

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,9 @@
7676
"typeinfo": "cpp",
7777
"variant": "cpp",
7878
"text_encoding": "cpp"
79-
}
80-
}
79+
},
80+
"clangd.fallbackFlags": [
81+
"-std=c++14",
82+
"-I${workspaceFolder}/library"
83+
]
84+
}
File renamed without changes.
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#include "mrpython/bits.hpp"
2+
3+
#ifndef MP_LIBRARY_LAZY_SEGMENT_TREE_HPP
4+
#define MP_LIBRARY_LAZY_SEGMENT_TREE_HPP
5+
#include <algorithm>
6+
#include <functional>
7+
#include <vector>
8+
9+
namespace mrpython {
10+
using std::size_t;
11+
template <typename T, typename MergeFunction, typename Lazy,
12+
typename OperateFunction, typename MergeLazyFunction>
13+
class lazy_segment_tree {
14+
std::vector<T> data;
15+
std::vector<Lazy> lazy;
16+
std::vector<size_t> size;
17+
size_t n;
18+
Lazy const lazyInit;
19+
MergeFunction const mergeData;
20+
OperateFunction const operate;
21+
MergeLazyFunction const mergeLazy;
22+
void build(void) {
23+
data.reserve(2 * n - 1), size.reserve(2 * n - 1);
24+
for (size_t i = n; i < 2 * n - 1; ++i) {
25+
size_t d = 2 * n - 1 - i;
26+
size_t l = d * 2 + 1, r = d * 2;
27+
data.emplace_back(mergeData(data[2 * n - 1 - l], data[2 * n - 1 - r]));
28+
size.emplace_back(size[2 * n - 1 - l] + size[2 * n - 1 - r]);
29+
}
30+
std::reverse(data.begin(), data.end());
31+
std::reverse(size.begin(), size.end());
32+
}
33+
void pushdown(size_t pos) {
34+
if (size[pos] == 1) {
35+
lazy[pos] = lazyInit;
36+
return;
37+
}
38+
add_tag_for_node(2 * pos + 1, lazy[pos]);
39+
add_tag_for_node(2 * pos + 2, lazy[pos]);
40+
lazy[pos] = lazyInit;
41+
}
42+
void add_tag_for_node(size_t pos, Lazy const& lazyVal) {
43+
data[pos] = operate(lazyVal, data[pos], size[pos]);
44+
lazy[pos] = mergeLazy(lazy[pos], lazyVal);
45+
}
46+
T get_impl(size_t l, size_t r, size_t pos) {
47+
if (l == 0 && r == size[pos]) return data[pos];
48+
pushdown(pos);
49+
size_t m = size[pos * 2 + 1];
50+
if (l < m && r > m)
51+
return mergeData(get_impl(l, m, pos * 2 + 1),
52+
get_impl(0, r - m, pos * 2 + 2));
53+
else if (l < m)
54+
return get_impl(l, r, pos * 2 + 1);
55+
else if (r > m)
56+
return get_impl(l - m, r - m, pos * 2 + 2);
57+
else
58+
__builtin_unreachable();
59+
}
60+
void set_impl(size_t l, size_t r, Lazy const& operateVal, size_t pos) {
61+
if (l == 0 && r == size[pos]) return add_tag_for_node(pos, operateVal);
62+
pushdown(pos);
63+
size_t m = size[pos * 2 + 1];
64+
if (l < m && r > m)
65+
set_impl(l, m, operateVal, pos * 2 + 1),
66+
set_impl(0, r - m, operateVal, pos * 2 + 2);
67+
else if (l < m)
68+
set_impl(l, r, operateVal, pos * 2 + 1);
69+
else if (r > m)
70+
set_impl(l - m, r - m, operateVal, pos * 2 + 2);
71+
else
72+
__builtin_unreachable();
73+
data[pos] = mergeData(data[pos * 2 + 1], data[pos * 2 + 2]);
74+
}
75+
76+
public:
77+
template <typename InputIterator>
78+
lazy_segment_tree(InputIterator first, InputIterator last,
79+
Lazy const& lazyInitVal,
80+
MergeFunction const& mergeDataFun = MergeFunction(),
81+
OperateFunction const& OperateFun = OperateFunction(),
82+
MergeLazyFunction const& mergeTagFun = MergeLazyFunction())
83+
: data(first, last),
84+
lazy(2 * data.size() - 1, lazyInitVal),
85+
size(data.size(), 1),
86+
n(data.size()),
87+
lazyInit(lazyInitVal),
88+
mergeData(mergeDataFun),
89+
operate(OperateFun),
90+
mergeLazy(mergeTagFun) {
91+
rotate(data.begin(), data.begin() + (2 * n - 1) - (highbit(2 * n - 1) - 1),
92+
data.end());
93+
reverse(data.begin(), data.end());
94+
build();
95+
}
96+
lazy_segment_tree(size_t len, T const& init, Lazy const& lazyInitVal,
97+
MergeFunction const& mergeDataFun = MergeFunction(),
98+
OperateFunction const& OperateFun = OperateFunction(),
99+
MergeLazyFunction const& mergeTagFun = MergeLazyFunction())
100+
: data(len, init),
101+
lazy(2 * len - 1, lazyInitVal),
102+
size(len, 1),
103+
n(len),
104+
lazyInit(lazyInitVal),
105+
mergeData(mergeDataFun),
106+
mergeLazy(mergeTagFun) {
107+
build();
108+
}
109+
T get(size_t l, size_t r) { return get_impl(l, r, 0); }
110+
void set(size_t l, size_t r, Lazy const& operateVal) {
111+
set_impl(l, r, operateVal, 0);
112+
}
113+
};
114+
template <typename T> struct lazy_segment_tree_add_add_operate_function {
115+
T operator()(T const& lazy, T const& data, size_t size) const {
116+
return data + lazy * size;
117+
}
118+
};
119+
template <typename T>
120+
using lazy_segment_tree_add_add =
121+
lazy_segment_tree<T, std::plus<T>, T,
122+
lazy_segment_tree_add_add_operate_function<T>,
123+
std::plus<T>>;
124+
} // namespace mrpython
125+
#endif

library/mrpython/typical_segment_tree.hpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@
44
#ifndef MP_LIBRARY_SEGMENT_TREE_HPP
55
#define MP_LIBRARY_SEGMENT_TREE_HPP
66
#include <algorithm>
7-
#include <iterator>
8-
#include <utility>
97
#include <vector>
108

119
namespace mrpython {
1210
using std::size_t;
1311
/**
1412
* 一颗 无标记线段树 模板。
1513
* @tparam T 元素类型
16-
* @tparam MergeFunction 合并运算函数,对于函数 `F` 应满足结合律
14+
* @tparam MergeFunction 合并运算函数
1715
*/
1816
template <typename T, typename MergeFunction> class typical_segment_tree {
1917
std::vector<T> data;
@@ -31,10 +29,10 @@ template <typename T, typename MergeFunction> class typical_segment_tree {
3129
std::reverse(data.begin(), data.end());
3230
std::reverse(size.begin(), size.end());
3331
}
34-
template <typename Fun>
35-
void set_impl(size_t c, Fun const& operate, size_t pos) {
32+
template <typename Operate>
33+
void set_impl(size_t c, Operate const& operate, size_t pos) {
3634
if (size[pos] == 1) {
37-
data[pos] = operate((typename std::vector<T>::const_reference)data[pos]);
35+
data[pos] = operate((T const&)data[pos]);
3836
return;
3937
}
4038
size_t m = size[pos * 2 + 1];
@@ -50,15 +48,18 @@ template <typename T, typename MergeFunction> class typical_segment_tree {
5048
if (l < m && r > m)
5149
return merge(get_impl(l, m, pos * 2 + 1),
5250
get_impl(0, r - m, pos * 2 + 2));
53-
if (l < m) return get_impl(l, r, pos * 2 + 1);
54-
if (r > m) return get_impl(l - m, r - m, pos * 2 + 2);
55-
throw;
51+
else if (l < m)
52+
return get_impl(l, r, pos * 2 + 1);
53+
else if (r > m)
54+
return get_impl(l - m, r - m, pos * 2 + 2);
55+
else
56+
__builtin_unreachable();
5657
}
5758

5859
public:
5960
template <typename InputIterator>
6061
typical_segment_tree(InputIterator first, InputIterator last,
61-
MergeFunction mergeFun = MergeFunction())
62+
MergeFunction const& mergeFun = MergeFunction())
6263
: data(first, last),
6364
size(data.size(), 1),
6465
n(data.size()),
@@ -69,7 +70,7 @@ template <typename T, typename MergeFunction> class typical_segment_tree {
6970
build();
7071
}
7172
typical_segment_tree(size_t len, T const& init,
72-
MergeFunction mergeFun = MergeFunction())
73+
MergeFunction const& mergeFun = MergeFunction())
7374
: data(len, init), size(len, 1), n(len), merge(mergeFun) {
7475
build();
7576
}

library/test/all.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
#include "lazy_segment_tree/all.hpp"
12
#include "sparse_table/all.hpp"
23
#include "typical_segment_tree/all.hpp"
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#include "lazy_segment_tree_add_add.cpp"
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <mrpython/lazy_segment_tree.hpp>
4+
#include <random>
5+
6+
TEST(lazy_segment_tree, add) {
7+
std::mt19937_64 gen(std::random_device{}());
8+
std::size_t n = std::uniform_int_distribution<std::size_t>{1, 5000}(gen),
9+
q = std::uniform_int_distribution<std::size_t>{1, 5000}(gen);
10+
std::uniform_int_distribution<unsigned> val_dist(
11+
std::numeric_limits<unsigned>::min(),
12+
std::numeric_limits<unsigned>::max()),
13+
size_dist(0, n - 1), operator_dist(0, 1);
14+
std::vector<unsigned> a(n);
15+
std::generate(a.begin(), a.end(), [&] { return val_dist(gen); });
16+
mrpython::lazy_segment_tree_add_add<unsigned> tree(a.begin(), a.end(), 0);
17+
while (q--) {
18+
std::size_t l = size_dist(gen), r = size_dist(gen);
19+
if (l > r) std::swap(l, r);
20+
assert(l < r + 1);
21+
if (operator_dist(gen)) {
22+
int ans = std::accumulate(a.begin() + l, a.begin() + r + 1, (unsigned)0);
23+
EXPECT_EQ(tree.get(l, r + 1), ans);
24+
} else {
25+
unsigned value = val_dist(gen);
26+
for (std::size_t i = l; i < r; ++i) a[i] += value;
27+
tree.set(l, r, value);
28+
}
29+
}
30+
}

0 commit comments

Comments
 (0)