diff --git a/content/data-structures/LazySegmentTree.h b/content/data-structures/LazySegmentTree.h index ecad9a0a5..0a9125ecc 100644 --- a/content/data-structures/LazySegmentTree.h +++ b/content/data-structures/LazySegmentTree.h @@ -16,51 +16,54 @@ const int inf = 1e9; struct Node { + typedef int T; // data type + struct L { int mset, madd; }; // lazy type + const T tneut = -inf; // neutral elements + const L lneut = {inf, 0}; + T f (T a, T b) { return max(a, b); } // (any associative fn) + T apply (T a, L b) { + return b.mset != inf ? b.mset + b.madd : a + b.madd; + } // Apply lazy + L comb(L a, L b) { + if (b.mset != inf) return b; + return {a.mset, a.madd + b.madd}; + } // Combine lazy + Node *l = 0, *r = 0; - int lo, hi, mset = inf, madd = 0, val = -inf; - Node(int lo,int hi):lo(lo),hi(hi){} // Large interval of -inf - Node(vi& v, int lo, int hi) : lo(lo), hi(hi) { + int lo, hi; T val = tneut; L lazy = lneut; + Node(int lo,int hi):lo(lo),hi(hi){}//Large interval of tneut + Node(vector& v, int lo, int hi) : lo(lo), hi(hi) { if (lo + 1 < hi) { int mid = lo + (hi - lo)/2; l = new Node(v, lo, mid); r = new Node(v, mid, hi); - val = max(l->val, r->val); + val = f(l->val, r->val); } else val = v[lo]; } - int query(int L, int R) { - if (R <= lo || hi <= L) return -inf; - if (L <= lo && hi <= R) return val; + T query(int L, int R) { + if (R <= lo || hi <= L) return tneut; + if (L <= lo && hi <= R) return apply(val, lazy); push(); - return max(l->query(L, R), r->query(L, R)); + return f(l->query(L, R), r->query(L, R)); } - void set(int L, int R, int x) { - if (R <= lo || hi <= L) return; - if (L <= lo && hi <= R) mset = val = x, madd = 0; - else { - push(), l->set(L, R, x), r->set(L, R, x); - val = max(l->val, r->val); - } - } - void add(int L, int R, int x) { - if (R <= lo || hi <= L) return; - if (L <= lo && hi <= R) { - if (mset != inf) mset += x; - else madd += x; - val += x; - } + void upd(int Le, int Ri, L x) { + if (Ri <= lo || hi <= Le) return; + if (Le <= lo && hi <= Ri) lazy = comb(lazy, x); else { - push(), l->add(L, R, x), r->add(L, R, x); - val = max(l->val, r->val); + push(), l->upd(Le, Ri, x), r->upd(Le, Ri, x); + val = f(l->query(lo, hi), r->query(lo, hi)); } } + void set(int L, int R, int x) { upd(L, R, {x, 0}); } + void add(int L, int R, int x) { upd(L, R, {inf, x}); } void push() { if (!l) { int mid = lo + (hi - lo)/2; - l = new Node(lo, mid); r = new Node(mid, hi); + l = new Node(lo, mid), r = new Node(mid, hi); } - if (mset != inf) - l->set(lo,hi,mset), r->set(lo,hi,mset), mset = inf; - else if (madd) - l->add(lo,hi,madd), r->add(lo,hi,madd), madd = 0; + l->lazy = comb(l->lazy, lazy); + r->lazy = comb(r->lazy, lazy); + lazy = lneut; + val = f(l->query(lo, hi), r->query(lo, hi)); } };