Skip to content

Commit 0cb910f

Browse files
committed
impl radix_cache node merge
1 parent 77a92be commit 0cb910f

File tree

2 files changed

+191
-0
lines changed

2 files changed

+191
-0
lines changed

lightllm/server/router/dynamic_prompt/radix_cache.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,62 @@ def evict(self, need_remove_tokens, evict_callback):
342342

343343
return
344344

345+
def _try_merge(self, child_node: TreeNode):
346+
"""
347+
合并条件:
348+
1. 父节点不是根节点。
349+
2. 父节点的引用计数为 0。
350+
3. 子节点的引用计数为 0。
351+
4. 父节点只有一个子节点 (即 child_node)。
352+
"""
353+
parent_node = child_node.parent
354+
# 条件检查
355+
if parent_node is None or parent_node == self.root_node:
356+
return False
357+
358+
if parent_node.ref_counter == 0 and \
359+
child_node.ref_counter == 0 and \
360+
len(parent_node.children) == 1:
361+
362+
if child_node.is_leaf():
363+
self.evict_tree_set.discard(child_node)
364+
365+
new_key = torch.cat([parent_node.token_id_key, child_node.token_id_key])
366+
new_value = torch.cat([parent_node.token_mem_index_value, child_node.token_mem_index_value])
367+
368+
parent_node.token_id_key = new_key
369+
parent_node.token_mem_index_value = new_value
370+
parent_node.children = child_node.children
371+
for grandchild in parent_node.children.values():
372+
grandchild.parent = parent_node
373+
374+
parent_node.node_value_len = len(parent_node.token_mem_index_value)
375+
parent_node.time_id = max(parent_node.time_id, child_node.time_id)
376+
377+
if parent_node.is_leaf():
378+
self.evict_tree_set.add(parent_node)
379+
380+
child_node.parent = None
381+
return True
382+
383+
return False
384+
385+
def merge_unreferenced_nodes(self):
386+
if not self.root_node.children:
387+
return
388+
nodes_to_process = []
389+
traversal_stack = list(self.root_node.children.values())
390+
391+
while traversal_stack:
392+
node = traversal_stack.pop()
393+
nodes_to_process.append(node)
394+
traversal_stack.extend(list(node.children.values()))
395+
396+
nodes_to_process.reverse()
397+
for node in nodes_to_process:
398+
if node.parent is not None:
399+
self._try_merge(node)
400+
345401
def assert_leafs_is_right(self):
346402
for node in self.evict_tree_set:
347403
if node.is_leaf() and node.ref_counter == 0:

unit_tests/server/router/dynamic_prompt/test_radix_cache.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,5 +91,140 @@ def test_case4():
9191
return
9292

9393

94+
def test_case5():
95+
"""
96+
测试场景:一个简单的父子节点链 (A -> B),在 ref_counter 都为 0 时,应该成功合并。
97+
"""
98+
print("\nTest Case 5: Merging simple parent-child nodes when ref_counter is 0\n")
99+
tree = RadixCache("unique_name", 100, 0)
100+
101+
_, node_a = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
102+
_, node_b = tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
103+
tree.print_self()
104+
105+
# 验证初始状态:A -> B 结构,且 ref_counter 均为 0
106+
assert node_b.parent == node_a
107+
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64))
108+
assert len(node_a.children) == 1
109+
assert node_a.ref_counter == 0
110+
assert node_b.ref_counter == 0
111+
assert tree.get_tree_total_tokens_num() == 5
112+
113+
# 执行合并
114+
tree.merge_unreferenced_nodes()
115+
tree.print_self()
116+
117+
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
118+
assert node_a.is_leaf()
119+
assert tree.get_tree_total_tokens_num() == 5
120+
assert tree.root_node.children[1] is node_a
121+
122+
def test_case6():
123+
"""
124+
测试场景:一个长的节点链 (A -> B -> C),在 ref_counter 都为 0 时,应该级联合并成一个节点。
125+
"""
126+
print("\nTest Case 6: Merging long nodes when ref_counter is 0\n")
127+
tree = RadixCache("unique_name", 100, 0)
128+
_, node_a = tree.insert(torch.tensor([1], dtype=torch.int64))
129+
_, node_b = tree.insert(torch.tensor([1, 2], dtype=torch.int64))
130+
_, node_c = tree.insert(torch.tensor([1, 2, 3, 4], dtype=torch.int64))
131+
tree.print_self()
132+
133+
assert node_c.parent == node_b
134+
assert node_b.parent == node_a
135+
assert tree.get_tree_total_tokens_num() == 4
136+
tree.merge_unreferenced_nodes()
137+
tree.print_self()
138+
139+
assert len(tree.root_node.children) == 1
140+
# 节点 A 的 key 应该是完整的 [1, 2, 3, 4]
141+
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3, 4], dtype=torch.int64))
142+
assert node_a.is_leaf()
143+
assert tree.get_tree_total_tokens_num() == 4
144+
145+
def test_case7():
146+
"""
147+
测试场景:由于父节点或子节点的 ref_counter > 0,合并不应该发生。
148+
"""
149+
print("\nTest Case 7: Merging when parent or child ref_counter > 0\n")
150+
tree = RadixCache("unique_name", 100, 0)
151+
152+
_, node_a = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
153+
_, node_b = tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
154+
tree.print_self()
155+
156+
matched_node, _, _ = tree.match_prefix(torch.tensor([1, 2, 3], dtype=torch.int64), update_refs=True)
157+
assert matched_node is node_a
158+
assert node_a.ref_counter == 1
159+
assert node_b.ref_counter == 0
160+
161+
tree.merge_unreferenced_nodes()
162+
tree.print_self()
163+
164+
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64))
165+
assert not node_a.is_leaf()
166+
assert node_b.parent is node_a
167+
168+
def test_case8():
169+
"""
170+
测试场景:由于父节点有多个子节点,合并不应该发生。
171+
"""
172+
print("\nTest Case 8: Merging when parent has multiple children\n")
173+
tree = RadixCache("unique_name", 100, 0)
174+
175+
_, node_a = tree.insert(torch.tensor([1, 2], dtype=torch.int64))
176+
_, node_b = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
177+
_, node_c = tree.insert(torch.tensor([1, 2, 4], dtype=torch.int64))
178+
tree.print_self()
179+
180+
assert len(node_a.children) == 2
181+
assert node_a.ref_counter == 0
182+
assert node_b.ref_counter == 0
183+
assert node_c.ref_counter == 0
184+
185+
tree.merge_unreferenced_nodes()
186+
tree.print_self()
187+
188+
assert len(node_a.children) == 2
189+
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2], dtype=torch.int64))
190+
assert tree.root_node.children[1].children[3] is node_b
191+
assert tree.root_node.children[1].children[4] is node_c
192+
193+
def test_case9():
194+
"""
195+
测试场景:在一个复杂的树中,只有满足条件的分支被合并。
196+
"""
197+
print("\nTest Case 9: Merging in a complex tree with mixed conditions\n")
198+
tree = RadixCache("unique_name", 100, 0)
199+
200+
# 分支1: 可合并的链 A -> B
201+
_, node_a = tree.insert(torch.tensor([1, 2], dtype=torch.int64))
202+
_, node_b = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
203+
204+
# 分支2: 不可合并的链 C -> D (因为 C 被引用)
205+
_, node_c = tree.insert(torch.tensor([4, 5], dtype=torch.int64))
206+
_, node_d = tree.insert(torch.tensor([4, 5, 6], dtype=torch.int64))
207+
208+
# 增加 C 的引用计数
209+
tree.match_prefix(torch.tensor([4, 5], dtype=torch.int64), update_refs=True)
210+
assert node_c.ref_counter == 1
211+
tree.print_self()
212+
213+
tree.merge_unreferenced_nodes()
214+
tree.print_self()
215+
216+
merged_node_a = tree.root_node.children[1]
217+
assert torch.equal(merged_node_a.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64))
218+
assert merged_node_a.is_leaf()
219+
220+
unmerged_node_c = tree.root_node.children[4]
221+
assert torch.equal(unmerged_node_c.token_id_key, torch.tensor([4, 5], dtype=torch.int64))
222+
assert not unmerged_node_c.is_leaf()
223+
assert len(unmerged_node_c.children) == 1
224+
225+
unmerged_node_d = unmerged_node_c.children[6]
226+
assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64))
227+
228+
94229
if __name__ == "__main__":
95230
pytest.main()

0 commit comments

Comments
 (0)