@@ -374,3 +374,30 @@ def test_label_graph_backward(self):
374374 }
375375 with pytest .raises (ValueError ):
376376 LabelGraphMapper .backward (instance = forward , all_labels = labels )
377+ # Checking label deletion case
378+ forward = {
379+ "type" : "tree" ,
380+ "directed" : True ,
381+ "nodes" : ["0" , "0_1" , "0_2" , "0_1_1" , "0_2_1" ],
382+ "edges" : [("0_1" , "0" ), ("0_2" , "0" ), ("0_1_1" , "0_1" ), ("0_2_1" , "0_2" )],
383+ }
384+ labels = {
385+ ID ("0" ): self .label_0 ,
386+ ID ("0_1" ): self .label_0_1 ,
387+ # ID("0_2"): self.label_0_2,
388+ ID ("0_1_1" ): self .label_0_1_1 ,
389+ # ID("0_1_2"): self.label_0_1_2,
390+ ID ("0_2_1" ): self .label_0_2_1 ,
391+ }
392+ expected_backward = LabelTree ()
393+ for node in labels .values ():
394+ expected_backward .add_node (node )
395+ for parent , child in [
396+ (self .label_0 , self .label_0_1 ),
397+ # (self.label_0, self.label_0_2),
398+ (self .label_0_1 , self .label_0_1_1 ),
399+ # (self.label_0_2, self.label_0_2_1),
400+ ]:
401+ expected_backward .add_child (parent , child )
402+ actual_backward = LabelGraphMapper .backward (instance = forward , all_labels = labels )
403+ assert LabelGraphMapper .forward (actual_backward ) == LabelGraphMapper .forward (expected_backward )
0 commit comments