Skip to content

Commit c528ee4

Browse files
authored
Speedup (#75)
1. Add packed RNNs 2. Add file-parallel read
1 parent 0b7f211 commit c528ee4

File tree

6 files changed

+113
-34
lines changed

6 files changed

+113
-34
lines changed

hotpp/data/dataset.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ class HotppDataset(torch.utils.data.IterableDataset):
7575
position: Sample position (`random` or `last`).
7676
rename: A dictionary for mapping field names during read.
7777
fields: A list of fields to keep in data. Other fields will be discarded.
78+
offset: Skip some initial records.
79+
limit: If set, limit the number of elements in the dataset.
7880
drop_nans: A list of fields to skip nans for.
7981
add_seq_fields: A dictionary with additional constant fields.
8082
global_target_fields: The name of the target field or a list of fields. Global targets are assigned to sequences.
@@ -91,11 +93,15 @@ def __init__(self, data,
9193
fields=None,
9294
id_field="id",
9395
timestamps_field="timestamps",
96+
offset=0,
97+
limit=None,
9498
drop_nans=None,
9599
add_seq_fields=None,
96100
global_target_fields=None,
97101
local_targets_fields=None,
98102
local_targets_indices_field=None):
103+
if (limit is not None) and (min_required_length or drop_nans):
104+
raise NotImplementedError("Can't combine `limit` with input filters.")
99105
super().__init__()
100106
if isinstance(data, str):
101107
self.filenames = list(sorted(parquet_file_scan(data)))
@@ -105,9 +111,26 @@ def __init__(self, data,
105111
raise ValueError(f"Unknown data type: {type(data)}")
106112
if not self.filenames:
107113
raise RuntimeError("Empty dataset")
108-
self.total_length = sum(map(get_parquet_length, self.filenames))
109-
self.random_split = random_split
110-
self.random_part = random_part
114+
if self.filenames and ((random_split != 1) or (random_part != "train")):
115+
if limit is not None:
116+
raise NotImplementedError("Can't combine `limit` with splitting.")
117+
if random_part not in {"train", "val"}:
118+
raise ValueError(f"Unknown random part: {random_part}. Must be either `train` or `val`.")
119+
s = 1000000000
120+
root = os.path.commonprefix(self.filenames)
121+
selected_filenames = []
122+
for filename in self.filenames:
123+
h = immutable_hash(os.path.relpath(filename, root))
124+
in_train = h % s <= s * random_split
125+
if not (in_train ^ (random_part == "train")):
126+
selected_filenames.append(filename)
127+
self.filenames = selected_filenames
128+
self.offset = offset
129+
self.limit = limit
130+
self.total_length = max(0, sum(map(get_parquet_length, self.filenames)) - offset)
131+
if self.limit is not None:
132+
self.total_length = min(self.limit, self.total_length)
133+
111134
self.min_length = min_length
112135
self.max_length = max_length
113136
self.position = position
@@ -134,7 +157,7 @@ def __init__(self, data,
134157

135158
def replace_files(self, filenames, **kwargs):
136159
names = set(inspect.signature(self.__init__).parameters.keys())
137-
names = names - {"self", "data"}
160+
names = names - {"self", "data", "random_split", "random_part"}
138161
kwargs = {name: getattr(self, name) for name in names} | kwargs
139162
return HotppDataset(filenames, **kwargs)
140163

@@ -190,16 +213,12 @@ def __len__(self):
190213
return self.total_length
191214

192215
def __iter__(self):
193-
if self.filenames:
194-
root = os.path.commonprefix(self.filenames)
216+
total = 0
195217
for filename in self.filenames:
196-
if (self.random_split != 1) or (self.random_part != "train"):
197-
s = 1000000000
198-
h = immutable_hash(os.path.relpath(filename, root))
199-
in_train = h % s <= s * self.random_split
200-
if in_train ^ (self.random_part == "train"):
218+
for rec in read_pyarrow_file(filename):
219+
total += 1
220+
if total <= self.offset:
201221
continue
202-
for rec in read_pyarrow_file(filename, use_threads=True):
203222
for src, dst in self.rename.items():
204223
if src not in rec:
205224
raise RuntimeError(f"The field `{src}` not found")
@@ -217,6 +236,8 @@ def __iter__(self):
217236
if skip:
218237
continue
219238
yield self.process(features)
239+
if (self.limit is not None) and (total - self.offset == self.limit):
240+
return
220241

221242
def _make_batch(self, by_name, batch_size, seq_feature_name=None):
222243
# Compute lengths.
@@ -277,14 +298,16 @@ class ShuffledDistributedDataset(torch.utils.data.IterableDataset):
277298
Args:
278299
parallelize: Parallel reading mode, either `records` (better granularity) or `files` (faster).
279300
"""
280-
def __init__(self, dataset, rank=None, world_size=None, cache_size=None, parallelize=DEFAULT_PARALLELIZM, seed=0):
301+
def __init__(self, dataset, rank=None, world_size=None, cache_size=None, parallelize=DEFAULT_PARALLELIZM, seed=0,
302+
drop_last=False):
281303
super().__init__()
282304
self.dataset = dataset
283305
self.rank = rank
284306
self.world_size = world_size
285307
self.cache_size = cache_size
286308
self.parallelize = parallelize
287309
self.seed = seed
310+
self.drop_last = drop_last
288311
self.epoch = 0
289312

290313
def _get_context(self):
@@ -320,19 +343,34 @@ def _iter_shuffled_files(self, dataset, seed, rank, world_size):
320343
filenames = list(dataset.filenames)
321344
if not filenames:
322345
raise RuntimeError("Empty dataset")
323-
root = os.path.commonprefix(filenames)
324-
splits = [list() for _ in range(world_size)]
325-
for filename in filenames:
326-
splits[immutable_hash(os.path.relpath(filename, root)) % world_size].append(filename)
327-
if any([len(split) == 0 for split in splits]):
328-
if rank == 0:
329-
warnings.warn(f"Some workers got zero files, switch to record parallelizm")
330-
yield from self._iter_shuffled_records(dataset, seed, rank, world_size)
331-
return
332-
dataset = dataset.replace_files(splits[rank])
346+
rnd = Random(seed)
347+
rnd.shuffle(filenames)
348+
lengths = list(map(get_parquet_length, filenames))
349+
records_per_worker = sum(lengths) // world_size
350+
if records_per_worker == 0:
351+
raise RuntimeError(f"Very small dataset for {world_size} workers")
352+
offset = records_per_worker * rank
353+
skipped = 0
354+
accepted = 0
355+
selected_filenames = []
356+
for filename, length in zip(filenames, lengths):
357+
if skipped + accepted + length <= offset:
358+
skipped += length
359+
elif accepted >= records_per_worker:
360+
break
361+
else:
362+
selected_filenames.append(filename)
363+
accepted += length - max(0, offset - skipped - accepted)
364+
dataset = dataset.replace_files(selected_filenames,
365+
offset=offset - skipped,
366+
limit=records_per_worker if self.drop_last or rank != world_size - 1 else None)
333367
yield from self._iter_shuffled_records_impl(dataset, seed)
334368

335369
def _iter_shuffled_records(self, dataset, seed, rank, world_size):
370+
rnd = Random(seed)
371+
filenames = list(dataset.filenames)
372+
rnd.shuffle(filenames)
373+
dataset = dataset.replace_files(filenames)
336374
for i, item in enumerate(self._iter_shuffled_records_impl(dataset, seed)):
337375
if i % world_size == rank:
338376
yield item
@@ -342,9 +380,6 @@ def _iter_shuffled_records_impl(self, dataset, seed):
342380
yield from dataset
343381
else:
344382
rnd = Random(seed)
345-
filenames = list(dataset.filenames)
346-
rnd.shuffle(filenames)
347-
dataset = dataset.replace_files(filenames)
348383
cache = []
349384
for item in dataset:
350385
cache.append(item)

hotpp/data/module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def train_dataloader(self, rank=None, world_size=None):
103103
dataset = ShuffledDistributedDataset(self.train_data, rank=rank, world_size=world_size,
104104
cache_size=loader_params.pop("cache_size", 4096),
105105
parallelize=loader_params.pop("parallelize", DEFAULT_PARALLELIZM),
106-
seed=loader_params.pop("seed", 0))
106+
seed=loader_params.pop("seed", 0),
107+
drop_last=loader_params.get("drop_last", False))
107108
loader = torch.utils.data.DataLoader(
108109
dataset=dataset,
109110
collate_fn=dataset.dataset.collate_fn,

hotpp/data/padded_batch.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ def to(self, *args, **kwargs):
128128

129129
@property
130130
def seq_len_mask(self):
131-
"""mask with B*T size for valid tokens in `payload`
132-
"""
131+
"""A mask with (B, L) size for valid tokens in `payload`."""
133132
if type(self._payload) is dict:
134133
name = self.seq_names[0]
135134
l = self._payload[name].shape[1]
@@ -139,3 +138,19 @@ def seq_len_mask(self):
139138
if self._left:
140139
indices = indices.flip(0)
141140
return indices[None] < self._lengths[:, None]
141+
142+
def pack(self):
143+
"""Create a PyTorch packed sequence object."""
144+
if self.left:
145+
raise NotImplementedError("Can't pack left padding.")
146+
if not isinstance(self._payload, torch.Tensor):
147+
raise ValueError("Can pack only tensor batches.")
148+
return torch.nn.utils.rnn.pack_padded_sequence(self._payload, self._lengths.cpu(), batch_first=True, enforce_sorted=False)
149+
150+
@staticmethod
151+
def unpack(packed, padding_value=0.0, total_length=None):
152+
"""Create PaddedBatch from a PyTorch packed sequence."""
153+
payload, lengths = torch.nn.utils.rnn.pad_packed_sequence(packed, batch_first=True,
154+
padding_value=padding_value,
155+
total_length=total_length)
156+
return PaddedBatch(payload, lengths)

hotpp/nn/encoder/rnn/simple.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class GRU(torch.nn.GRU):
99
"""GRU interface."""
10-
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0):
10+
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0, pack=False):
1111
super().__init__(
1212
input_size,
1313
hidden_size,
@@ -16,6 +16,7 @@ def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0):
1616
dropout=dropout
1717
)
1818
self._hidden_size = hidden_size
19+
self._pack = pack
1920

2021
@property
2122
def delta_time(self):
@@ -47,7 +48,8 @@ def forward(self, x: PaddedBatch, time_deltas: PaddedBatch,
4748
Outputs with shape (B, L, D) and states with shape (N, B, D) or (N, B, L, D), where
4849
N is the number of layers.
4950
"""
50-
outputs, _ = super().forward(x.payload, states) # (B, L, D).
51+
outputs, _ = super().forward(x.pack() if self._pack else x.payload, states) # (B, L, D).
52+
outputs = PaddedBatch.unpack(outputs, total_length=x.shape[1]).payload if self._pack else outputs
5153
if not return_states:
5254
output_states = None
5355
elif return_states == "last":
@@ -81,7 +83,7 @@ def interpolate(self, states: Tensor, time_deltas: PaddedBatch) -> PaddedBatch:
8183

8284
class LSTM(torch.nn.LSTM):
8385
"""LSTM interface."""
84-
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0):
86+
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0, pack=False):
8587
super().__init__(
8688
input_size,
8789
hidden_size,
@@ -90,6 +92,7 @@ def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0):
9092
dropout=dropout
9193
)
9294
self._hidden_size = hidden_size
95+
self._pack = pack
9396

9497
@property
9598
def delta_time(self):
@@ -121,7 +124,8 @@ def forward(self, x: PaddedBatch, time_deltas: PaddedBatch,
121124
Outputs with shape (B, L, D) and states with shape (N, B, D) or (N, B, L, D), where
122125
N is the number of layers.
123126
"""
124-
outputs, _ = super().forward(x.payload, states) # (B, L, D).
127+
outputs, _ = super().forward(x.pack() if self._pack else x.payload, states) # (B, L, D).
128+
outputs = PaddedBatch.unpack(outputs, total_length=x.shape[1]).payload if self._pack else outputs
125129
if not return_states:
126130
output_states = None
127131
else:

tests/data/test_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def test_shuffle(self):
154154
self.assertNotEqual(ids1, ids2)
155155

156156
# Joined dataset, file parallelizm.
157+
world_size = 2
157158
data = HotppDataModule(train_path=[self.data15_path, self.data16_path],
158159
drop_last=False,
159160
parallelize="files",

tests/nn/encoder/test_rnn.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from hotpp.data import PaddedBatch
7-
from hotpp.nn.encoder.rnn import ContTimeLSTM, ODEGRU
7+
from hotpp.nn.encoder.rnn import ContTimeLSTM, ODEGRU, GRU, LSTM
88

99

1010
EPS = 1e-10
@@ -32,6 +32,29 @@ def lin_rk4(x0, a, b, dt):
3232
return x0 + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6
3333

3434

35+
class TestPacking(TestCase):
36+
def test_pack_unpack(self):
37+
batch = PaddedBatch(
38+
torch.tensor([
39+
[1, 2, 0],
40+
[3, 0, 0],
41+
[4, 5, 6]
42+
]).float().unsqueeze(2).expand(3, 3, 8),
43+
torch.tensor([2, 1, 3])
44+
)
45+
time_deltas = None
46+
for cls in [GRU, LSTM]:
47+
model_gt = cls(8, 16, pack=False)
48+
model = cls(8, 16, pack=True)
49+
model.load_state_dict(model_gt.state_dict())
50+
states = torch.randn(1, 3, 16) if isinstance(model, GRU) else (torch.randn(1, 3, 16), torch.randn(1, 3, 16))
51+
output_gt, _ = model_gt(batch, time_deltas, states=states)
52+
output_gt.payload.masked_fill_(~output_gt.seq_len_mask.unsqueeze(2), 0)
53+
output, _ = model(batch, time_deltas, states=states)
54+
self.assertTrue((output_gt.seq_lens == output.seq_lens).all())
55+
self.assertTrue(output_gt.payload.isclose(output.payload).all())
56+
57+
3558
class TestContTimeLSTM(TestCase):
3659
def test_simple_parameters(self):
3760
rnn = ContTimeLSTM(1, 1)

0 commit comments

Comments
 (0)