Skip to content

Commit 574ce80

Browse files
committed
fix merge efficiency bug
1 parent 0cb910f commit 574ce80

File tree

3 files changed

+67
-62
lines changed

3 files changed

+67
-62
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ repos:
77
args: [--line-length=120]
88
additional_dependencies: ['click==8.0.4']
99
- repo: https://github.com/pycqa/flake8
10-
rev: 3.9.0
10+
rev: 6.1.0
1111
hooks:
1212
- id: flake8
13-
additional_dependencies: [flake8-typing-imports==1.9.0]
14-
args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231']
13+
args: ['--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231']

lightllm/server/router/dynamic_prompt/radix_cache.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def evict(self, need_remove_tokens, evict_callback):
342342

343343
return
344344

345-
def _try_merge(self, child_node: TreeNode):
345+
def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]:
346346
"""
347347
合并条件:
348348
1. 父节点不是根节点。
@@ -352,52 +352,54 @@ def _try_merge(self, child_node: TreeNode):
352352
"""
353353
parent_node = child_node.parent
354354
# 条件检查
355-
if parent_node is None or parent_node == self.root_node:
356-
return False
357-
358-
if parent_node.ref_counter == 0 and \
359-
child_node.ref_counter == 0 and \
360-
len(parent_node.children) == 1:
361-
362-
if child_node.is_leaf():
363-
self.evict_tree_set.discard(child_node)
364-
365-
new_key = torch.cat([parent_node.token_id_key, child_node.token_id_key])
366-
new_value = torch.cat([parent_node.token_mem_index_value, child_node.token_mem_index_value])
367-
368-
parent_node.token_id_key = new_key
369-
parent_node.token_mem_index_value = new_value
370-
parent_node.children = child_node.children
371-
for grandchild in parent_node.children.values():
372-
grandchild.parent = parent_node
373-
374-
parent_node.node_value_len = len(parent_node.token_mem_index_value)
375-
parent_node.time_id = max(parent_node.time_id, child_node.time_id)
355+
if (
356+
parent_node is None
357+
or parent_node == self.root_node
358+
or parent_node.ref_counter != 0
359+
or len(parent_node.children) != 1
360+
or child_node.ref_counter != 0
361+
):
362+
return None
376363

377-
if parent_node.is_leaf():
378-
self.evict_tree_set.add(parent_node)
379-
380-
child_node.parent = None
381-
return True
364+
if child_node.is_leaf():
365+
self.evict_tree_set.discard(child_node)
366+
367+
child_node.token_id_key = torch.cat([parent_node.token_id_key, child_node.token_id_key])
368+
child_node.token_mem_index_value = torch.cat(
369+
[parent_node.token_mem_index_value, child_node.token_mem_index_value]
370+
)
371+
child_node.node_value_len = len(child_node.token_mem_index_value)
372+
child_node.time_id = max(parent_node.time_id, child_node.time_id)
373+
374+
grandparent_node = parent_node.parent
375+
key_in_grandparent = parent_node.token_id_key[0].item()
376+
grandparent_node.children[key_in_grandparent] = child_node
377+
child_node.parent = grandparent_node
382378

383-
return False
379+
parent_node.parent = None
380+
381+
if child_node.is_leaf():
382+
self.evict_tree_set.add(child_node)
383+
384+
return child_node
384385

385386
def merge_unreferenced_nodes(self):
386-
if not self.root_node.children:
387-
return
388-
nodes_to_process = []
389-
traversal_stack = list(self.root_node.children.values())
390-
391-
while traversal_stack:
392-
node = traversal_stack.pop()
393-
nodes_to_process.append(node)
394-
traversal_stack.extend(list(node.children.values()))
395-
396-
nodes_to_process.reverse()
397-
for node in nodes_to_process:
398-
if node.parent is not None:
399-
self._try_merge(node)
400-
387+
worklist = collections.deque(
388+
[
389+
node
390+
for node in self.evict_tree_set
391+
if node.ref_counter == 0 and node.parent is not None and node.parent != self.root_node
392+
]
393+
)
394+
395+
while worklist:
396+
node = worklist.popleft()
397+
if node.parent is None:
398+
continue
399+
merged_node = self._try_merge(node)
400+
if merged_node:
401+
worklist.append(merged_node)
402+
401403
def assert_leafs_is_right(self):
402404
for node in self.evict_tree_set:
403405
if node.is_leaf() and node.ref_counter == 0:

unit_tests/server/router/dynamic_prompt/test_radix_cache.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_case5():
9797
"""
9898
print("\nTest Case 5: Merging simple parent-child nodes when ref_counter is 0\n")
9999
tree = RadixCache("unique_name", 100, 0)
100-
100+
101101
_, node_a = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
102102
_, node_b = tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
103103
tree.print_self()
@@ -114,10 +114,11 @@ def test_case5():
114114
tree.merge_unreferenced_nodes()
115115
tree.print_self()
116116

117-
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
118-
assert node_a.is_leaf()
117+
assert torch.equal(node_b.token_id_key, torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
118+
assert node_b.is_leaf()
119119
assert tree.get_tree_total_tokens_num() == 5
120-
assert tree.root_node.children[1] is node_a
120+
assert tree.root_node.children[1] is node_b
121+
121122

122123
def test_case6():
123124
"""
@@ -137,18 +138,19 @@ def test_case6():
137138
tree.print_self()
138139

139140
assert len(tree.root_node.children) == 1
140-
# 节点 A 的 key 应该是完整的 [1, 2, 3, 4]
141-
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3, 4], dtype=torch.int64))
142-
assert node_a.is_leaf()
141+
# 节点 C 的 key 应该是完整的 [1, 2, 3, 4]
142+
assert torch.equal(node_c.token_id_key, torch.tensor([1, 2, 3, 4], dtype=torch.int64))
143+
assert node_c.is_leaf()
143144
assert tree.get_tree_total_tokens_num() == 4
144145

146+
145147
def test_case7():
146148
"""
147149
测试场景:由于父节点或子节点的 ref_counter > 0,合并不应该发生。
148150
"""
149151
print("\nTest Case 7: Merging when parent or child ref_counter > 0\n")
150152
tree = RadixCache("unique_name", 100, 0)
151-
153+
152154
_, node_a = tree.insert(torch.tensor([1, 2, 3], dtype=torch.int64))
153155
_, node_b = tree.insert(torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64))
154156
tree.print_self()
@@ -157,14 +159,15 @@ def test_case7():
157159
assert matched_node is node_a
158160
assert node_a.ref_counter == 1
159161
assert node_b.ref_counter == 0
160-
162+
161163
tree.merge_unreferenced_nodes()
162164
tree.print_self()
163165

164166
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64))
165167
assert not node_a.is_leaf()
166168
assert node_b.parent is node_a
167169

170+
168171
def test_case8():
169172
"""
170173
测试场景:由于父节点有多个子节点,合并不应该发生。
@@ -184,12 +187,13 @@ def test_case8():
184187

185188
tree.merge_unreferenced_nodes()
186189
tree.print_self()
187-
190+
188191
assert len(node_a.children) == 2
189192
assert torch.equal(node_a.token_id_key, torch.tensor([1, 2], dtype=torch.int64))
190193
assert tree.root_node.children[1].children[3] is node_b
191194
assert tree.root_node.children[1].children[4] is node_c
192195

196+
193197
def test_case9():
194198
"""
195199
测试场景:在一个复杂的树中,只有满足条件的分支被合并。
@@ -204,7 +208,7 @@ def test_case9():
204208
# 分支2: 不可合并的链 C -> D (因为 C 被引用)
205209
_, node_c = tree.insert(torch.tensor([4, 5], dtype=torch.int64))
206210
_, node_d = tree.insert(torch.tensor([4, 5, 6], dtype=torch.int64))
207-
211+
208212
# 增加 C 的引用计数
209213
tree.match_prefix(torch.tensor([4, 5], dtype=torch.int64), update_refs=True)
210214
assert node_c.ref_counter == 1
@@ -213,15 +217,15 @@ def test_case9():
213217
tree.merge_unreferenced_nodes()
214218
tree.print_self()
215219

216-
merged_node_a = tree.root_node.children[1]
217-
assert torch.equal(merged_node_a.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64))
218-
assert merged_node_a.is_leaf()
219-
220+
merged_node_b = tree.root_node.children[1]
221+
assert torch.equal(merged_node_b.token_id_key, torch.tensor([1, 2, 3], dtype=torch.int64))
222+
assert merged_node_b.is_leaf()
223+
220224
unmerged_node_c = tree.root_node.children[4]
221225
assert torch.equal(unmerged_node_c.token_id_key, torch.tensor([4, 5], dtype=torch.int64))
222226
assert not unmerged_node_c.is_leaf()
223227
assert len(unmerged_node_c.children) == 1
224-
228+
225229
unmerged_node_d = unmerged_node_c.children[6]
226230
assert torch.equal(unmerged_node_d.token_id_key, torch.tensor([6], dtype=torch.int64))
227231

0 commit comments

Comments
 (0)