Skip to content

Commit 896c8a6

Browse files
authored
fix(video): append models only if they exist (#1117)
1 parent fef8455 commit 896c8a6

File tree

4 files changed

+91
-12
lines changed

4 files changed

+91
-12
lines changed

src/datachain/model/ultralytics/bbox.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def from_result(result: "Results") -> "YoloBBox":
3333
name = summary[0].get("name", "")
3434
box = (
3535
BBox.from_dict(summary[0]["box"], title=name)
36-
if "box" in summary[0]
36+
if summary[0].get("box")
3737
else BBox()
3838
)
3939
return YoloBBox(
@@ -69,7 +69,8 @@ def from_results(results: list["Results"]) -> "YoloBBoxes":
6969
cls.append(s["class"])
7070
names.append(name)
7171
confidence.append(s["confidence"])
72-
box.append(BBox.from_dict(s.get("box", {}), title=name))
72+
if s.get("box"):
73+
box.append(BBox.from_dict(s.get("box"), title=name))
7374
return YoloBBoxes(
7475
cls=cls,
7576
name=names,
@@ -102,7 +103,7 @@ def from_result(result: "Results") -> "YoloOBBox":
102103
name = summary[0].get("name", "")
103104
box = (
104105
OBBox.from_dict(summary[0]["box"], title=name)
105-
if "box" in summary[0]
106+
if summary[0].get("box")
106107
else OBBox()
107108
)
108109
return YoloOBBox(
@@ -138,7 +139,8 @@ def from_results(results: list["Results"]) -> "YoloOBBoxes":
138139
cls.append(s["class"])
139140
names.append(name)
140141
confidence.append(s["confidence"])
141-
box.append(OBBox.from_dict(s.get("box", {}), title=name))
142+
if s.get("box"):
143+
box.append(OBBox.from_dict(s.get("box"), title=name))
142144
return YoloOBBoxes(
143145
cls=cls,
144146
name=names,

src/datachain/model/ultralytics/pose.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ def from_result(result: "Results") -> "YoloPose":
5858
name = summary[0].get("name", "")
5959
box = (
6060
BBox.from_dict(summary[0]["box"], title=name)
61-
if "box" in summary[0]
61+
if summary[0].get("box")
6262
else BBox()
6363
)
6464
pose = (
6565
Pose3D.from_dict(summary[0]["keypoints"])
66-
if "keypoints" in summary[0]
66+
if summary[0].get("keypoints")
6767
else Pose3D()
6868
)
6969
return YoloPose(
@@ -102,8 +102,10 @@ def from_results(results: list["Results"]) -> "YoloPoses":
102102
cls.append(s["class"])
103103
names.append(name)
104104
confidence.append(s["confidence"])
105-
box.append(BBox.from_dict(s.get("box", {}), title=name))
106-
pose.append(Pose3D.from_dict(s.get("keypoints", {})))
105+
if s.get("box"):
106+
box.append(BBox.from_dict(s.get("box"), title=name))
107+
if s.get("keypoints"):
108+
pose.append(Pose3D.from_dict(s.get("keypoints")))
107109
return YoloPoses(
108110
cls=cls,
109111
name=names,

src/datachain/model/ultralytics/segment.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ def from_result(result: "Results") -> "YoloSegment":
3636
name = summary[0].get("name", "")
3737
box = (
3838
BBox.from_dict(summary[0]["box"], title=name)
39-
if "box" in summary[0]
39+
if summary[0].get("box")
4040
else BBox()
4141
)
4242
segment = (
4343
Segment.from_dict(summary[0]["segments"], title=name)
44-
if "segments" in summary[0]
44+
if summary[0].get("segments")
4545
else Segment()
4646
)
4747
return YoloSegment(
@@ -80,8 +80,10 @@ def from_results(results: list["Results"]) -> "YoloSegments":
8080
cls.append(s["class"])
8181
names.append(name)
8282
confidence.append(s["confidence"])
83-
box.append(BBox.from_dict(s.get("box", {}), title=name))
84-
segment.append(Segment.from_dict(s.get("segments", {}), title=name))
83+
if s.get("box"):
84+
box.append(BBox.from_dict(s.get("box"), title=name))
85+
if s.get("segments"):
86+
segment.append(Segment.from_dict(s.get("segments"), title=name))
8587
return YoloSegments(
8688
cls=cls,
8789
name=names,

tests/func/model/test_yolo.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,45 @@ def test_yolo_segment_from_results_empty(running_img):
205205
}
206206

207207

208+
def test_yolo_segment_from_results_empty_segments(running_img):
209+
result = Results(
210+
orig_img=running_img,
211+
path="running.jpeg",
212+
names={0: "person"},
213+
boxes=torch.tensor([[102.0, 84.0, 183.0, 238.0, 0.9078, 0.0]]),
214+
)
215+
216+
model = YoloSegment.from_result(result)
217+
assert model.model_dump() == {
218+
"cls": 0,
219+
"name": "person",
220+
"confidence": 0.9078,
221+
"box": {
222+
"coords": [102, 84, 183, 238],
223+
"title": "person",
224+
},
225+
"segment": {
226+
"title": "",
227+
"x": [],
228+
"y": [],
229+
},
230+
}
231+
232+
model = YoloSegments.from_results([result])
233+
assert model.model_dump() == {
234+
"cls": [0],
235+
"name": ["person"],
236+
"confidence": [0.9078],
237+
"box": [
238+
{
239+
"coords": [102, 84, 183, 238],
240+
"title": "person",
241+
}
242+
],
243+
"segment": [],
244+
}
245+
246+
208247
def test_yolo_seg_from_results(running_img, running_img_masks):
209248
result = Results(
210249
orig_img=running_img,
@@ -2239,6 +2278,40 @@ def test_yolo_pose_from_results_empty(running_img):
22392278
}
22402279

22412280

2281+
def test_yolo_pose_from_results_empty_poses(running_img):
2282+
result = Results(
2283+
orig_img=running_img,
2284+
path="running.jpeg",
2285+
names={0: "person"},
2286+
boxes=torch.tensor([[102.0, 84.0, 183.0, 238.0, 0.9078, 0.0]]),
2287+
)
2288+
2289+
model = YoloPose.from_result(result)
2290+
assert model.model_dump() == {
2291+
"cls": 0,
2292+
"name": "person",
2293+
"confidence": 0.9078,
2294+
"box": {
2295+
"coords": [102, 84, 183, 238],
2296+
"title": "person",
2297+
},
2298+
"pose": {
2299+
"x": [],
2300+
"y": [],
2301+
"visible": [],
2302+
},
2303+
}
2304+
2305+
model = YoloPoses.from_results([result])
2306+
assert model.model_dump() == {
2307+
"cls": [0],
2308+
"name": ["person"],
2309+
"confidence": [0.9078],
2310+
"box": [{"coords": [102, 84, 183, 238], "title": "person"}],
2311+
"pose": [],
2312+
}
2313+
2314+
22422315
def test_yolo_pose_from_results(running_img, running_img_masks):
22432316
result = Results(
22442317
orig_img=running_img,

0 commit comments

Comments
 (0)