@@ -325,3 +325,30 @@ def test_label_graph_backward(self):
325325 }
326326 with pytest .raises (ValueError ):
327327 LabelTreeMapper .backward (instance = forward , all_labels = labels )
328+ # Checking label deletion case
329+ forward = {
330+ "type" : "tree" ,
331+ "directed" : True ,
332+ "nodes" : ["0" , "0_1" , "0_2" , "0_1_1" , "0_2_1" ],
333+ "edges" : [("0_1" , "0" ), ("0_2" , "0" ), ("0_1_1" , "0_1" ), ("0_2_1" , "0_2" )],
334+ }
335+ labels = {
336+ ID ("0" ): self .label_0 ,
337+ ID ("0_1" ): self .label_0_1 ,
338+ # ID("0_2"): self.label_0_2,
339+ ID ("0_1_1" ): self .label_0_1_1 ,
340+ # ID("0_1_2"): self.label_0_1_2,
341+ ID ("0_2_1" ): self .label_0_2_1 ,
342+ }
343+ expected_backward = LabelTree ()
344+ for node in labels .values ():
345+ expected_backward .add_node (node )
346+ for parent , child in [
347+ (self .label_0 , self .label_0_1 ),
348+ # (self.label_0, self.label_0_2),
349+ (self .label_0_1 , self .label_0_1_1 ),
350+ # (self.label_0_2, self.label_0_2_1),
351+ ]:
352+ expected_backward .add_child (parent , child )
353+ actual_backward = LabelTreeMapper .backward (instance = forward , all_labels = labels )
354+ assert LabelTreeMapper .forward (actual_backward ) == LabelTreeMapper .forward (expected_backward )
0 commit comments