Skip to content

Commit 7bed09d

Browse files
committed
revert check on dim for fit-generator and move to a separate PR
Signed-off-by: GiulioZizzo <[email protected]>
1 parent d8bab78 commit 7bed09d

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

art/estimators/classification/pytorch.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,10 +500,7 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
500500
if isinstance(o_batch, np.ndarray):
501501
o_batch = torch.argmax(torch.from_numpy(o_batch).to(self._device), dim=1)
502502
else:
503-
if o_batch.dim() > 1:
504-
o_batch = torch.argmax(o_batch.to(self._device), dim=1)
505-
else:
506-
o_batch = o_batch.to(self._device)
503+
o_batch = torch.argmax(o_batch.to(self._device), dim=1)
507504

508505
# Zero the parameter gradients
509506
self._optimizer.zero_grad()

0 commit comments

Comments
 (0)