11#pragma once
22#include < vector>
33
4- // CUT begin
5- struct BinaryTrie {
6- using Int = int ;
4+ template <class Int , class Count = int > struct BinaryTrie {
75 int maxD;
8- std::vector<int > deg, sz ;
6+ std::vector<Count > deg, subtree_sum ;
97 std::vector<int > ch0, ch1, par;
108
119 int _new_node (int id_par) {
12- deg.emplace_back (0 );
13- sz .emplace_back (0 );
10+ deg.emplace_back (Count () );
11+ subtree_sum .emplace_back (Count () );
1412 ch0.emplace_back (-1 );
1513 ch1.emplace_back (-1 );
1614 par.emplace_back (id_par);
17- return ch0.size () - 1 ;
15+ return ( int ) ch0.size () - 1 ;
1816 }
1917
2018 BinaryTrie (int maxD = 0 ) : maxD(maxD) { _new_node (-1 ); }
21- int _goto (Int x) {
19+
20+ // Return index of x.
21+ // Create nodes to locate x if needed.
22+ int _get_create_index (Int x) {
2223 int now = 0 ;
23- for (int d = maxD - 1 ; d >= 0 ; d-- ) {
24+ for (int d = maxD - 1 ; d >= 0 ; --d ) {
2425 int nxt = ((x >> d) & 1 ) ? ch1[now] : ch0[now];
2526 if (nxt == -1 ) {
2627 nxt = _new_node (now);
@@ -31,34 +32,77 @@ struct BinaryTrie {
3132 return now;
3233 }
3334
34- void insert (Int x) {
35- int now = _goto (x);
36- if (deg[now] == 0 ) {
37- deg[now] = 1 ;
38- while (now >= 0 ) { sz[now]++, now = par[now]; }
35+ // Return node index of x.
36+ // Return -1 if x is not in trie.
37+ int _get_index (Int x) const {
38+ int now = 0 ;
39+ for (int d = maxD - 1 ; d >= 0 ; --d) {
40+ now = ((x >> d) & 1 ) ? ch1[now] : ch0[now];
41+ if (now == -1 ) return -1 ;
3942 }
43+ return now;
44+ }
45+
46+ // insert x
47+ void insert (Int x, Count c = Count(1 )) {
48+ int now = _get_create_index (x);
49+ deg[now] += c;
50+ while (now >= 0 ) subtree_sum[now] += c, now = par[now];
4051 }
4152
53+ // delete x if exists
4254 void erase (Int x) {
43- int now = _goto (x);
44- if (deg[now] > 0 ) {
55+ int now = _get_index (x);
56+ if (now >= 0 and deg[now] != 0 ) {
57+ Count r = deg[now];
4558 deg[now] = 0 ;
46- while (now >= 0 ) { sz [now]-- , now = par[now]; }
59+ while (now >= 0 ) subtree_sum [now] -= r , now = par[now];
4760 }
4861 }
4962
50- Int xor_min (Int x) {
63+ Count count (Int x) const {
64+ int now = _get_index (x);
65+ return now == -1 ? Count () : deg[now];
66+ }
67+
68+ bool contains (Int x) const { return count (x) > Count (); }
69+
70+ // min(y ^ x) for y in trie
71+ Int xor_min (Int x) const {
5172 Int ret = 0 ;
5273 int now = 0 ;
53- if (!sz [now]) return -1 ;
74+ if (!subtree_sum [now]) return -1 ;
5475 for (int d = maxD - 1 ; d >= 0 ; d--) {
5576 int y = ((x >> d) & 1 ) ? ch1[now] : ch0[now];
56- if (y != -1 and sz [y]) {
77+ if (y != -1 and subtree_sum [y]) {
5778 now = y;
5879 } else {
5980 ret += Int (1 ) << d, now = ch0[now] ^ ch1[now] ^ y;
6081 }
6182 }
6283 return ret;
6384 }
85+
86+ // Count elements y such that x ^ y < thres
87+ Count count_less_xor (Int x, Int thres) const {
88+ Count ret = Count ();
89+ int now = 0 ;
90+
91+ for (int d = maxD - 1 ; d >= 0 ; d--) {
92+ if (now == -1 ) break ;
93+
94+ const bool bit_x = (x >> d) & 1 ;
95+
96+ if ((thres >> d) & 1 ) {
97+ const int child = bit_x ? ch1[now] : ch0[now];
98+ if (child != -1 ) ret += subtree_sum[child];
99+
100+ now = bit_x ? ch0[now] : ch1[now];
101+ } else {
102+ now = bit_x ? ch1[now] : ch0[now];
103+ }
104+ }
105+
106+ return ret;
107+ }
64108};
0 commit comments