Skip to content

Commit 042cdab

Browse files
committed
Fix workers creation
1 parent 9c2b618 commit 042cdab

File tree

4 files changed

+26
-18
lines changed

4 files changed

+26
-18
lines changed

hotpp/calibrate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@
55
import torch
66
from omegaconf import OmegaConf
77

8-
from hotpp.data import ShuffledDistributedDataset, DEFAULT_PARALLELIZM
8+
from hotpp.data import ShuffledDistributedDataset, DEFAULT_PARALLELIZM, get_default_loader_params
99
from hotpp.data.module import HotppSampler
1010
from tqdm import tqdm
1111

1212
logger = logging.getLogger(__name__)
1313

1414

1515
def get_loader(dm):
16-
loader_params = {"drop_last": False,
17-
"pin_memory": torch.cuda.is_available()}
16+
loader_params = get_default_loader_params()
1817
loader_params.update(dm.train_loader_params)
1918
dataset = ShuffledDistributedDataset(dm.val_data, rank=None, world_size=None,
2019
num_workers=loader_params.get("num_workers", 0),

hotpp/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .dataset import HotppDataset, ShuffledDistributedDataset, DEFAULT_PARALLELIZM
2-
from .module import HotppDataModule
2+
from .module import HotppDataModule, get_default_loader_params
33
from .padded_batch import PaddedBatch

hotpp/data/module.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,31 @@
1-
import torch
1+
import multiprocessing as mp
22
import pytorch_lightning as pl
3+
import torch
34
from .dataset import HotppDataset, ShuffledDistributedDataset, DEFAULT_PARALLELIZM
45

56

67
def pop_loader_params(params):
78
loader_params = {}
89
for key in ["seed", "num_workers", "batch_size", "cache_size", "parallelize", "drop_last", "prefetch_factor",
9-
"persistent_workers", "multiprocessing_context"]:
10+
"pin_memory", "persistent_workers", "multiprocessing_context"]:
1011
if key in params:
1112
loader_params[key] = params.pop(key)
1213
return loader_params
1314

1415

16+
def get_default_loader_params():
17+
default_loader_params = {
18+
"persistent_workers": True,
19+
"pin_memory": torch.cuda.is_available()
20+
}
21+
available_contexts = mp.get_all_start_methods()
22+
for context in ["forkserver", "spawn", "fork"]:
23+
if context in available_contexts:
24+
default_loader_params["multiprocessing_context"] = context
25+
break
26+
return default_loader_params
27+
28+
1529
class HotppSampler(torch.utils.data.DistributedSampler):
1630
def __init__(self, dataset):
1731
# Skip super init.
@@ -106,10 +120,8 @@ def splits(self):
106120
def train_dataloader(self, rank=None, world_size=None):
107121
rank = self.trainer.global_rank if rank is None else rank
108122
world_size = self.trainer.world_size if world_size is None else world_size
109-
loader_params = {"drop_last": True,
110-
"multiprocessing_context": "spawn",
111-
"persistent_workers": True,
112-
"pin_memory": torch.cuda.is_available()}
123+
loader_params = get_default_loader_params()
124+
loader_params.update({"drop_last": True})
113125
loader_params.update(self.train_loader_params)
114126
dataset = ShuffledDistributedDataset(self.train_data, rank=rank, world_size=world_size,
115127
cache_size=loader_params.pop("cache_size", 4096),
@@ -127,9 +139,7 @@ def train_dataloader(self, rank=None, world_size=None):
127139
def val_dataloader(self, rank=None, world_size=None):
128140
rank = self.trainer.global_rank if rank is None else rank
129141
world_size = self.trainer.world_size if world_size is None else world_size
130-
loader_params = {"multiprocessing_context": "spawn",
131-
"persistent_workers": True,
132-
"pin_memory": torch.cuda.is_available()}
142+
loader_params = get_default_loader_params()
133143
loader_params.update(self.val_loader_params)
134144
dataset = ShuffledDistributedDataset(self.val_data, rank=rank, world_size=world_size,
135145
parallelize=loader_params.pop("parallelize", DEFAULT_PARALLELIZM)) # Disable shuffle.
@@ -143,9 +153,7 @@ def val_dataloader(self, rank=None, world_size=None):
143153
def test_dataloader(self, rank=None, world_size=None):
144154
rank = self.trainer.global_rank if rank is None else rank
145155
world_size = self.trainer.world_size if world_size is None else world_size
146-
loader_params = {"multiprocessing_context": "spawn",
147-
"persistent_workers": True,
148-
"pin_memory": torch.cuda.is_available()}
156+
loader_params = get_default_loader_params()
149157
loader_params.update(self.test_loader_params)
150158
dataset = ShuffledDistributedDataset(self.test_data, rank=rank, world_size=world_size,
151159
parallelize=loader_params.pop("parallelize", DEFAULT_PARALLELIZM)) # Disable shuffle.

hotpp/embed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torchmetrics.utilities.distributed import gather_all_tensors
1414

1515
from .common import get_trainer
16-
from .data import ShuffledDistributedDataset, DEFAULT_PARALLELIZM
16+
from .data import ShuffledDistributedDataset, DEFAULT_PARALLELIZM, get_default_loader_params
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -125,7 +125,8 @@ def test_dataloader(self):
125125
collate_fn=dataset.dataset.collate_fn,
126126
shuffle=False,
127127
num_workers=num_workers,
128-
batch_size=loader_params.get("batch_size", 1)
128+
batch_size=loader_params.get("batch_size", 1),
129+
**get_default_loader_params()
129130
)
130131

131132

0 commit comments

Comments
 (0)