Skip to content

Commit b2c8c2a

Browse files
committed
bug fix for prediction
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 292910b commit b2c8c2a

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

art/estimators/classification/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def predict( # pylint: disable=W0221
322322
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
323323

324324
results_list = []
325-
for x_batch in dataloader:
325+
for (x_batch,) in dataloader:
326326
# Move inputs to device
327327
x_batch = x_batch.to(self._device)
328328

art/estimators/regression/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def predict( # pylint: disable=W0221
260260
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
261261

262262
results_list = []
263-
for x_batch in dataloader:
263+
for (x_batch,) in dataloader:
264264
# Move inputs to device
265265
x_batch = x_batch.to(self._device)
266266

0 commit comments

Comments
 (0)