diff --git a/examples/mnist_lightning.py b/examples/mnist_lightning.py index 42d93a3d..9313dbcd 100644 --- a/examples/mnist_lightning.py +++ b/examples/mnist_lightning.py @@ -99,11 +99,8 @@ def configure_optimizers(self): optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=0) if self.enable_dp: - data_loader = ( - # soon there will be a fancy way to access train dataloader, - # see https://github.com/PyTorchLightning/pytorch-lightning/issues/10430 - self.trainer._data_connector._train_dataloader_source.dataloader() - ) + self.trainer.fit_loop.setup_data() + data_loader = self.trainer.train_dataloader # transform (model, optimizer, dataloader) to DP-versions if hasattr(self, "dp"): diff --git a/opacus/utils/batch_memory_manager.py b/opacus/utils/batch_memory_manager.py index c5d6dcc0..a2e2de62 100644 --- a/opacus/utils/batch_memory_manager.py +++ b/opacus/utils/batch_memory_manager.py @@ -16,12 +16,13 @@ from typing import List import numpy as np +from torch.utils.data import BatchSampler, DataLoader, Sampler + from opacus.optimizers import DPOptimizer from opacus.utils.uniform_sampler import ( DistributedUniformWithReplacementSampler, UniformWithReplacementSampler, ) -from torch.utils.data import BatchSampler, DataLoader, Sampler class BatchSplittingSampler(Sampler[List[int]]): @@ -71,13 +72,17 @@ def __iter__(self): def __len__(self): if isinstance(self.sampler, BatchSampler): return int( - len(self.sampler) * (self.sampler.batch_size / self.max_batch_size) + np.ceil( + len(self.sampler) * (self.sampler.batch_size / self.max_batch_size) + ) ) elif isinstance(self.sampler, UniformWithReplacementSampler) or isinstance( self.sampler, DistributedUniformWithReplacementSampler ): expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples - return int(len(self.sampler) * (expected_batch_size / self.max_batch_size)) + return int( + np.ceil(len(self.sampler) * (expected_batch_size / self.max_batch_size)) + ) return len(self.sampler)