We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9afad68 commit 707ed46Copy full SHA for 707ed46
art/estimators/object_detection/pytorch_yolo.py
@@ -92,10 +92,11 @@ def translate_labels_x1y1x2y2_to_xcycwh(
92
import torch
93
94
labels_xcycwh_list = []
95
+ device = labels_x1y1x2y2[0]["boxes"].device
96
97
for i, label_dict in enumerate(labels_x1y1x2y2):
98
# create 2D tensor to encode labels and bounding boxes
- labels = torch.zeros(len(label_dict["boxes"]), 6)
99
+ labels = torch.zeros(len(label_dict["boxes"]), 6, device=device)
100
labels[:, 0] = i
101
labels[:, 1] = label_dict["labels"]
102
labels[:, 2:6] = label_dict["boxes"]
0 commit comments