Skip to content

Commit 50e4f35

Browse files
author
Songki Choi
authored
Fix 'None' node issue in label schema mapping in case of label deletion (develop) (#2308)
Signed-off-by: Songki Choi <[email protected]>
1 parent ee7610e commit 50e4f35

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

src/otx/api/serialization/label_mapper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,13 @@ def backward(instance: dict, all_labels: Dict[ID, LabelEntity]) -> LabelTree:
139139

140140
label_map = {label_id: all_labels.get(IDMapper().backward(label_id)) for label_id in instance["nodes"]}
141141
for label in label_map.values():
142-
output.add_node(label)
142+
if label:
143+
output.add_node(label)
143144
for edge in instance["edges"]:
144-
output.add_edge(label_map[edge[0]], label_map[edge[1]])
145+
node1 = label_map.get(edge[0])
146+
node2 = label_map.get(edge[1])
147+
if node1 and node2:
148+
output.add_edge(node1, node2)
145149

146150
return output
147151

tests/unit/api/serialization/test_label_mapper.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,30 @@ def test_label_graph_backward(self):
325325
}
326326
with pytest.raises(ValueError):
327327
LabelTreeMapper.backward(instance=forward, all_labels=labels)
328+
# Checking label deletion case
329+
forward = {
330+
"type": "tree",
331+
"directed": True,
332+
"nodes": ["0", "0_1", "0_2", "0_1_1", "0_2_1"],
333+
"edges": [("0_1", "0"), ("0_2", "0"), ("0_1_1", "0_1"), ("0_2_1", "0_2")],
334+
}
335+
labels = {
336+
ID("0"): self.label_0,
337+
ID("0_1"): self.label_0_1,
338+
# ID("0_2"): self.label_0_2,
339+
ID("0_1_1"): self.label_0_1_1,
340+
# ID("0_1_2"): self.label_0_1_2,
341+
ID("0_2_1"): self.label_0_2_1,
342+
}
343+
expected_backward = LabelTree()
344+
for node in labels.values():
345+
expected_backward.add_node(node)
346+
for parent, child in [
347+
(self.label_0, self.label_0_1),
348+
# (self.label_0, self.label_0_2),
349+
(self.label_0_1, self.label_0_1_1),
350+
# (self.label_0_2, self.label_0_2_1),
351+
]:
352+
expected_backward.add_child(parent, child)
353+
actual_backward = LabelTreeMapper.backward(instance=forward, all_labels=labels)
354+
assert LabelTreeMapper.forward(actual_backward) == LabelTreeMapper.forward(expected_backward)

0 commit comments

Comments
 (0)