@@ -97,7 +97,7 @@ def test_case5():
9797 """
9898 print ("\n Test Case 5: Merging simple parent-child nodes when ref_counter is 0\n " )
9999 tree = RadixCache ("unique_name" , 100 , 0 )
100-
100+
101101 _ , node_a = tree .insert (torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
102102 _ , node_b = tree .insert (torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .int64 ))
103103 tree .print_self ()
@@ -114,10 +114,11 @@ def test_case5():
114114 tree .merge_unreferenced_nodes ()
115115 tree .print_self ()
116116
117- assert torch .equal (node_a .token_id_key , torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .int64 ))
118- assert node_a .is_leaf ()
117+ assert torch .equal (node_b .token_id_key , torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .int64 ))
118+ assert node_b .is_leaf ()
119119 assert tree .get_tree_total_tokens_num () == 5
120- assert tree .root_node .children [1 ] is node_a
120+ assert tree .root_node .children [1 ] is node_b
121+
121122
122123def test_case6 ():
123124 """
@@ -137,18 +138,19 @@ def test_case6():
137138 tree .print_self ()
138139
139140 assert len (tree .root_node .children ) == 1
140- # 节点 A 的 key 应该是完整的 [1, 2, 3, 4]
141- assert torch .equal (node_a .token_id_key , torch .tensor ([1 , 2 , 3 , 4 ], dtype = torch .int64 ))
142- assert node_a .is_leaf ()
141+ # 节点 C 的 key 应该是完整的 [1, 2, 3, 4]
142+ assert torch .equal (node_c .token_id_key , torch .tensor ([1 , 2 , 3 , 4 ], dtype = torch .int64 ))
143+ assert node_c .is_leaf ()
143144 assert tree .get_tree_total_tokens_num () == 4
144145
146+
145147def test_case7 ():
146148 """
147149 测试场景:由于父节点或子节点的 ref_counter > 0,合并不应该发生。
148150 """
149151 print ("\n Test Case 7: Merging when parent or child ref_counter > 0\n " )
150152 tree = RadixCache ("unique_name" , 100 , 0 )
151-
153+
152154 _ , node_a = tree .insert (torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
153155 _ , node_b = tree .insert (torch .tensor ([1 , 2 , 3 , 4 , 5 ], dtype = torch .int64 ))
154156 tree .print_self ()
@@ -157,14 +159,15 @@ def test_case7():
157159 assert matched_node is node_a
158160 assert node_a .ref_counter == 1
159161 assert node_b .ref_counter == 0
160-
162+
161163 tree .merge_unreferenced_nodes ()
162164 tree .print_self ()
163165
164166 assert torch .equal (node_a .token_id_key , torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
165167 assert not node_a .is_leaf ()
166168 assert node_b .parent is node_a
167169
170+
168171def test_case8 ():
169172 """
170173 测试场景:由于父节点有多个子节点,合并不应该发生。
@@ -184,12 +187,13 @@ def test_case8():
184187
185188 tree .merge_unreferenced_nodes ()
186189 tree .print_self ()
187-
190+
188191 assert len (node_a .children ) == 2
189192 assert torch .equal (node_a .token_id_key , torch .tensor ([1 , 2 ], dtype = torch .int64 ))
190193 assert tree .root_node .children [1 ].children [3 ] is node_b
191194 assert tree .root_node .children [1 ].children [4 ] is node_c
192195
196+
193197def test_case9 ():
194198 """
195199 测试场景:在一个复杂的树中,只有满足条件的分支被合并。
@@ -204,7 +208,7 @@ def test_case9():
204208 # 分支2: 不可合并的链 C -> D (因为 C 被引用)
205209 _ , node_c = tree .insert (torch .tensor ([4 , 5 ], dtype = torch .int64 ))
206210 _ , node_d = tree .insert (torch .tensor ([4 , 5 , 6 ], dtype = torch .int64 ))
207-
211+
208212 # 增加 C 的引用计数
209213 tree .match_prefix (torch .tensor ([4 , 5 ], dtype = torch .int64 ), update_refs = True )
210214 assert node_c .ref_counter == 1
@@ -213,15 +217,15 @@ def test_case9():
213217 tree .merge_unreferenced_nodes ()
214218 tree .print_self ()
215219
216- merged_node_a = tree .root_node .children [1 ]
217- assert torch .equal (merged_node_a .token_id_key , torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
218- assert merged_node_a .is_leaf ()
219-
220+ merged_node_b = tree .root_node .children [1 ]
221+ assert torch .equal (merged_node_b .token_id_key , torch .tensor ([1 , 2 , 3 ], dtype = torch .int64 ))
222+ assert merged_node_b .is_leaf ()
223+
220224 unmerged_node_c = tree .root_node .children [4 ]
221225 assert torch .equal (unmerged_node_c .token_id_key , torch .tensor ([4 , 5 ], dtype = torch .int64 ))
222226 assert not unmerged_node_c .is_leaf ()
223227 assert len (unmerged_node_c .children ) == 1
224-
228+
225229 unmerged_node_d = unmerged_node_c .children [6 ]
226230 assert torch .equal (unmerged_node_d .token_id_key , torch .tensor ([6 ], dtype = torch .int64 ))
227231
0 commit comments