Skip to content

Commit fe4a1aa

Browse files
authored
Merge pull request #1785 from abigailgold/dev_1.11.1_fix_pytorch
Fix exception in pytorch classifier predict method
2 parents 48f7609 + 9aad906 commit fe4a1aa

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

art/estimators/classification/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def predict( # pylint: disable=W0221
322322
output = model_outputs[-1]
323323
output = output.detach().cpu().numpy().astype(np.float32)
324324
if len(output.shape) == 1:
325-
output = np.expand_dims(output.detach().cpu().numpy(), axis=1).astype(np.float32)
325+
output = np.expand_dims(output, axis=1).astype(np.float32)
326326

327327
results_list.append(output)
328328

0 commit comments

Comments
 (0)