Skip to content

Commit 4fc2e8d

Browse files
author
Songki Choi
authored
Fix 'None' node issue in label schema mapping in case of label deletion (#2300)
Signed-off-by: Songki Choi <[email protected]>
1 parent f1baed1 commit 4fc2e8d

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

otx/api/serialization/label_mapper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,13 @@ def backward(instance: dict, all_labels: Dict[ID, LabelEntity]) -> Union[LabelTr
142142

143143
label_map = {label_id: all_labels.get(IDMapper().backward(label_id)) for label_id in instance["nodes"]}
144144
for label in label_map.values():
145-
output.add_node(label)
145+
if label:
146+
output.add_node(label)
146147
for edge in instance["edges"]:
147-
output.add_edge(label_map[edge[0]], label_map[edge[1]])
148+
node1 = label_map.get(edge[0])
149+
node2 = label_map.get(edge[1])
150+
if node1 and node2:
151+
output.add_edge(node1, node2)
148152

149153
return output
150154

tests/unit/api/serialization/test_label_mapper.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,30 @@ def test_label_graph_backward(self):
374374
}
375375
with pytest.raises(ValueError):
376376
LabelGraphMapper.backward(instance=forward, all_labels=labels)
377+
# Checking label deletion case
378+
forward = {
379+
"type": "tree",
380+
"directed": True,
381+
"nodes": ["0", "0_1", "0_2", "0_1_1", "0_2_1"],
382+
"edges": [("0_1", "0"), ("0_2", "0"), ("0_1_1", "0_1"), ("0_2_1", "0_2")],
383+
}
384+
labels = {
385+
ID("0"): self.label_0,
386+
ID("0_1"): self.label_0_1,
387+
# ID("0_2"): self.label_0_2,
388+
ID("0_1_1"): self.label_0_1_1,
389+
# ID("0_1_2"): self.label_0_1_2,
390+
ID("0_2_1"): self.label_0_2_1,
391+
}
392+
expected_backward = LabelTree()
393+
for node in labels.values():
394+
expected_backward.add_node(node)
395+
for parent, child in [
396+
(self.label_0, self.label_0_1),
397+
# (self.label_0, self.label_0_2),
398+
(self.label_0_1, self.label_0_1_1),
399+
# (self.label_0_2, self.label_0_2_1),
400+
]:
401+
expected_backward.add_child(parent, child)
402+
actual_backward = LabelGraphMapper.backward(instance=forward, all_labels=labels)
403+
assert LabelGraphMapper.forward(actual_backward) == LabelGraphMapper.forward(expected_backward)

0 commit comments

Comments
 (0)