Skip to content

Commit 707ed46

Browse files
committed
fix yolo missing device
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 9afad68 commit 707ed46

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

art/estimators/object_detection/pytorch_yolo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,11 @@ def translate_labels_x1y1x2y2_to_xcycwh(
9292
import torch
9393

9494
labels_xcycwh_list = []
95+
device = labels_x1y1x2y2[0]["boxes"].device
9596

9697
for i, label_dict in enumerate(labels_x1y1x2y2):
9798
# create 2D tensor to encode labels and bounding boxes
98-
labels = torch.zeros(len(label_dict["boxes"]), 6)
99+
labels = torch.zeros(len(label_dict["boxes"]), 6, device=device)
99100
labels[:, 0] = i
100101
labels[:, 1] = label_dict["labels"]
101102
labels[:, 2:6] = label_dict["boxes"]

0 commit comments

Comments
 (0)