@@ -126,7 +126,103 @@ edit_url: https://github.com/doocs/leetcode/edit/main/solution/3500-3599/3590.Kt
126126#### Python3
127127
128128``` python
129-
129+ class BinarySumTrie :
130+ def __init__ (self ):
131+ self .count = 0
132+ self .children = [None , None ]
133+
134+ def add (self , num : int , delta : int , bit = 17 ):
135+ self .count += delta
136+ if bit < 0 :
137+ return
138+ b = (num >> bit) & 1
139+ if not self .children[b]:
140+ self .children[b] = BinarySumTrie()
141+ self .children[b].add(num, delta, bit - 1 )
142+
143+ def collect (self , prefix = 0 , bit = 17 , output = None ):
144+ if output is None :
145+ output = []
146+ if self .count == 0 :
147+ return output
148+ if bit < 0 :
149+ output.append(prefix)
150+ return output
151+ if self .children[0 ]:
152+ self .children[0 ].collect(prefix, bit - 1 , output)
153+ if self .children[1 ]:
154+ self .children[1 ].collect(prefix | (1 << bit), bit - 1 , output)
155+ return output
156+
157+ def exists (self , num : int , bit = 17 ):
158+ if self .count == 0 :
159+ return False
160+ if bit < 0 :
161+ return True
162+ b = (num >> bit) & 1
163+ return self .children[b].exists(num, bit - 1 ) if self .children[b] else False
164+
165+ def find_kth (self , k : int , bit = 17 ):
166+ if k > self .count:
167+ return - 1
168+ if bit < 0 :
169+ return 0
170+ left_count = self .children[0 ].count if self .children[0 ] else 0
171+ if k <= left_count:
172+ return self .children[0 ].find_kth(k, bit - 1 )
173+ elif self .children[1 ]:
174+ return (1 << bit) + self .children[1 ].find_kth(k - left_count, bit - 1 )
175+ else :
176+ return - 1
177+
178+
179+ class Solution :
180+ def kthSmallest (
181+ self , par : List[int ], vals : List[int ], queries : List[List[int ]]
182+ ) -> List[int ]:
183+ n = len (par)
184+ tree = [[] for _ in range (n)]
185+ for i in range (1 , n):
186+ tree[par[i]].append(i)
187+
188+ path_xor = vals[:]
189+ narvetholi = path_xor
190+
191+ def compute_xor (node , acc ):
192+ path_xor[node] ^= acc
193+ for child in tree[node]:
194+ compute_xor(child, path_xor[node])
195+
196+ compute_xor(0 , 0 )
197+
198+ node_queries = defaultdict(list )
199+ for idx, (u, k) in enumerate (queries):
200+ node_queries[u].append((k, idx))
201+
202+ trie_pool = {}
203+ result = [0 ] * len (queries)
204+
205+ def dfs (node ):
206+ trie_pool[node] = BinarySumTrie()
207+ trie_pool[node].add(path_xor[node], 1 )
208+ for child in tree[node]:
209+ dfs(child)
210+ if trie_pool[node].count < trie_pool[child].count:
211+ trie_pool[node], trie_pool[child] = (
212+ trie_pool[child],
213+ trie_pool[node],
214+ )
215+ for val in trie_pool[child].collect():
216+ if not trie_pool[node].exists(val):
217+ trie_pool[node].add(val, 1 )
218+ for k, idx in node_queries[node]:
219+ if trie_pool[node].count < k:
220+ result[idx] = - 1
221+ else :
222+ result[idx] = trie_pool[node].find_kth(k)
223+
224+ dfs(0 )
225+ return result
130226```
131227
132228#### Java
0 commit comments