@@ -378,6 +378,8 @@ def fit( # pylint: disable=W0221
378378 """
379379 import torch # lgtm [py/repeated-import]
380380
381+ use_ffcv = kwargs .get ("ffcv" )
382+
381383 # Set model mode
382384 self ._model .train (mode = training_mode )
383385
@@ -395,15 +397,157 @@ def fit( # pylint: disable=W0221
395397 num_batch = int (np .ceil (len (x_preprocessed ) / float (batch_size )))
396398 ind = np .arange (len (x_preprocessed ))
397399
400+ if use_ffcv :
401+ self ._fit_ffcv (
402+ x = x_preprocessed ,
403+ y = y_preprocessed ,
404+ batch_size = batch_size ,
405+ nb_epochs = nb_epochs ,
406+ training_mode = training_mode ,
407+ ** kwargs ,
408+ )
409+ else :
410+ # Start training
411+ for _ in range (nb_epochs ):
412+ # Shuffle the examples
413+ random .shuffle (ind )
414+
415+ # Train for one epoch
416+ for m in range (num_batch ):
417+ i_batch = torch .from_numpy (x_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]).to (
418+ self ._device
419+ )
420+ o_batch = torch .from_numpy (y_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]).to (
421+ self ._device
422+ )
423+
424+ # Zero the parameter gradients
425+ self ._optimizer .zero_grad ()
426+
427+ # Perform prediction
428+ model_outputs = self ._model (i_batch )
429+
430+ # Form the loss function
431+ loss = self ._loss (model_outputs [- 1 ], o_batch ) # lgtm [py/call-to-non-callable]
432+
433+ # Do training
434+ if self ._use_amp : # pragma: no cover
435+ from apex import amp # pylint: disable=E0611
436+
437+ with amp .scale_loss (loss , self ._optimizer ) as scaled_loss :
438+ scaled_loss .backward ()
439+
440+ else :
441+ loss .backward ()
442+
443+ self ._optimizer .step ()
444+
445+ def _fit_ffcv (
446+ self ,
447+ x : np .ndarray ,
448+ y : np .ndarray ,
449+ batch_size : int = 128 ,
450+ nb_epochs : int = 10 ,
451+ training_mode : bool = True ,
452+ ** kwargs ,
453+ ) -> None :
454+ """
455+ Fit the classifier on the training set `(x, y)`.
456+
457+ :param x: Training data.
458+ :param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes) or index labels of
459+ shape (nb_samples,).
460+ :param batch_size: Size of batches.
461+ :param nb_epochs: Number of epochs to use for training.
462+ :param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
463+ :param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
464+ and providing it takes no effect.
465+ """
466+ ind = np .arange (len (x ))
467+
468+ # FFCV - prepare
469+ from ffcv .writer import DatasetWriter
470+ from ffcv .fields import NDArrayField
471+
472+ # Your dataset (`torch.utils.data.Dataset`) of (image, label) pairs
473+ # my_dataset = make_my_dataset()
474+
475+ class NumpyDataset :
476+ def __init__ (self , x , y ):
477+ self .X = x
478+ self .Y = y
479+
480+ def __getitem__ (self , idx ):
481+ return (self .X [idx ], self .Y [idx ])
482+
483+ def __len__ (self ):
484+ return len (self .X )
485+
486+ my_dataset = NumpyDataset (x , y )
487+
488+ write_path = "/home/bbuesser/tmp/ffcv/ds.beton"
489+
490+ # Pass a type for each data field
491+ jpeg_quality = 50
492+
493+ writer = DatasetWriter (
494+ write_path ,
495+ {
496+ # Tune options to optimize dataset size, throughput at train-time
497+ # 'image': RGBImageField(max_resolution=256, jpeg_quality=jpeg_quality),
498+ "image" : NDArrayField (dtype = x .dtype , shape = (1 , 28 , 28 )),
499+ "label" : NDArrayField (dtype = y .dtype , shape = (10 ,)),
500+ },
501+ )
502+
503+ # Write dataset
504+ writer .from_indexed_dataset (my_dataset )
505+
506+ # FFCV
507+ from ffcv .loader import Loader , OrderOption
508+ from ffcv .transforms import ToTensor , ToDevice , ToTorchImage , Cutout
509+ from ffcv .fields .decoders import IntDecoder , RandomResizedCropRGBImageDecoder , NDArrayDecoder
510+
511+ # Random resized crop
512+ # decoder = RandomResizedCropRGBImageDecoder((224, 224))
513+
514+ # Data decoding and augmentation
515+ # image_pipeline = [decoder, Cutout(), ToTensor(), ToTorchImage(), ToDevice(0)]
516+ image_pipeline = [NDArrayDecoder (), ToTensor ()]
517+ label_pipeline = [NDArrayDecoder (), ToTensor ()]
518+
519+ # Pipeline for each data field
520+ pipelines = {"image" : image_pipeline , "label" : label_pipeline }
521+
522+ # Replaces PyTorch data loader (`torch.utils.data.Dataloader`)
523+ # write_path = "/home/bbuesser/tmp/ffcv/"
524+ bs = batch_size
525+ num_workers = 1
526+ # loader = Loader(
527+ # write_path, batch_size=bs, num_workers=num_workers, order=OrderOption.RANDOM, pipelines=pipelines
528+ # )
529+ loader = Loader (
530+ write_path ,
531+ batch_size = bs ,
532+ num_workers = num_workers ,
533+ order = OrderOption .RANDOM ,
534+ pipelines = pipelines ,
535+ os_cache = True ,
536+ )
537+
398538 # Start training
399539 for _ in range (nb_epochs ):
540+ print (_ )
400541 # Shuffle the examples
401542 random .shuffle (ind )
402543
403544 # Train for one epoch
404- for m in range (num_batch ):
405- i_batch = torch .from_numpy (x_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]).to (self ._device )
406- o_batch = torch .from_numpy (y_preprocessed [ind [m * batch_size : (m + 1 ) * batch_size ]]).to (self ._device )
545+ # for m in range(num_batch):
546+ # i_batch = torch.from_numpy(x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
547+ # o_batch = torch.from_numpy(y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]).to(self._device)
548+ from tqdm import tqdm
549+
550+ for i , (i_batch , o_batch ) in enumerate (tqdm (loader )):
407551
408552 # Zero the parameter gradients
409553 self ._optimizer .zero_grad ()
0 commit comments