Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions hotpp/nn/encoder/transformer/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,13 @@ class SimpleTransformer(torch.nn.Module):
max_duration: Must be provided if time encodings are used.
min_time_step: The minimum time step (> 0). By default it is max_duration / n_positions.
rope: Either "time[-train]", "none" or None.
sos: Whether to use start token or not.
"""
def __init__(self, input_size, n_positions=1024, n_embd=768, n_layer=12, n_head=12,
n_inner=None, dropout=0.1, causal=False,
activation=torch.nn.functional.relu,
normalization=torch.nn.LayerNorm,
mlp="default", pos_type="pos-angular", rope=None, group_size=1,
mlp="default", pos_type="pos-angular", rope=None, group_size=1, sos=True,
max_duration=None, min_time_step=None):
super().__init__()
n_inner = n_inner if n_inner is not None else 4 * n_embd
Expand All @@ -287,6 +288,7 @@ def __init__(self, input_size, n_positions=1024, n_embd=768, n_layer=12, n_head=
self.causal = causal

self.input_projection = torch.nn.Linear(input_size, n_embd)
self.sos = torch.nn.Parameter(torch.randn(n_embd)) if sos else None

# We use norm_first by default.
# See the original paper: Xiong R. et al. "On layer normalization in the transformer architecture" ICML 2020.
Expand Down Expand Up @@ -334,6 +336,26 @@ def output_size(self):
def delta_time(self):
return False

def add_sos(self, embeddings, timestamps):
if self.sos is None:
return embeddings, timestamps
b, l, d = embeddings.payload.shape
lengths = torch.where(embeddings.seq_lens > 0, embeddings.seq_lens + 1, embeddings.seq_lens)
embeddings = torch.cat([self.sos[None, None].expand(b, 1, d), embeddings.payload], 1) # (B, 1 + L, D).
init_timestamps = timestamps.payload[:, :1] if l > 0 else torch.zeros(b, 1, device=timestamps.device, dtype=timestamps.dtype)
timestamps = torch.cat([init_timestamps, timestamps.payload], 1) # (B, 1 + L).
return PaddedBatch(embeddings, lengths), PaddedBatch(timestamps, lengths)

def remove_sos(self, outputs):
if outputs is None:
return outputs
if self.sos is None:
return outputs
if isinstance(outputs, PaddedBatch):
return PaddedBatch(outputs.payload[:, 1:], (outputs.seq_lens - 1).clip(min=0))
else:
return outputs[:, 1:]

def transform(self, embeddings, return_states=False, attention_mask=None):
"""Apply encoder after input projection and positional encoding.

Expand Down Expand Up @@ -388,12 +410,16 @@ def forward(self, x, timestamps, states=None, return_states=False, attention_mas
if return_states not in {False, "full"}:
raise ValueError(f"Unknown states mode: {return_states}")

embeddings = self.input_projection(x.payload) # (B, L, D).
embeddings = self.positional(embeddings, timestamps.payload) # (B, L, D).
embeddings = PaddedBatch(self.input_projection(x.payload), # (B, L, D).
x.seq_lens)
embeddings, timestamps = self.add_sos(embeddings, timestamps)
embeddings = self.positional(embeddings.payload, timestamps.payload) # (B, L, D).
embeddings = PaddedBatch(embeddings, x.seq_lens)
if self.rope is not None:
with self.rope.cache(timestamps.payload):
outputs, states = self.transform(embeddings, attention_mask=attention_mask)
else:
outputs, states = self.transform(embeddings, attention_mask=attention_mask)
outputs = self.remove_sos(outputs)
states = self.remove_sos(states)
return outputs, states
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setuptools.setup(
name="hotpp-benchmark",
version="0.6.4",
version="0.6.5",
author="Ivan Karpukhin",
author_email="karpuhini@yandex.ru",
description="Evaluate generative event sequence models on the long horizon prediction task.",
Expand Down