Skip to content

Commit c00332d

Browse files
author
Beat Buesser
committed
Update method fit of PyTorchClassifier
Signed-off-by: Beat Buesser <[email protected]>
1 parent 3e3a438 commit c00332d

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

art/estimators/classification/pytorch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,18 +392,21 @@ def fit( # pylint: disable=W0221
392392
# Check label shape
393393
y_preprocessed = self.reduce_labels(y_preprocessed)
394394

395-
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
395+
num_batch = int(np.floor(len(x_preprocessed) / float(batch_size)))
396396
ind = np.arange(len(x_preprocessed))
397397

398+
x_preprocessed = torch.from_numpy(x_preprocessed).to(self._device)
399+
y_preprocessed = torch.from_numpy(y_preprocessed).to(self._device)
400+
398401
# Start training
399402
for _ in range(nb_epochs):
400403
# Shuffle the examples
401404
random.shuffle(ind)
402405

403406
# Train for one epoch
404407
for m in range(num_batch):
405-
i_batch = torch.from_numpy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
406-
o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
408+
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
409+
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
407410

408411
# Zero the parameter gradients
409412
self._optimizer.zero_grad()

0 commit comments

Comments
 (0)