@@ -91,5 +91,140 @@ def test_case4():
9191 return
9292
9393
94+ def test_case5 ():
95+ """
96+ 测试场景:一个简单的父子节点链 (A -> B),在 ref_counter 都为 0 时,应该成功合并。
97+ """
98+ print ("\n Test Case 5: Merging simple parent-child nodes when ref_counter is 0\n " )
99+ tree = RadixCache ("unique_name" , 100 , 0 )
100+
101+ _ , node_a = tree .insert (torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
102+ _ , node_b = tree .insert (torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .int64 ))
103+ tree .print_self ()
104+
105+ # 验证初始状态:A -> B 结构,且 ref_counter 均为 0
106+ assert node_b .parent == node_a
107+ assert torch .equal (node_a .token_id_key , torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
108+ assert len (node_a .children ) == 1
109+ assert node_a .ref_counter == 0
110+ assert node_b .ref_counter == 0
111+ assert tree .get_tree_total_tokens_num () == 5
112+
113+ # 执行合并
114+ tree .merge_unreferenced_nodes ()
115+ tree .print_self ()
116+
117+ assert torch .equal (node_a .token_id_key , torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .int64 ))
118+ assert node_a .is_leaf ()
119+ assert tree .get_tree_total_tokens_num () == 5
120+ assert tree .root_node .children [1 ] is node_a
121+
122+ def test_case6 ():
123+ """
124+ 测试场景:一个长的节点链 (A -> B -> C),在 ref_counter 都为 0 时,应该级联合并成一个节点。
125+ """
126+ print ("\n Test Case 6: Merging long nodes when ref_counter is 0\n " )
127+ tree = RadixCache ("unique_name" , 100 , 0 )
128+ _ , node_a = tree .insert (torch .tensor ([1 ], dtype = torch .int64 ))
129+ _ , node_b = tree .insert (torch .tensor ([1 , 2 ], dtype = torch .int64 ))
130+ _ , node_c = tree .insert (torch .tensor ([1 , 2 , 3 , 4 ], dtype = torch .int64 ))
131+ tree .print_self ()
132+
133+ assert node_c .parent == node_b
134+ assert node_b .parent == node_a
135+ assert tree .get_tree_total_tokens_num () == 4
136+ tree .merge_unreferenced_nodes ()
137+ tree .print_self ()
138+
139+ assert len (tree .root_node .children ) == 1
140+ # 节点 A 的 key 应该是完整的 [1, 2, 3, 4]
141+ assert torch .equal (node_a .token_id_key , torch .tensor ([1 , 2 , 3 , 4 ], dtype = torch .int64 ))
142+ assert node_a .is_leaf ()
143+ assert tree .get_tree_total_tokens_num () == 4
144+
145+ def test_case7 ():
146+ """
147+ 测试场景:由于父节点或子节点的 ref_counter > 0,合并不应该发生。
148+ """
149+ print ("\n Test Case 7: Merging when parent or child ref_counter > 0\n " )
150+ tree = RadixCache ("unique_name" , 100 , 0 )
151+
152+ _ , node_a = tree .insert (torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
153+ _ , node_b = tree .insert (torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .int64 ))
154+ tree .print_self ()
155+
156+ matched_node , _ , _ = tree .match_prefix (torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ), update_refs = True )
157+ assert matched_node is node_a
158+ assert node_a .ref_counter == 1
159+ assert node_b .ref_counter == 0
160+
161+ tree .merge_unreferenced_nodes ()
162+ tree .print_self ()
163+
164+ assert torch .equal (node_a .token_id_key , torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
165+ assert not node_a .is_leaf ()
166+ assert node_b .parent is node_a
167+
168+ def test_case8 ():
169+ """
170+ 测试场景:由于父节点有多个子节点,合并不应该发生。
171+ """
172+ print ("\n Test Case 8: Merging when parent has multiple children\n " )
173+ tree = RadixCache ("unique_name" , 100 , 0 )
174+
175+ _ , node_a = tree .insert (torch .tensor ([1 , 2 ], dtype = torch .int64 ))
176+ _ , node_b = tree .insert (torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
177+ _ , node_c = tree .insert (torch .tensor ([1 , 2 , 4 ], dtype = torch .int64 ))
178+ tree .print_self ()
179+
180+ assert len (node_a .children ) == 2
181+ assert node_a .ref_counter == 0
182+ assert node_b .ref_counter == 0
183+ assert node_c .ref_counter == 0
184+
185+ tree .merge_unreferenced_nodes ()
186+ tree .print_self ()
187+
188+ assert len (node_a .children ) == 2
189+ assert torch .equal (node_a .token_id_key , torch .tensor ([1 , 2 ], dtype = torch .int64 ))
190+ assert tree .root_node .children [1 ].children [3 ] is node_b
191+ assert tree .root_node .children [1 ].children [4 ] is node_c
192+
193+ def test_case9 ():
194+ """
195+ 测试场景:在一个复杂的树中,只有满足条件的分支被合并。
196+ """
197+ print ("\n Test Case 9: Merging in a complex tree with mixed conditions\n " )
198+ tree = RadixCache ("unique_name" , 100 , 0 )
199+
200+ # 分支1: 可合并的链 A -> B
201+ _ , node_a = tree .insert (torch .tensor ([1 , 2 ], dtype = torch .int64 ))
202+ _ , node_b = tree .insert (torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
203+
204+ # 分支2: 不可合并的链 C -> D (因为 C 被引用)
205+ _ , node_c = tree .insert (torch .tensor ([4 , 5 ], dtype = torch .int64 ))
206+ _ , node_d = tree .insert (torch .tensor ([4 , 5 , 6 ], dtype = torch .int64 ))
207+
208+ # 增加 C 的引用计数
209+ tree .match_prefix (torch .tensor ([4 , 5 ], dtype = torch .int64 ), update_refs = True )
210+ assert node_c .ref_counter == 1
211+ tree .print_self ()
212+
213+ tree .merge_unreferenced_nodes ()
214+ tree .print_self ()
215+
216+ merged_node_a = tree .root_node .children [1 ]
217+ assert torch .equal (merged_node_a .token_id_key , torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
218+ assert merged_node_a .is_leaf ()
219+
220+ unmerged_node_c = tree .root_node .children [4 ]
221+ assert torch .equal (unmerged_node_c .token_id_key , torch .tensor ([4 , 5 ], dtype = torch .int64 ))
222+ assert not unmerged_node_c .is_leaf ()
223+ assert len (unmerged_node_c .children ) == 1
224+
225+ unmerged_node_d = unmerged_node_c .children [6 ]
226+ assert torch .equal (unmerged_node_d .token_id_key , torch .tensor ([6 ], dtype = torch .int64 ))
227+
228+
94229if __name__ == "__main__" :
95230 pytest .main ()
0 commit comments