File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed
art/estimators/classification Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -392,18 +392,21 @@ def fit( # pylint: disable=W0221
392392 # Check label shape
393393 y_preprocessed = self .reduce_labels (y_preprocessed )
394394
395- num_batch = int (np .ceil (len (x_preprocessed ) / float (batch_size )))
395+ num_batch = int (np .floor (len (x_preprocessed ) / float (batch_size )))
396396 ind = np .arange (len (x_preprocessed ))
397397
398+ x_preprocessed = torch .from_numpy (x_preprocessed ).to (self ._device )
399+ y_preprocessed = torch .from_numpy (y_preprocessed ).to (self ._device )
400+
398401 # Start training
399402 for _ in range (nb_epochs ):
400403 # Shuffle the examples
401404 random .shuffle (ind )
402405
403406 # Train for one epoch
404407 for m in range (num_batch ):
405- i_batch = torch . from_numpy ( x_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]). to ( self . _device )
406- o_batch = torch . from_numpy ( y_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]). to ( self . _device )
408+ i_batch = x_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]
409+ o_batch = y_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]
407410
408411 # Zero the parameter gradients
409412 self ._optimizer .zero_grad ()
You can’t perform that action at this time.
0 commit comments