From f120fd3acda823957e4a606ae69d514dbf6cab93 Mon Sep 17 00:00:00 2001 From: zhiweichen <30588754+zhiweichen12@users.noreply.github.com> Date: Wed, 21 Nov 2018 10:26:32 +0800 Subject: [PATCH] Update util.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 原code报错:RuntimeError: Expected object of type torch.FloatTensor but found type torch.cuda.FloatTensor for argument #4 'other' 新增代码:prediction = prediction.cuda() --- util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/util.py b/util.py index 2ef4673e..6dbdb6b2 100644 --- a/util.py +++ b/util.py @@ -73,6 +73,7 @@ def predict_transform(prediction, inp_dim, anchors, num_classes, CUDA = True): if CUDA: x_offset = x_offset.cuda() y_offset = y_offset.cuda() + prediction = prediction.cuda() x_y_offset = torch.cat((x_offset, y_offset), 1).repeat(1,num_anchors).view(-1,2).unsqueeze(0)