Skip to content

Commit d3ad194

Browse files
authored
Merge pull request #18 from msamogh/torch-1.2
Support Torch 1.2
2 parents 32955e7 + a9aff15 commit d3ad194

File tree

5 files changed

+51
-12
lines changed

5 files changed

+51
-12
lines changed

nonechucks/__init__.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,43 @@
1-
from .dataset import SafeDataset
2-
from .sampler import SafeSampler
3-
from .dataloader import SafeDataLoader
1+
import logging
42

5-
__all__ = ["SafeDataset", "SafeSampler", "SafeDataLoader"]
3+
import torch
4+
import torch.utils.data
5+
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def _get_pytorch_version():
11+
version = torch.__version__
12+
major, minor, patch = [int(x) for x in version.split(".")]
13+
if major != 1:
14+
raise RuntimeError(
15+
"nonechucks only supports PyTorch major version 1 at the moment."
16+
)
17+
if minor > 2:
18+
logger.warn(
19+
"nonechucks may not work properly with this version of PyTorch ({}). "
20+
"It has only been tested on PyTorch versions 1.0, 1.1, and 1.2".format(
21+
version
22+
)
23+
)
24+
return major, minor
25+
26+
27+
MAJOR, MINOR = _get_pytorch_version()
28+
29+
if MINOR > 1:
30+
SingleProcessDataLoaderIter = (
31+
torch.utils.data.dataloader._SingleProcessDataLoaderIter
32+
)
33+
MultiProcessingDataLoaderIter = (
34+
torch.utils.data.dataloader._MultiProcessingDataLoaderIter
35+
)
36+
else:
37+
SingleProcessDataLoaderIter = torch.utils.data.dataloader._DataLoaderIter
38+
MultiProcessingDataLoaderIter = torch.utils.data.dataloader._DataLoaderIter
39+
40+
41+
from nonechucks.dataset import SafeDataset
42+
from nonechucks.sampler import SafeSampler
43+
from nonechucks.dataloader import SafeDataLoader

nonechucks/dataloader.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
except ImportError:
99
from torch.utils.data._utils.collate import default_collate
1010

11-
from .dataset import SafeDataset
12-
from .sampler import SafeSampler
13-
from .utils import batch_len, collate_batches, slice_batch
11+
from nonechucks import SingleProcessDataLoaderIter, MultiProcessingDataLoaderIter
12+
from nonechucks.dataset import SafeDataset
13+
from nonechucks.sampler import SafeSampler
14+
from nonechucks.utils import batch_len, collate_batches, slice_batch
1415

1516

1617
class _SafeDataLoaderCaller(type):
@@ -44,7 +45,7 @@ def _restore_default_samplers(cls):
4445
data.dataloader.RandomSampler = cls.random
4546

4647

47-
class _SafeDataLoaderIter(data.dataloader._DataLoaderIter):
48+
class _SafeDataLoaderIter(MultiProcessingDataLoaderIter):
4849
def __init__(self, loader):
4950
super().__init__(loader)
5051
self.batch_size = loader.batch_size
@@ -145,4 +146,4 @@ def __init__(self, dataset, **kwargs):
145146
def __iter__(self):
146147
if self.num_workers > 0:
147148
return _SafeDataLoaderIter(self)
148-
return data.dataloader._DataLoaderIter(self)
149+
return SingleProcessDataLoaderIter(self)

nonechucks/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.utils.data
33

4-
from .utils import memoize
4+
from nonechucks.utils import memoize
55

66

77
class SafeDataset(torch.utils.data.Dataset):

nonechucks/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.utils.data
33

4-
from .dataset import SafeDataset
4+
from nonechucks.dataset import SafeDataset
55

66

77
class SafeSampler(torch.utils.data.sampler.Sampler):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
setup(
55
name="nonechucks",
6-
version="0.3.1",
6+
version="0.4.0",
77
url="https://github.com/msamogh/nonechucks",
88
license="MIT",
99
author="Amogh Mannekote",

0 commit comments

Comments
 (0)