|
| 1 | +#include "mrpython/bits.hpp" |
| 2 | +#include "mrpython/utility.hpp" |
| 3 | + |
| 4 | +#ifndef MP_LIBRARY_SEGMENT_TREE_HPP |
| 5 | +#define MP_LIBRARY_SEGMENT_TREE_HPP |
| 6 | +#include <algorithm> |
| 7 | +#include <iterator> |
| 8 | +#include <utility> |
| 9 | +#include <vector> |
| 10 | + |
| 11 | +namespace mrpython { |
| 12 | +using std::size_t; |
| 13 | +/** |
| 14 | + * 一颗 无标记线段树 模板。 |
| 15 | + * @tparam T 元素类型 |
| 16 | + * @tparam MergeFunction 合并运算函数,对于函数 `F` 应满足结合律 |
| 17 | + */ |
| 18 | +template <typename T, typename MergeFunction> class typical_segment_tree { |
| 19 | + std::vector<T> data; |
| 20 | + std::vector<size_t> size; |
| 21 | + size_t n; |
| 22 | + MergeFunction merge; |
| 23 | + void build(void) { |
| 24 | + data.reserve(2 * n - 1), size.reserve(2 * n - 1); |
| 25 | + for (size_t i = n; i < 2 * n - 1; ++i) { |
| 26 | + size_t d = 2 * n - 1 - i; |
| 27 | + size_t l = d * 2 + 1, r = d * 2; |
| 28 | + data.emplace_back(merge(data[2 * n - 1 - l], data[2 * n - 1 - r])); |
| 29 | + size.emplace_back(size[2 * n - 1 - l] + size[2 * n - 1 - r]); |
| 30 | + } |
| 31 | + std::reverse(data.begin(), data.end()); |
| 32 | + std::reverse(size.begin(), size.end()); |
| 33 | + } |
| 34 | + template <typename Fun> |
| 35 | + void set_impl(size_t c, Fun const& operate, size_t pos) { |
| 36 | + if (size[pos] == 1) { |
| 37 | + data[pos] = operate((T const&)data[pos]); |
| 38 | + return; |
| 39 | + } |
| 40 | + size_t m = size[pos * 2 + 1]; |
| 41 | + if (c < m) |
| 42 | + set_impl(c, operate, pos * 2 + 1); |
| 43 | + else |
| 44 | + set_impl(c - m, operate, pos * 2 + 2); |
| 45 | + data[pos] = merge(data[pos * 2 + 1], data[pos * 2 + 2]); |
| 46 | + } |
| 47 | + T get_impl(size_t l, size_t r, size_t pos) { |
| 48 | + if (l == 0 && r == size[pos]) return data[pos]; |
| 49 | + size_t m = size[pos * 2 + 1]; |
| 50 | + if (l < m && r > m) |
| 51 | + return merge(get_impl(l, m, pos * 2 + 1), |
| 52 | + get_impl(0, r - m, pos * 2 + 2)); |
| 53 | + if (l < m) return get_impl(l, r, pos * 2 + 1); |
| 54 | + if (r > m) return get_impl(l - m, r - m, pos * 2 + 2); |
| 55 | + throw; |
| 56 | + } |
| 57 | + |
| 58 | + public: |
| 59 | + template <typename InputIterator> |
| 60 | + typical_segment_tree(InputIterator first, InputIterator last, |
| 61 | + MergeFunction mergeFun = MergeFunction()) |
| 62 | + : data(first, last), |
| 63 | + size(data.size(), 1), |
| 64 | + n(data.size()), |
| 65 | + merge(mergeFun) { |
| 66 | + rotate(data.begin(), data.begin() + (2 * n - 1) - (highbit(2 * n - 1) - 1), |
| 67 | + data.end()); |
| 68 | + reverse(data.begin(), data.end()); |
| 69 | + build(); |
| 70 | + } |
| 71 | + typical_segment_tree(size_t n, T const& init, |
| 72 | + MergeFunction mergeFun = MergeFunction()) |
| 73 | + : data(n, init), size(n, 1), n(n), merge(mergeFun) { |
| 74 | + build(); |
| 75 | + } |
| 76 | + /** |
| 77 | + * 单点修改操作 |
| 78 | + * @param target 修改的位置 |
| 79 | + * @param operate 更新该点的函数 |
| 80 | + */ |
| 81 | + template <typename Fun> void set(size_t target, Fun const& operate) { |
| 82 | + set_impl(target, operate, 0); |
| 83 | + } |
| 84 | + T get(size_t l, size_t r) { return get_impl(l, r, 0); } |
| 85 | +}; |
| 86 | +template <typename T> |
| 87 | +using typical_segment_tree_add = typical_segment_tree<T, std::plus<T>>; |
| 88 | +template <typename T> |
| 89 | +using typical_segment_tree_max = typical_segment_tree<T, max>; |
| 90 | +template <typename T> |
| 91 | +using typical_segment_tree_min = typical_segment_tree<T, min>; |
| 92 | +} // namespace mrpython |
| 93 | +#endif // MP_LIBRARY_SEGMENT_TREE_HPP |
0 commit comments