Skip to content

Commit 9c079fd

Browse files
committed
Revert default to record parallelizm
1 parent dc74d32 commit 9c079fd

File tree

6 files changed

+25
-19
lines changed

6 files changed

+25
-19
lines changed

hotpp/calibrate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from omegaconf import OmegaConf
77

8-
from hotpp.data import ShuffledDistributedDataset
8+
from hotpp.data import ShuffledDistributedDataset, DEFAULT_PARALLELIZM
99
from hotpp.data.module import HotppSampler
1010
from tqdm import tqdm
1111

@@ -19,6 +19,7 @@ def get_loader(dm):
1919
dataset = ShuffledDistributedDataset(dm.val_data, rank=None, world_size=None,
2020
num_workers=loader_params.get("num_workers", 0),
2121
cache_size=loader_params.pop("cache_size", 4096),
22+
parallelize=loader_params.pop("parallelize", DEFAULT_PARALLELIZM),
2223
seed=loader_params.pop("seed", 0))
2324
loader = torch.utils.data.DataLoader(
2425
dataset=dataset,

hotpp/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .dataset import HotppDataset, ShuffledDistributedDataset
1+
from .dataset import HotppDataset, ShuffledDistributedDataset, DEFAULT_PARALLELIZM
22
from .module import HotppDataModule
33
from .padded_batch import PaddedBatch

hotpp/data/dataset.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from .padded_batch import PaddedBatch
1818

1919

20+
DEFAULT_PARALLELIZM = "records"
21+
22+
2023
def immutable_hash(s):
2124
return int(hashlib.sha256(s.encode("utf-8")).hexdigest(), 16)
2225

@@ -90,18 +93,16 @@ def __init__(self, data,
9093
add_seq_fields=None,
9194
global_target_fields=None,
9295
local_targets_fields=None,
93-
local_targets_indices_field=None,
94-
allow_empty=False):
96+
local_targets_indices_field=None):
9597
super().__init__()
9698
if isinstance(data, str):
9799
self.filenames = list(sorted(parquet_file_scan(data)))
98100
elif isinstance(data, list):
99101
self.filenames = data
100102
else:
101103
raise ValueError(f"Unknown data type: {type(data)}")
102-
if (not self.filenames) and (not allow_empty):
104+
if not self.filenames:
103105
raise RuntimeError("Empty dataset")
104-
self.allow_empty = allow_empty
105106
self.total_length = sum(map(get_parquet_length, self.filenames))
106107
self.random_split = random_split
107108
self.random_part = random_part
@@ -268,7 +269,7 @@ class ShuffledDistributedDataset(torch.utils.data.IterableDataset):
268269
Args:
269270
parallelize: Parallel reading mode, either `records` (better granularity) or `files` (faster).
270271
"""
271-
def __init__(self, dataset, rank=None, world_size=None, cache_size=None, parallelize="files", seed=0):
272+
def __init__(self, dataset, rank=None, world_size=None, cache_size=None, parallelize=DEFAULT_PARALLELIZM, seed=0):
272273
super().__init__()
273274
self.dataset = dataset
274275
self.rank = rank
@@ -311,13 +312,16 @@ def _iter_shuffled_files(self, dataset, seed, rank, world_size):
311312
filenames = list(dataset.filenames)
312313
if not filenames:
313314
raise RuntimeError("Empty dataset")
314-
if len(filenames) < world_size:
315-
warnings.warn(f"{len(filenames)} files for {world_size} workers, switch to record parallelizm")
315+
root = os.path.commonprefix(filenames)
316+
splits = [list() for _ in range(world_size)]
317+
for filename in filenames:
318+
splits[immutable_hash(os.path.relpath(filename, root)) % world_size].append(filename)
319+
if any([len(split) == 0 for split in splits]):
320+
if rank == 0:
321+
warnings.warn(f"Some workers got zero files, switch to record parallelizm")
316322
yield from self._iter_shuffled_records(dataset, seed, rank, world_size)
317323
return
318-
root = os.path.commonprefix(filenames)
319-
subset = [filename for filename in filenames if immutable_hash(os.path.relpath(filename, root)) % world_size == rank]
320-
dataset = dataset.replace_files(subset, allow_empty=True)
324+
dataset = dataset.replace_files(splits[rank])
321325
yield from self._iter_shuffled_records_impl(dataset, seed)
322326

323327
def _iter_shuffled_records(self, dataset, seed, rank, world_size):

hotpp/data/module.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import pytorch_lightning as pl
3-
from .dataset import HotppDataset, ShuffledDistributedDataset
3+
from .dataset import HotppDataset, ShuffledDistributedDataset, DEFAULT_PARALLELIZM
44

55

66
def pop_loader_params(params):
@@ -102,7 +102,7 @@ def train_dataloader(self, rank=None, world_size=None):
102102
loader_params.update(self.train_loader_params)
103103
dataset = ShuffledDistributedDataset(self.train_data, rank=rank, world_size=world_size,
104104
cache_size=loader_params.pop("cache_size", 4096),
105-
parallelize=loader_params.pop("parallelize", "files"),
105+
parallelize=loader_params.pop("parallelize", DEFAULT_PARALLELIZM),
106106
seed=loader_params.pop("seed", 0))
107107
loader = torch.utils.data.DataLoader(
108108
dataset=dataset,
@@ -118,7 +118,7 @@ def val_dataloader(self, rank=None, world_size=None):
118118
loader_params = {"pin_memory": torch.cuda.is_available()}
119119
loader_params.update(self.val_loader_params)
120120
dataset = ShuffledDistributedDataset(self.val_data, rank=rank, world_size=world_size,
121-
parallelize=loader_params.pop("parallelize", "files")) # Disable shuffle.
121+
parallelize=loader_params.pop("parallelize", DEFAULT_PARALLELIZM)) # Disable shuffle.
122122
loader = torch.utils.data.DataLoader(
123123
dataset=dataset,
124124
collate_fn=dataset.dataset.collate_fn,
@@ -132,7 +132,7 @@ def test_dataloader(self, rank=None, world_size=None):
132132
loader_params = {"pin_memory": torch.cuda.is_available()}
133133
loader_params.update(self.test_loader_params)
134134
dataset = ShuffledDistributedDataset(self.test_data, rank=rank, world_size=world_size,
135-
parallelize=loader_params.pop("parallelize", "files")) # Disable shuffle.
135+
parallelize=loader_params.pop("parallelize", DEFAULT_PARALLELIZM)) # Disable shuffle.
136136
loader = torch.utils.data.DataLoader(
137137
dataset=dataset,
138138
collate_fn=dataset.dataset.collate_fn,

hotpp/embed.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torchmetrics.utilities import dim_zero_cat
1414

1515
from .common import get_trainer
16-
from .data import ShuffledDistributedDataset
16+
from .data import ShuffledDistributedDataset, DEFAULT_PARALLELIZM
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -136,7 +136,8 @@ def test_dataloader(self):
136136
loader_params = getattr(self.data, f"{self.split}_loader_params")
137137

138138
num_workers = loader_params.get("num_workers", 0)
139-
dataset = ShuffledDistributedDataset(dataset, rank=self.rank, world_size=self.world_size)
139+
dataset = ShuffledDistributedDataset(dataset, rank=self.rank, world_size=self.world_size,
140+
parallelize=loader_params.pop("parallelize", DEFAULT_PARALLELIZM))
140141
return torch.utils.data.DataLoader(
141142
dataset=dataset,
142143
collate_fn=dataset.dataset.collate_fn,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setuptools.setup(
1717
name="hotpp-benchmark",
18-
version="0.6.4",
18+
version="0.6.5",
1919
author="Ivan Karpukhin",
2020
author_email="karpuhini@yandex.ru",
2121
description="Evaluate generative event sequence models on the long horizon prediction task.",

0 commit comments

Comments
 (0)