Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 61 additions & 26 deletions hotpp/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class HotppDataset(torch.utils.data.IterableDataset):
position: Sample position (`random` or `last`).
rename: A dictionary for mapping field names during read.
fields: A list of fields to keep in data. Other fields will be discarded.
offset: Skip some initial records.
limit: If set, limit the number of elements in the dataset.
drop_nans: A list of fields to skip nans for.
add_seq_fields: A dictionary with additional constant fields.
global_target_fields: The name of the target field or a list of fields. Global targets are assigned to sequences.
Expand All @@ -91,11 +93,15 @@ def __init__(self, data,
fields=None,
id_field="id",
timestamps_field="timestamps",
offset=0,
limit=None,
drop_nans=None,
add_seq_fields=None,
global_target_fields=None,
local_targets_fields=None,
local_targets_indices_field=None):
if (limit is not None) and (min_required_length or drop_nans):
raise NotImplementedError("Can't combine `limit` with input filters.")
super().__init__()
if isinstance(data, str):
self.filenames = list(sorted(parquet_file_scan(data)))
Expand All @@ -105,9 +111,26 @@ def __init__(self, data,
raise ValueError(f"Unknown data type: {type(data)}")
if not self.filenames:
raise RuntimeError("Empty dataset")
self.total_length = sum(map(get_parquet_length, self.filenames))
self.random_split = random_split
self.random_part = random_part
if self.filenames and ((random_split != 1) or (random_part != "train")):
if limit is not None:
raise NotImplementedError("Can't combine `limit` with splitting.")
if random_part not in {"train", "val"}:
raise ValueError(f"Unknown random part: {random_part}. Must be either `train` or `val`.")
s = 1000000000
root = os.path.commonprefix(self.filenames)
selected_filenames = []
for filename in self.filenames:
h = immutable_hash(os.path.relpath(filename, root))
in_train = h % s <= s * random_split
if not (in_train ^ (random_part == "train")):
selected_filenames.append(filename)
self.filenames = selected_filenames
self.offset = offset
self.limit = limit
self.total_length = max(0, sum(map(get_parquet_length, self.filenames)) - offset)
if self.limit is not None:
self.total_length = min(self.limit, self.total_length)

self.min_length = min_length
self.max_length = max_length
self.position = position
Expand All @@ -134,7 +157,7 @@ def __init__(self, data,

def replace_files(self, filenames, **kwargs):
names = set(inspect.signature(self.__init__).parameters.keys())
names = names - {"self", "data"}
names = names - {"self", "data", "random_split", "random_part"}
kwargs = {name: getattr(self, name) for name in names} | kwargs
return HotppDataset(filenames, **kwargs)

Expand Down Expand Up @@ -190,16 +213,12 @@ def __len__(self):
return self.total_length

def __iter__(self):
if self.filenames:
root = os.path.commonprefix(self.filenames)
total = 0
for filename in self.filenames:
if (self.random_split != 1) or (self.random_part != "train"):
s = 1000000000
h = immutable_hash(os.path.relpath(filename, root))
in_train = h % s <= s * self.random_split
if in_train ^ (self.random_part == "train"):
for rec in read_pyarrow_file(filename):
total += 1
if total <= self.offset:
continue
for rec in read_pyarrow_file(filename, use_threads=True):
for src, dst in self.rename.items():
if src not in rec:
raise RuntimeError(f"The field `{src}` not found")
Expand All @@ -217,6 +236,8 @@ def __iter__(self):
if skip:
continue
yield self.process(features)
if (self.limit is not None) and (total - self.offset == self.limit):
return

def _make_batch(self, by_name, batch_size, seq_feature_name=None):
# Compute lengths.
Expand Down Expand Up @@ -277,14 +298,16 @@ class ShuffledDistributedDataset(torch.utils.data.IterableDataset):
Args:
parallelize: Parallel reading mode, either `records` (better granularity) or `files` (faster).
"""
def __init__(self, dataset, rank=None, world_size=None, cache_size=None, parallelize=DEFAULT_PARALLELIZM, seed=0):
def __init__(self, dataset, rank=None, world_size=None, cache_size=None, parallelize=DEFAULT_PARALLELIZM, seed=0,
drop_last=False):
super().__init__()
self.dataset = dataset
self.rank = rank
self.world_size = world_size
self.cache_size = cache_size
self.parallelize = parallelize
self.seed = seed
self.drop_last = drop_last
self.epoch = 0

def _get_context(self):
Expand Down Expand Up @@ -320,19 +343,34 @@ def _iter_shuffled_files(self, dataset, seed, rank, world_size):
filenames = list(dataset.filenames)
if not filenames:
raise RuntimeError("Empty dataset")
root = os.path.commonprefix(filenames)
splits = [list() for _ in range(world_size)]
for filename in filenames:
splits[immutable_hash(os.path.relpath(filename, root)) % world_size].append(filename)
if any([len(split) == 0 for split in splits]):
if rank == 0:
warnings.warn(f"Some workers got zero files, switch to record parallelizm")
yield from self._iter_shuffled_records(dataset, seed, rank, world_size)
return
dataset = dataset.replace_files(splits[rank])
rnd = Random(seed)
rnd.shuffle(filenames)
lengths = list(map(get_parquet_length, filenames))
records_per_worker = sum(lengths) // world_size
if records_per_worker == 0:
raise RuntimeError(f"Very small dataset for {world_size} workers")
offset = records_per_worker * rank
skipped = 0
accepted = 0
selected_filenames = []
for filename, length in zip(filenames, lengths):
if skipped + accepted + length <= offset:
skipped += length
elif accepted >= records_per_worker:
break
else:
selected_filenames.append(filename)
accepted += length - max(0, offset - skipped - accepted)
dataset = dataset.replace_files(selected_filenames,
offset=offset - skipped,
limit=records_per_worker if self.drop_last or rank != world_size - 1 else None)
yield from self._iter_shuffled_records_impl(dataset, seed)

def _iter_shuffled_records(self, dataset, seed, rank, world_size):
rnd = Random(seed)
filenames = list(dataset.filenames)
rnd.shuffle(filenames)
dataset = dataset.replace_files(filenames)
for i, item in enumerate(self._iter_shuffled_records_impl(dataset, seed)):
if i % world_size == rank:
yield item
Expand All @@ -342,9 +380,6 @@ def _iter_shuffled_records_impl(self, dataset, seed):
yield from dataset
else:
rnd = Random(seed)
filenames = list(dataset.filenames)
rnd.shuffle(filenames)
dataset = dataset.replace_files(filenames)
cache = []
for item in dataset:
cache.append(item)
Expand Down
3 changes: 2 additions & 1 deletion hotpp/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def train_dataloader(self, rank=None, world_size=None):
dataset = ShuffledDistributedDataset(self.train_data, rank=rank, world_size=world_size,
cache_size=loader_params.pop("cache_size", 4096),
parallelize=loader_params.pop("parallelize", DEFAULT_PARALLELIZM),
seed=loader_params.pop("seed", 0))
seed=loader_params.pop("seed", 0),
drop_last=loader_params.get("drop_last", False))
loader = torch.utils.data.DataLoader(
dataset=dataset,
collate_fn=dataset.dataset.collate_fn,
Expand Down
19 changes: 17 additions & 2 deletions hotpp/data/padded_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ def to(self, *args, **kwargs):

@property
def seq_len_mask(self):
"""mask with B*T size for valid tokens in `payload`
"""
"""A mask with (B, L) size for valid tokens in `payload`."""
if type(self._payload) is dict:
name = self.seq_names[0]
l = self._payload[name].shape[1]
Expand All @@ -139,3 +138,19 @@ def seq_len_mask(self):
if self._left:
indices = indices.flip(0)
return indices[None] < self._lengths[:, None]

def pack(self):
"""Create a PyTorch packed sequence object."""
if self.left:
raise NotImplementedError("Can't pack left padding.")
if not isinstance(self._payload, torch.Tensor):
raise ValueError("Can pack only tensor batches.")
return torch.nn.utils.rnn.pack_padded_sequence(self._payload, self._lengths.cpu(), batch_first=True, enforce_sorted=False)

@staticmethod
def unpack(packed, padding_value=0.0, total_length=None):
"""Create PaddedBatch from a PyTorch packed sequence."""
payload, lengths = torch.nn.utils.rnn.pad_packed_sequence(packed, batch_first=True,
padding_value=padding_value,
total_length=total_length)
return PaddedBatch(payload, lengths)
12 changes: 8 additions & 4 deletions hotpp/nn/encoder/rnn/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class GRU(torch.nn.GRU):
"""GRU interface."""
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0):
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0, pack=False):
super().__init__(
input_size,
hidden_size,
Expand All @@ -16,6 +16,7 @@ def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0):
dropout=dropout
)
self._hidden_size = hidden_size
self._pack = pack

@property
def delta_time(self):
Expand Down Expand Up @@ -47,7 +48,8 @@ def forward(self, x: PaddedBatch, time_deltas: PaddedBatch,
Outputs with shape (B, L, D) and states with shape (N, B, D) or (N, B, L, D), where
N is the number of layers.
"""
outputs, _ = super().forward(x.payload, states) # (B, L, D).
outputs, _ = super().forward(x.pack() if self._pack else x.payload, states) # (B, L, D).
outputs = PaddedBatch.unpack(outputs, total_length=x.shape[1]).payload if self._pack else outputs
if not return_states:
output_states = None
elif return_states == "last":
Expand Down Expand Up @@ -81,7 +83,7 @@ def interpolate(self, states: Tensor, time_deltas: PaddedBatch) -> PaddedBatch:

class LSTM(torch.nn.LSTM):
"""LSTM interface."""
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0):
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0, pack=False):
super().__init__(
input_size,
hidden_size,
Expand All @@ -90,6 +92,7 @@ def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0):
dropout=dropout
)
self._hidden_size = hidden_size
self._pack = pack

@property
def delta_time(self):
Expand Down Expand Up @@ -121,7 +124,8 @@ def forward(self, x: PaddedBatch, time_deltas: PaddedBatch,
Outputs with shape (B, L, D) and states with shape (N, B, D) or (N, B, L, D), where
N is the number of layers.
"""
outputs, _ = super().forward(x.payload, states) # (B, L, D).
outputs, _ = super().forward(x.pack() if self._pack else x.payload, states) # (B, L, D).
outputs = PaddedBatch.unpack(outputs, total_length=x.shape[1]).payload if self._pack else outputs
if not return_states:
output_states = None
else:
Expand Down
1 change: 1 addition & 0 deletions tests/data/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def test_shuffle(self):
self.assertNotEqual(ids1, ids2)

# Joined dataset, file parallelizm.
world_size = 2
data = HotppDataModule(train_path=[self.data15_path, self.data16_path],
drop_last=False,
parallelize="files",
Expand Down
25 changes: 24 additions & 1 deletion tests/nn/encoder/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from hotpp.data import PaddedBatch
from hotpp.nn.encoder.rnn import ContTimeLSTM, ODEGRU
from hotpp.nn.encoder.rnn import ContTimeLSTM, ODEGRU, GRU, LSTM


EPS = 1e-10
Expand Down Expand Up @@ -32,6 +32,29 @@ def lin_rk4(x0, a, b, dt):
return x0 + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6


class TestPacking(TestCase):
def test_pack_unpack(self):
batch = PaddedBatch(
torch.tensor([
[1, 2, 0],
[3, 0, 0],
[4, 5, 6]
]).float().unsqueeze(2).expand(3, 3, 8),
torch.tensor([2, 1, 3])
)
time_deltas = None
for cls in [GRU, LSTM]:
model_gt = cls(8, 16, pack=False)
model = cls(8, 16, pack=True)
model.load_state_dict(model_gt.state_dict())
states = torch.randn(1, 3, 16) if isinstance(model, GRU) else (torch.randn(1, 3, 16), torch.randn(1, 3, 16))
output_gt, _ = model_gt(batch, time_deltas, states=states)
output_gt.payload.masked_fill_(~output_gt.seq_len_mask.unsqueeze(2), 0)
output, _ = model(batch, time_deltas, states=states)
self.assertTrue((output_gt.seq_lens == output.seq_lens).all())
self.assertTrue(output_gt.payload.isclose(output.payload).all())


class TestContTimeLSTM(TestCase):
def test_simple_parameters(self):
rnn = ContTimeLSTM(1, 1)
Expand Down