diff --git a/hotpp/nn/encoder/transformer/simple.py b/hotpp/nn/encoder/transformer/simple.py index 3a032f0e..9233e370 100644 --- a/hotpp/nn/encoder/transformer/simple.py +++ b/hotpp/nn/encoder/transformer/simple.py @@ -269,12 +269,13 @@ class SimpleTransformer(torch.nn.Module): max_duration: Must be provided if time encodings are used. min_time_step: The minimum time step (> 0). By default it is max_duration / n_positions. rope: Either "time[-train]", "none" or None. + sos: Whether to use start token or not. """ def __init__(self, input_size, n_positions=1024, n_embd=768, n_layer=12, n_head=12, n_inner=None, dropout=0.1, causal=False, activation=torch.nn.functional.relu, normalization=torch.nn.LayerNorm, - mlp="default", pos_type="pos-angular", rope=None, group_size=1, + mlp="default", pos_type="pos-angular", rope=None, group_size=1, sos=True, max_duration=None, min_time_step=None): super().__init__() n_inner = n_inner if n_inner is not None else 4 * n_embd @@ -287,6 +288,7 @@ def __init__(self, input_size, n_positions=1024, n_embd=768, n_layer=12, n_head= self.causal = causal self.input_projection = torch.nn.Linear(input_size, n_embd) + self.sos = torch.nn.Parameter(torch.randn(n_embd)) if sos else None # We use norm_first by default. # See the original paper: Xiong R. et al. "On layer normalization in the transformer architecture" ICML 2020. @@ -334,6 +336,26 @@ def output_size(self): def delta_time(self): return False + def add_sos(self, embeddings, timestamps): + if self.sos is None: + return embeddings, timestamps + b, l, d = embeddings.payload.shape + lengths = torch.where(embeddings.seq_lens > 0, embeddings.seq_lens + 1, embeddings.seq_lens) + embeddings = torch.cat([self.sos[None, None].expand(b, 1, d), embeddings.payload], 1) # (B, 1 + L, D). + init_timestamps = timestamps.payload[:, :1] if l > 0 else torch.zeros(b, 1, device=timestamps.device, dtype=timestamps.dtype) + timestamps = torch.cat([init_timestamps, timestamps.payload], 1) # (B, 1 + L). + return PaddedBatch(embeddings, lengths), PaddedBatch(timestamps, lengths) + + def remove_sos(self, outputs): + if outputs is None: + return outputs + if self.sos is None: + return outputs + if isinstance(outputs, PaddedBatch): + return PaddedBatch(outputs.payload[:, 1:], (outputs.seq_lens - 1).clip(min=0)) + else: + return outputs[:, 1:] + def transform(self, embeddings, return_states=False, attention_mask=None): """Apply encoder after input projection and positional encoding. @@ -388,12 +410,16 @@ def forward(self, x, timestamps, states=None, return_states=False, attention_mas if return_states not in {False, "full"}: raise ValueError(f"Unknown states mode: {return_states}") - embeddings = self.input_projection(x.payload) # (B, L, D). - embeddings = self.positional(embeddings, timestamps.payload) # (B, L, D). + embeddings = PaddedBatch(self.input_projection(x.payload), # (B, L, D). + x.seq_lens) + embeddings, timestamps = self.add_sos(embeddings, timestamps) + embeddings = self.positional(embeddings.payload, timestamps.payload) # (B, L, D). embeddings = PaddedBatch(embeddings, x.seq_lens) if self.rope is not None: with self.rope.cache(timestamps.payload): outputs, states = self.transform(embeddings, attention_mask=attention_mask) else: outputs, states = self.transform(embeddings, attention_mask=attention_mask) + outputs = self.remove_sos(outputs) + states = self.remove_sos(states) return outputs, states diff --git a/setup.py b/setup.py index d7fb8a04..2563a95c 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setuptools.setup( name="hotpp-benchmark", - version="0.6.4", + version="0.6.5", author="Ivan Karpukhin", author_email="karpuhini@yandex.ru", description="Evaluate generative event sequence models on the long horizon prediction task.",