diff --git a/README.md b/README.md index 232140d..8dc987f 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,7 @@ We provide reference implementations of various state-of-the-art TPP papers: | 6 | ICLR'20 | IntensityFree | [Intensity-Free Learning of Temporal Point Processes](https://arxiv.org/abs/1909.12127) | [PyTorch](easy_tpp/model/torch_model/torch_intensity_free.py) | | 7 | ICLR'21 | ODETPP | [Neural Spatio-Temporal Point Processes (simplified)](https://arxiv.org/abs/2011.04583) | [PyTorch](easy_tpp/model/torch_model/torch_ode_tpp.py) | | 8 | ICLR'22 | AttNHP | [Transformer Embeddings of Irregularly Spaced Events and Their Participants](https://arxiv.org/abs/2201.00044) | [PyTorch](easy_tpp/model/torch_model/torch_attnhp.py) | +| 9 | NeurIPS'25 | S2P2 | Deep Continuous-Time State-Space Models for Marked Event Sequences | [PyTorch](easy_tpp/model/torch_model/torch_s2p2.py) | diff --git a/easy_tpp/model/__init__.py b/easy_tpp/model/__init__.py index 18d53b0..db9dac1 100644 --- a/easy_tpp/model/__init__.py +++ b/easy_tpp/model/__init__.py @@ -6,6 +6,7 @@ from easy_tpp.model.torch_model.torch_nhp import NHP as TorchNHP from easy_tpp.model.torch_model.torch_ode_tpp import ODETPP as TorchODETPP from easy_tpp.model.torch_model.torch_rmtpp import RMTPP as TorchRMTPP +from easy_tpp.model.torch_model.torch_s2p2 import S2P2 as TorchS2P2 from easy_tpp.model.torch_model.torch_sahp import SAHP as TorchSAHP from easy_tpp.model.torch_model.torch_thp import THP as TorchTHP @@ -18,4 +19,5 @@ 'TorchIntensityFree', 'TorchODETPP', 'TorchRMTPP', - 'TorchANHN'] + 'TorchANHN', + 'TorchS2P2'] diff --git a/easy_tpp/model/torch_model/torch_s2p2.py b/easy_tpp/model/torch_model/torch_s2p2.py new file mode 100644 index 0000000..6f609a3 --- /dev/null +++ b/easy_tpp/model/torch_model/torch_s2p2.py @@ -0,0 +1,322 @@ +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from easy_tpp.model.torch_model.torch_baselayer import ScaledSoftplus +from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel +from easy_tpp.ssm.models import LLH, Int_Backward_LLH, Int_Forward_LLH + + +class ComplexEmbedding(nn.Module): + def __init__(self, *args, **kwargs): + super(ComplexEmbedding, self).__init__() + self.real_embedding = nn.Embedding(*args, **kwargs) + self.imag_embedding = nn.Embedding(*args, **kwargs) + + self.real_embedding.weight.data *= 1e-3 + self.imag_embedding.weight.data *= 1e-3 + + def forward(self, x): + return torch.complex( + self.real_embedding(x), + self.imag_embedding(x), + ) + + +class IntensityNet(nn.Module): + def __init__(self, input_dim, bias, num_event_types): + super().__init__() + self.intensity_net = nn.Linear(input_dim, num_event_types, bias=bias) + self.softplus = ScaledSoftplus(num_event_types) + + def forward(self, x): + return self.softplus(self.intensity_net(x)) + + +class S2P2(TorchBaseModel): + def __init__(self, model_config): + """Initialize the model + + Args: + model_config (EasyTPP.ModelConfig): config of model specs. + """ + super(S2P2, self).__init__(model_config) + self.n_layers = model_config.num_layers + self.P = model_config.model_specs["P"] # Hidden state dimension + self.H = model_config.hidden_size # Residual stream dimension + self.beta = model_config.model_specs.get("beta", 1.0) + self.bias = model_config.model_specs.get("bias", True) + self.simple_mark = model_config.model_specs.get("simple_mark", True) + + layer_kwargs = dict( + P=self.P, + H=self.H, + dt_init_min=model_config.model_specs.get("dt_init_min", 1e-4), + dt_init_max=model_config.model_specs.get("dt_init_max", 0.1), + act_func=model_config.model_specs.get("act_func", "full_glu"), + dropout_rate=model_config.model_specs.get("dropout_rate", 0.0), + for_loop=model_config.model_specs.get("for_loop", False), + pre_norm=model_config.model_specs.get("pre_norm", True), + post_norm=model_config.model_specs.get("post_norm", False), + simple_mark=self.simple_mark, + relative_time=model_config.model_specs.get("relative_time", False), + complex_values=model_config.model_specs.get("complex_values", True), + ) + + int_forward_variant = model_config.model_specs.get("int_forward_variant", False) + int_backward_variant = model_config.model_specs.get( + "int_backward_variant", False + ) + assert ( + int_forward_variant + int_backward_variant + ) <= 1 # Only one at most is allowed to be specified + + if int_forward_variant: + llh_layer = Int_Forward_LLH + elif int_backward_variant: + llh_layer = Int_Backward_LLH + else: + llh_layer = LLH + + self.backward_variant = int_backward_variant + + self.layers = nn.ModuleList( + [ + llh_layer(**layer_kwargs, is_first_layer=i == 0) + for i in range(self.n_layers) + ] + ) + self.layers_mark_emb = nn.Embedding( + self.num_event_types_pad, + self.H, + ) # One embedding to share amongst layers to be used as input into a layer-specific and input-aware impulse + self.layer_type_emb = None # Remove old embeddings from EasyTPP + self.intensity_net = IntensityNet( + input_dim=self.H, + bias=self.bias, + num_event_types=self.num_event_types, + ) + + def _get_intensity( + self, x_LP: Union[torch.tensor, List[torch.tensor]], right_us_BNH + ) -> torch.Tensor: + """ + Assume time has already been evolved, take a vertical stack of hidden states and produce intensity. + """ + left_u_H = None + for i, layer in enumerate(self.layers): + if isinstance( + x_LP, list + ): # Sometimes it is convenient to pass as a list over the layers rather than a single tensor + left_u_H = layer.depth_pass( + x_LP[i], current_left_u_H=left_u_H, prev_right_u_H=right_us_BNH[i] + ) + else: + left_u_H = layer.depth_pass( + x_LP[..., i, :], + current_left_u_H=left_u_H, + prev_right_u_H=right_us_BNH[i], + ) + + return self.intensity_net(left_u_H) # self.ScaledSoftplus(self.linear(left_u_H)) + + def _evolve_and_get_intensity_at_sampled_dts(self, x_LP, dt_G, right_us_H): + left_u_GH = None + for i, layer in enumerate(self.layers): + x_GP = layer.get_left_limit( + right_limit_P=x_LP[..., i, :], + dt_G=dt_G, + next_left_u_GH=left_u_GH, + current_right_u_H=right_us_H[i], + ) + left_u_GH = layer.depth_pass( + current_left_x_P=x_GP, + current_left_u_H=left_u_GH, + prev_right_u_H=right_us_H[i], + ) + return self.intensity_net(left_u_GH) # self.ScaledSoftplus(self.linear(left_u_GH)) + + def forward( + self, batch, initial_state_BLP: Optional[torch.Tensor] = None, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Batch operations of self._forward + """ + t_BN, dt_BN, marks_BN, batch_non_pad_mask, _ = batch + + right_xs_BNP = [] # including both t_0 and t_N + left_xs_BNm1P = [] + right_us_BNH = [ + None + ] # Start with None as this is the 'input' to the first layer + left_u_BNH, right_u_BNH = None, None + alpha_BNP = self.layers_mark_emb(marks_BN) + + for l_i, layer in enumerate(self.layers): + # for each event, compute the fixed impulse via alpha_m for event i of type m + init_state = ( + initial_state_BLP[:, l_i] if initial_state_BLP is not None else None + ) + + # Returns right limit of xs and us for [t0, t1, ..., tN] + # "layer" returns the right limit of xs at current layer, and us for the next layer (as transformations of ys) + # x_BNP: at time [t_0, t_1, ..., t_{N-1}, t_N] + # next_left_u_BNH: at time [t_0, t_1, ..., t_{N-1}, t_N] -- only available for backward variant + # next_right_u_BNH: at time [t_0, t_1, ..., t_{N-1}, t_N] -- always returned but only used for RT + x_BNP, next_layer_left_u_BNH, next_layer_right_u_BNH = layer.forward( + left_u_BNH, right_u_BNH, alpha_BNP, dt_BN, init_state + ) + assert next_layer_right_u_BNH is not None + + right_xs_BNP.append(x_BNP) + if next_layer_left_u_BNH is None: # NOT backward variant + left_xs_BNm1P.append( + layer.get_left_limit( # current and next at event level + x_BNP[..., :-1, :], # at time [t_0, t_1, ..., t_{N-1}] + dt_BN[..., 1:].unsqueeze( + -1 + ), # with dts [t1-t0, t2-t1, ..., t_N-t_{N-1}] + current_right_u_H=right_u_BNH + if right_u_BNH is None + else right_u_BNH[ + ..., :-1, : + ], # at time [t_0, t_1, ..., t_{N-1}] + next_left_u_GH=left_u_BNH + if left_u_BNH is None + else left_u_BNH[..., 1:, :].unsqueeze( + -2 + ), # at time [t_1, t_2 ..., t_N] + ).squeeze(-2) + ) + right_us_BNH.append(next_layer_right_u_BNH) + + left_u_BNH, right_u_BNH = next_layer_left_u_BNH, next_layer_right_u_BNH + + right_xs_BNLP = torch.stack(right_xs_BNP, dim=-2) + + ret_val = { + "right_xs_BNLP": right_xs_BNLP, # [t_0, ..., t_N] + "right_us_BNH": right_us_BNH, # [t_0, ..., t_N]; list starting with None + } + + if left_u_BNH is not None: # backward variant + ret_val["left_u_BNm1H"] = left_u_BNH[ + ..., 1:, : + ] # The next inputs after last layer -> transformation of ys + else: # NOT backward variant + ret_val["left_xs_BNm1LP"] = torch.stack(left_xs_BNm1P, dim=-2) + + # 'seq_len - 1' left limit for [t_1, ..., t_N] for events (u if available, x if not) + # 'seq_len' right limit for [t_0, t_1, ..., t_{N-1}, t_N] for events xs or us + return ret_val + + def loglike_loss(self, batch, **kwargs): + # hidden states at the left and right limits around event time; note for the shift by 1 in indices: + # consider a sequence [t0, t1, ..., tN] + # Produces the following: + # left_x: x0, x1, x2, ... <-> x_{t_1-}, x_{t_2-}, x_{t_3-}, ..., x_{t_N-} (note the shift in indices) for all layers + # OR ==> <-> u_{t_1-}, u_{t_2-}, u_{t_3-}, ..., u_{t_N-} for last layer + # + # right_x: x0, x1, x2, ... <-> x_{t_0+}, x_{t_1+}, ..., x_{t_N+} for all layers + # right_u: u0, u1, u2, ... <-> u_{t_0+}, u_{t_1+}, ..., u_{t_N+} for all layers + forward_results = self.forward( + batch + ) # N minus 1 comparing with sequence lengths + right_xs_BNLP, right_us_BNH = ( + forward_results["right_xs_BNLP"], + forward_results["right_us_BNH"], + ) + right_us_BNm1H = [ + None if right_u_BNH is None else right_u_BNH[:, :-1, :] + for right_u_BNH in right_us_BNH + ] + + ts_BN, dts_BN, marks_BN, batch_non_pad_mask, _ = batch + + # evaluate intensity values at each event *from the left limit*, _get_intensity: [LP] -> [M] + # left_xs_B_Nm1_LP = left_xs_BNm1LP[:, :-1, ...] # discard the left limit of t_N + # Note: no need to discard the left limit of t_N because "marks_mask" will deal with it + if "left_u_BNm1H" in forward_results: # ONLY backward variant + intensity_B_Nm1_M = self.intensity_net( + forward_results["left_u_BNm1H"] + ) # self.ScaledSoftplus(self.linear(forward_results["left_u_BNm1H"])) + else: # NOT backward variant + intensity_B_Nm1_M = self._get_intensity( + forward_results["left_xs_BNm1LP"], right_us_BNm1H + ) + + # sample dt in each interval for MC: [batch_size, num_times=N-1, num_mc_sample] + # N-1 because we only consider the intervals between N events + # G for grid points + dts_sample_B_Nm1_G = self.make_dtime_loss_samples(dts_BN[:, 1:]) + + # evaluate intensity at dt_samples for MC *from the left limit* after decay -> shape (B, N-1, MC, M) + intensity_dts_B_Nm1_G_M = self._evolve_and_get_intensity_at_sampled_dts( + right_xs_BNLP[ + :, :-1 + ], # x_{t_i+} will evolve up to x_{t_{i+1}-} and many times between for i=0,...,N-1 + dts_sample_B_Nm1_G, + right_us_BNm1H, + ) + + event_ll, non_event_ll, num_events = self.compute_loglikelihood( + lambda_at_event=intensity_B_Nm1_M, + lambdas_loss_samples=intensity_dts_B_Nm1_G_M, + time_delta_seq=dts_BN[:, 1:], + seq_mask=batch_non_pad_mask[:, 1:], + type_seq=marks_BN[:, 1:], + ) + + # compute loss to optimize + loss = -(event_ll - non_event_ll).sum() + + return loss, num_events + + def compute_intensities_at_sample_times( + self, event_times_BN, inter_event_times_BN, marks_BN, sample_dtimes, **kwargs + ): + """Compute the intensity at sampled times, not only event times. *from the left limit* + + Args: + time_seq (tensor): [batch_size, seq_len], times seqs. + time_delta_seq (tensor): [batch_size, seq_len], time delta seqs. + event_seq (tensor): [batch_size, seq_len], event type seqs. + sample_dtimes (tensor): [batch_size, seq_len, num_sample], sampled inter-event timestamps. + + Returns: + tensor: [batch_size, num_times, num_mc_sample, num_event_types], + intensity at each timestamp for each event type. + """ + + compute_last_step_only = kwargs.get("compute_last_step_only", False) + + # assume inter_event_times_BN always starts from 0 + _input = event_times_BN, inter_event_times_BN, marks_BN, None, None + + # 'seq_len - 1' left limit for [t_1, ..., t_N] + # 'seq_len' right limit for [t_0, t_1, ..., t_{N-1}, t_N] + + forward_results = self.forward( + _input + ) # N minus 1 comparing with sequence lengths + right_xs_BNLP, right_us_BNH = ( + forward_results["right_xs_BNLP"], + forward_results["right_us_BNH"], + ) + + if ( + compute_last_step_only + ): # fix indices for right_us_BNH: list [None, tensor([BNH]), ...] + right_us_B1H = [ + None if right_u_BNH is None else right_u_BNH[:, -1:, :] + for right_u_BNH in right_us_BNH + ] + sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts( + right_xs_BNLP[:, -1:, :, :], sample_dtimes[:, -1:, :], right_us_B1H + ) # equiv. to right_xs_BNLP[:, -1, :, :][:, None, ...] + else: + sampled_intensity = self._evolve_and_get_intensity_at_sampled_dts( + right_xs_BNLP, sample_dtimes, right_us_BNH + ) + return sampled_intensity # [B, N, MC, M] diff --git a/easy_tpp/ssm/__init__.py b/easy_tpp/ssm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/easy_tpp/ssm/initializers.py b/easy_tpp/ssm/initializers.py new file mode 100644 index 0000000..3027c85 --- /dev/null +++ b/easy_tpp/ssm/initializers.py @@ -0,0 +1,117 @@ +import math + +import numpy as np +import numpy as onp +import torch as th +from numpy.linalg import eigh + + +def make_HiPPO(P): + """Create a HiPPO-LegS matrix. + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Args: + P (int32): state size + Returns: + P x P HiPPO LegS matrix + """ + M = np.sqrt(1 + 2 * np.arange(P)) + A = M[:, np.newaxis] * M[np.newaxis, :] + A = np.tril(A) - np.diag(np.arange(P)) + return -A + + +def make_NPLR_HiPPO(P): + """ + Makes components needed for NPLR representation of HiPPO-LegS + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Args: + P (int32): state size + + Returns: + P x P HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B + + """ + # Make -HiPPO + hippo = make_HiPPO(P) + + # Add in a rank 1 term. Makes it Normal. + R1 = np.sqrt(np.arange(P) + 0.5) + + # HiPPO also specifies the B matrix + B = np.sqrt(2 * np.arange(P) + 1.0) + return hippo, R1, B + + +def make_DPLR_HiPPO(P): + """ + Makes components needed for DPLR representation of HiPPO-LegS + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Note, we will only use the diagonal part + Args: + P: + + Returns: + eigenvalues Lambda, low-rank term R1, conjugated HiPPO input matrix B, + eigenvectors V, HiPPO B pre-conjugation + + """ + A, R1, B = make_NPLR_HiPPO(P) + + S = A + R1[:, np.newaxis] * R1[np.newaxis, :] + + S_diag = np.diagonal(S) + Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) + + # Diagonalize S to V \Lambda V^* + Lambda_imag, V = eigh(S * -1j) + + R1 = V.conj().T @ R1 + B_orig = B + B = V.conj().T @ B + return ( + th.tensor(onp.asarray(Lambda_real + 1j * Lambda_imag), dtype=th.complex64), + th.tensor(onp.asarray(R1)), + th.tensor(onp.asarray(B)), + th.tensor(onp.asarray(V), dtype=th.complex64), + th.tensor(onp.asarray(B_orig)), + ) + + +def init_log_steps(P, dt_min, dt_max): + """Initialize an array of learnable timescale parameters. + initialized uniformly in log space. + Args: + input: + Returns: + initialized array of timescales (float32): (P,) + """ + unlog = th.rand(size=(P,)) + log = unlog * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) + return log + + +def lecun_normal_(tensor: th.Tensor) -> th.Tensor: + input_size = tensor.shape[ + -1 + ] # Assuming that the weights' input dimension is the last. + std = math.sqrt(1 / input_size) + with th.no_grad(): + return tensor.normal_(0, std) # or torch.nn.init.xavier_normal_ + + +def init_VinvB(shape, Vinv): + """Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. + Note we will parameterize this with two different matrices for complex + + Modified from https://github.com/lindermanlab/S5/blob/52cc7e22d6963459ad99a8674e4d3cfb0a480008/s5/ssm.py#L165 + + numbers. + Args: + shape (tuple): desired shape (P,H) + Vinv: (complex64) the inverse eigenvectors used for initialization + Returns: + B_tilde (complex64) of shape (P,H) + """ + B = lecun_normal_(th.zeros(shape)) + VinvB = Vinv @ B.type(th.complex64) + return VinvB diff --git a/easy_tpp/ssm/models.py b/easy_tpp/ssm/models.py new file mode 100644 index 0000000..db69575 --- /dev/null +++ b/easy_tpp/ssm/models.py @@ -0,0 +1,820 @@ +from typing import Optional, Tuple + +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .initializers import ( + make_DPLR_HiPPO, # , lecun_normal_ # init_VinvB, init_log_steps, +) + +MATRIX_SCALING_FACTOR = 1 + + +class LLH(nn.Module): + """ + This is canon: + L -- number of layers + N -- number of events. + P -- Hidden dimension. Dimensionality of x. + H -- output dimension. Dimensionality of y/u. + """ + + def __init__( + self, + P: int, + H: int, + dt_init_min: float = 1e-4, + dt_init_max: float = 0.1, + dropout_rate: float = 0.0, + act_func: str = "gelu", # F.gelu, + for_loop: bool = False, + pre_norm: bool = True, + post_norm: bool = False, + simple_mark: bool = True, + is_first_layer: bool = False, + relative_time: bool = False, + complex_values: bool = True, + ): + """ + + :param P: + :param H: + :param dt_init_min: + :param dt_init_max: + :param act_func: + """ + + super(LLH, self).__init__() + + # Inscribe the args. + self.P = P + self.H = H + self.dt_init_min = dt_init_min + self.dt_init_max = dt_init_max + self.dropout_rate = dropout_rate + self.complex_values = complex_values + + # select the activation function. + if act_func == "gelu": + self.act_func = nn.Sequential(nn.GELU(), nn.Dropout(p=self.dropout_rate)) + elif act_func == "full_glu": + self.act_func = nn.Sequential( + nn.Linear(self.H, 2 * self.H), + nn.Dropout(p=self.dropout_rate), + nn.GLU(), + nn.Dropout(p=self.dropout_rate), + ) + + elif ( + act_func == "half_glu" + ): # ref: https://github.com/lindermanlab/S5/blob/main/s5/layers.py#L76 + self.act_func1 = nn.Sequential( + nn.Dropout(p=self.dropout_rate), + nn.GELU(), + nn.Linear(self.H, self.H), + ) + self.act_func = lambda x: nn.Dropout(p=self.dropout_rate)( + x * nn.Sigmoid()(self.act_func1(x)) + ) + else: + raise NotImplementedError( + "Unrecognized activation function {}".format(act_func) + ) + + # Assume we always use conjugate symmetry. + self.conj_sym = True + + # Allow a learnable initial state. + # Needs to be =/= 0 since we take the log to compute + if self.complex_values: + self.initial_state_P = nn.Parameter( + th.complex( + th.randn( + self.P, + ), + th.randn( + self.P, + ), + ) + * 1e-3, + requires_grad=True, + ) + else: + self.initial_state_P = nn.Parameter( + th.randn( + self.P, + ), + requires_grad=True, + ) + + self.norm = nn.LayerNorm(self.H) + self.for_loop = for_loop + self.pre_norm = pre_norm + self.post_norm = post_norm + + self.is_first_layer = is_first_layer + self.relative_time = relative_time + + self._init_ssm_params() + + self.simple_mark = simple_mark + if not simple_mark: + self.mark_a_net = nn.Linear(self.H, self.P, bias=True) + self.mark_u_net = nn.Linear( + self.H, self.P, bias=False + ) # Only need one bias + self.mark_a_net.weight.data = th.complex( + nn.init.xavier_normal_(self.mark_a_net.weight.data) * 1e-3, + nn.init.xavier_normal_(self.mark_a_net.weight.data) * 1e-3, + ) + self.mark_a_net.bias.data = th.complex( + nn.init.xavier_normal_(self.mark_a_net.bias.data) * 1e-3, + nn.init.xavier_normal_(self.mark_a_net.bias.data) * 1e-3, + ) + self.mark_u_net.weight.data = th.complex( + nn.init.xavier_normal_(self.mark_u_net.weight.data) * 1e-3, + nn.init.xavier_normal_(self.mark_u_net.weight.data) * 1e-3, + ) + if not self.complex_values: + self.mark_a_net.weight.data = self.mark_a_net.weight.data.real + self.mark_a_net.bias.data = self.mark_a_net.bias.data.real + self.mark_u_net.weight.data = self.self.mark_u_net.weight.data.real + + def _init_ssm_params(self): + self._init_A() + if not self.is_first_layer: + self._init_B() + self._init_C() + if ( + not self.is_first_layer + ): # Could group, but left in same order to not mess with initialization + self._init_D() + self._init_E() + + def _init_A(self): + # Define the initial diagonal HiPPO matrix. + # Te throw the HiPPO B away. + Lambda_P, _, _, V_PP, _ = make_DPLR_HiPPO(self.P) + self.Lambda_P_log_neg_real = th.nn.Parameter((-Lambda_P.real).log()) + self.Lambda_P_imag = th.nn.Parameter(Lambda_P.imag) + + # Store these for use later. + self._V_PP = V_PP + self._Vc_PP = V_PP.conj().T + + # We also initialize the step size. + if self.relative_time: + self.delta_net = nn.Linear( + self.H, self.P, bias=True + ) # nn.Parameter(init_log_steps(self.P, self.dt_init_min, self.dt_init_max)) + with th.no_grad(): + self.delta_net.weight.copy_( + nn.init.xavier_normal_(self.delta_net.weight) + ) + bias = th.ones( + self.P, + ) + bias += th.log(-th.expm1(-bias)) + self.delta_net.bias.copy_(bias) + else: + self.log_step_size_P = nn.Parameter( + th.zeros(size=(self.P,)), requires_grad=False + ) + + @property + def Lambda_P(self): + if self.complex_values: + return th.complex( + -self.Lambda_P_log_neg_real.exp(), + self.Lambda_P_imag, + ) + else: + return -self.Lambda_P_log_neg_real.exp() + + def _init_B(self): + # Initialize the B outside the eigenbasis and then transform. + B = nn.init.xavier_normal_(th.zeros((self.P, self.H))) * MATRIX_SCALING_FACTOR + B_tilde_PH = self._Vc_PP @ B.type(th.complex64) + self.B_tilde_PH = ( + th.nn.Parameter(B_tilde_PH) + if self.complex_values + else th.nn.Parameter(B_tilde_PH.real) + ) + + def _init_C(self): + # Use the "complex_normal" initialization. + # See ~https://github.com/lindermanlab/S5/blob/52cc7e22d6963459ad99a8674e4d3cfb0a480008/s5/ssm.py#L183 + C = nn.init.xavier_normal_(th.zeros((self.H, self.P))) * MATRIX_SCALING_FACTOR + C_tilde_HP = C.type(th.complex64) @ self._V_PP + self.C_tilde_HP = ( + th.nn.Parameter(C_tilde_HP) + if self.complex_values + else th.nn.Parameter(C_tilde_HP.real) + ) + # self.C_tilde_HP.data *= 1e-3 + + def _init_D(self): + # Initialize feedthrough (D) matrix. Note the intensity depends on all layers. + D_HH = th.zeros(self.H) + nn.init.normal_(D_HH, std=1.0) + self.D_HH = nn.Parameter(D_HH, requires_grad=True) + + def _init_E(self): + E = ( + th.nn.init.xavier_normal_(th.zeros((self.P, self.H))) + * MATRIX_SCALING_FACTOR + ) + E_tilde_PH = self._Vc_PP @ E.type(th.complex64) + self.E_tilde_PH = ( + th.nn.Parameter(E_tilde_PH) + if self.complex_values + else th.nn.Parameter(E_tilde_PH.real) + ) + + def compute_impulse(self, right_u_H, mark_embedding_H): + # Compute impulse to add to left limit of x to make right limit. + alpha_P = th.einsum( + "ph,...h->...p", + self.E_tilde_PH, + mark_embedding_H.type(th.complex64) + if self.complex_values + else mark_embedding_H, + ) + return alpha_P + + def get_lambda(self, right_u_NH, shift_u=True): + if self.relative_time and (right_u_NH is not None): + if shift_u: # during "forward" when dts = [0, t1-t0, ..., t_N-t_{N-1}] + right_u_NH = F.pad( + right_u_NH[..., :-1, :], (0, 0, 1, 0) + ) # pad default 0 at beginning of second to last dim + lambda_rescaled_NP = ( + F.softplus(self.delta_net(right_u_NH)) * self.Lambda_P + ) # predict delta_i from right_u_i + return {"lambda_rescaled_NP": lambda_rescaled_NP} + else: + if self.relative_time: + lambda_rescaled_P = F.softplus(self.delta_net.bias) * self.Lambda_P + else: + lambda_rescaled_P = th.exp(self.log_step_size_P) * self.Lambda_P + return {"lambda_rescaled_P": lambda_rescaled_P} + + def forward( + self, + left_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + right_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + mark_embedding_NH: th.Tensor, + dt_N: th.Tensor, + initial_state_P: Optional[th.Tensor] = None, + ) -> Tuple[th.Tensor, th.Tensor]: + """ + Apply the linear SSM to the inputs. + + In the context of TPPs, this returns the right limit of the "intensity function". + This intensity will have been passed through a non-linearity, though, and so there is no + guarantee for it is positive. + + :param u_NH: [..., seq_len, input_dim] + :param alpha_NP: [..., seq_len, hidden_dim] + :param dt_N: [..., seq_len] + :param initial_state_P: [..., hidden_dim] + :return: + """ + # Pull out the dimensions. + *leading_dims, _, _ = mark_embedding_NH.shape + num_leading_dims = len(leading_dims) + + if initial_state_P is None: + # Pad and expand to match leading dimensions of input + initial_state_P = self.initial_state_P.view( + *[1 for _ in range(num_leading_dims)], -1 + ).expand(*leading_dims, -1) + + # Add layer norm + prime_left_u_NH = left_u_NH + prime_right_u_NH = right_u_NH + if prime_left_u_NH is not None: # ONLY for backward variant + assert all( + u_d == a_d + for u_d, a_d in zip(prime_left_u_NH.shape, mark_embedding_NH.shape) + ) # All but last dimensions should match + if self.pre_norm: + prime_left_u_NH = self.norm(prime_left_u_NH) + if prime_right_u_NH is not None: + assert all( + u_d == a_d + for u_d, a_d in zip(prime_right_u_NH.shape, mark_embedding_NH.shape) + ) # All but last dimensions should match + if self.pre_norm: + prime_right_u_NH = self.norm(prime_right_u_NH) + + right_x_NP, left_y_NH, right_y_NH = self._ssm( + left_u_NH=prime_left_u_NH, + right_u_NH=prime_right_u_NH, + impulse_NP=self.compute_impulse(prime_right_u_NH, mark_embedding_NH), + dt_N=dt_N, + initial_state_P=initial_state_P, + ) + + # Given the following: + # right_u: u0, u1, u2, ... <-> u_{t_0}, u_{t_1}, u_{t_2}, ... + # left_u: u0, u1, u2, ... <-> u_{t_0-}, u_{t_1-}, u_{t_2-}, ... + # a: a0, a1, a2, ... <-> mark embeddings for m_0, m_1, m_2, ... at times t_0, t_1, t_2 + # dt: dt0, dt1, dt2, ... <-> 0, t_1-t_0, t_2-t_1, ... + # initial_state_p: hidden state to evolve to to compute x_{0} + + # Returns the following: + # right_x: x0, x1, x2, ... <-> x_{t_0}, x_{t_1}, x_{t_2}, ... + # right_y: y0, y1, y2, ... <-> y_{t_0}, y_{t_1}, y_{t_2}, ... + # left_y: y0, y1, y2, ... <-> y_{t_0-}, y_{t_1-}, y_{t_2-}, ... + + next_layer_left_u_NH = next_layer_right_u_NH = None + if left_y_NH is not None: + next_layer_left_u_NH = self.act_func(left_y_NH) + ( + left_u_NH if left_u_NH is not None else 0.0 + ) + if self.post_norm: + next_layer_left_u_NH = self.norm(next_layer_left_u_NH) + if right_y_NH is not None: + next_layer_right_u_NH = self.act_func(right_y_NH) + ( + right_u_NH if right_u_NH is not None else 0.0 + ) + if self.post_norm: + next_layer_right_u_NH = self.norm(next_layer_right_u_NH) + return right_x_NP, next_layer_left_u_NH, next_layer_right_u_NH + + def _ssm( + self, + left_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + right_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + impulse_NP: th.Tensor, + dt_N: th.Tensor, # [0, t_1 - t_0, ..., t_N - t_{N-1}] + initial_state_P: th.Tensor, + ): + *leading_dims, N, P = impulse_NP.shape + u_NH = right_u_NH # This implementation does not use left_u, nor does it compute left_y + if u_NH is not None: + impulse_NP = impulse_NP + th.einsum( + "ph,...nh->...np", + self.B_tilde_PH, + u_NH.type(th.complex64) if self.complex_values else u_NH, + ) + y_u_res_NH = th.einsum( + "...nh,h->...nh", u_NH, self.D_HH + ) # D_HH should really be D_H + else: + assert self.is_first_layer + y_u_res_NH = 0.0 + + lambda_res = self.get_lambda(right_u_NH=right_u_NH, shift_u=True) + if "lambda_rescaled_P" in lambda_res: # original formulation + lambda_dt_NP = th.einsum( + "...n,p->...np", dt_N, lambda_res["lambda_rescaled_P"] + ) + else: # relative time + lambda_dt_NP = th.einsum( + "...n,...np->...np", dt_N, lambda_res["lambda_rescaled_NP"] + ) + + if self.for_loop: + right_x_P = initial_state_P + right_x_NP = [] + for i in range(N): + right_x_P = ( + lambda_dt_NP[..., i, :].exp() * right_x_P + impulse_NP[..., i, :] + ) + right_x_NP.append(right_x_P) + right_x_NP = th.stack(right_x_NP, dim=-2) + else: + # Trick inspired by: https://github.com/PeaBrane/mamba-tiny/blob/master/scans.py + # .unsqueeze(-2) to add sequence dimension to initial state + log_impulse_Np1_P = th.concat( + (initial_state_P.unsqueeze(-2), impulse_NP), dim=-2 + ).log() + lamdba_dt_star = F.pad(lambda_dt_NP.cumsum(-2), (0, 0, 1, 0)) + right_x_log_NP = ( + th.logcumsumexp(log_impulse_Np1_P - lamdba_dt_star, -2) + lamdba_dt_star + ) + right_x_NP = right_x_log_NP.exp()[..., 1:, :] + + conj_sym_mult = 2 if self.conj_sym else 1 + y_NH = ( + conj_sym_mult + * th.einsum("...np,hp->...nh", right_x_NP, self.C_tilde_HP).real + + y_u_res_NH + ) + + return right_x_NP, None, y_NH + + def get_left_limit( + self, + right_limit_P: th.Tensor, # Along with dt, can have any number of leading dimensions, produces a tensor of dim ...MP + dt_G: th.Tensor, + current_right_u_H: th.Tensor, + next_left_u_GH: th.Tensor, + ) -> th.Tensor: + """ + To get the left limit, we roll on the layer for the right dt. + Computed for a single point (vmap for multiple). + + :param right_limit_P: at [t_0, ..., t_{N-1}] + :param dt: Length of time to roll the layer on for. at [t_1 - t_0, ..., t_N - t_{N-1}] + :param current_right_u_H: at [t_0, ..., t_{N-1}] -- for relative-time variant + :param next_left_u_GH: at [t_1, ..., t_N] -- for backward variant + + :return: + """ + + if current_right_u_H is not None and self.pre_norm: + current_right_u_H = self.norm(current_right_u_H) + + lambda_res = self.get_lambda( + current_right_u_H, shift_u=False + ) # U should already be shifted + if "lambda_rescaled_P" in lambda_res: + lambda_bar_GP = th.exp( + th.einsum("...g,p->...gp", dt_G, lambda_res["lambda_rescaled_P"]) + ) + else: + lambda_bar_GP = th.exp( + th.einsum("...g,...p->...gp", dt_G, lambda_res["lambda_rescaled_NP"]) + ) + + return th.einsum("...p,...gp->...gp", right_limit_P, lambda_bar_GP) + + def depth_pass( + self, + current_left_x_P: th.Tensor, # No leading dimensions (seq, batch, etc.) here because we accommodate any of them + current_left_u_H: Optional[ + th.Tensor + ], # Just assume that x and u match in the leading dimensions. Produces y_H with equivalent leading dimensions + prev_right_u_H: Optional[ + th.Tensor + ], # Just assume that x and u match in the leading dimensions. Produces y_H with equivalent leading dimensions + ) -> th.Tensor: + if current_left_u_H is not None: + if self.pre_norm: + prime_u_H = self.norm(current_left_u_H) + else: + prime_u_H = current_left_u_H + y_u_res_H = th.einsum( + "...h,h->...h", prime_u_H, self.D_HH + ) # D_HH should really be D_H + else: + assert self.is_first_layer + y_u_res_H = 0.0 + + conj_sym_mult = 2 if self.conj_sym else 1 + y_H = ( + conj_sym_mult + * th.einsum("...p,hp->...h", current_left_x_P, self.C_tilde_HP).real + + y_u_res_H + ) + + # Apply an activation function. + if self.post_norm: + new_u_H = self.norm( + self.act_func(y_H) + + (current_left_u_H if current_left_u_H is not None else 0.0) + ) + else: + new_u_H = self.act_func(y_H) + ( + current_left_u_H if current_left_u_H is not None else 0.0 + ) + + return new_u_H + + +class Int_Forward_LLH(LLH): + # LLH but Bu_t is integrated w.r.t dt instead of dN_t + # After discretization, when evolving x_t to x_t', applies ZOH on u_t over [t,t'] forward in time + # (as opposed to u_{t'} backwards over [t,t']) + + def _ssm( + self, + left_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + right_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + impulse_NP: th.Tensor, + dt_N: th.Tensor, + initial_state_P: th.Tensor, + ) -> Tuple[th.Tensor, th.Tensor]: + """ + Apply the linear SSM to the inputs. + + In the context of TPPs, this returns the right limit of the "intensity function". + This intensity will have been passed through a non-linearity, though, and so there is no + guarantee for it is positive. + + :param u_NH: [..., seq_len, input_dim] + :param alpha_NP: [..., seq_len, hidden_dim] + :param dt_N: [..., seq_len] + :param initial_state_P: [..., hidden_dim] + + :return: + """ + # Pull out the dimensions. + *leading_dims, N, P = impulse_NP.shape + + lambda_res = self.get_lambda(right_u_NH=right_u_NH, shift_u=True) + if "lambda_rescaled_P" in lambda_res: + lambda_rescaled = lambda_res["lambda_rescaled_P"] + lambda_dt_NP = th.einsum( + "...n,p->...np", dt_N, lambda_res["lambda_rescaled_P"] + ) + else: + lambda_rescaled = lambda_res["lambda_rescaled_NP"] + lambda_dt_NP = th.einsum( + "...n,...np->...np", dt_N, lambda_res["lambda_rescaled_NP"] + ) + + if left_u_NH is not None: + left_Du_NH = th.einsum( + "...nh,h->...nh", + left_u_NH, + self.D_HH, + ) + else: + assert self.is_first_layer + left_Du_NH = 0.0 + + if right_u_NH is not None: + right_u_NH = F.pad(right_u_NH[..., :-1, :], (0, 0, 1, 0)) + right_Bu_NP = th.einsum( + "...np,ph,...nh->...np", + lambda_dt_NP.exp() - 1.0, # dts: [0, t1-t0, t2-t1, ...] + self.B_tilde_PH, + right_u_NH.type(th.complex64) if self.complex_values else right_u_NH, + ) + right_Du_NH = th.einsum( + "...nh,h->...nh", + right_u_NH, + self.D_HH, + ) + else: + assert self.is_first_layer + right_Bu_NP = right_Du_NH = 0.0 + + if self.for_loop: + right_x_P = initial_state_P + left_x_NP, right_x_NP = [], [] + for i in range(N): + left_x_P = lambda_dt_NP[..., i, :].exp() * right_x_P + ( + right_Bu_NP[..., i, :] if left_u_NH is not None else 0.0 + ) + right_x_P = left_x_P + impulse_NP[..., i, :] + left_x_NP.append(left_x_P) + right_x_NP.append(right_x_P) + right_x_NP = th.stack( + right_x_NP, dim=-2 + ) # discard initial_hidden_states, right_limit of xs for [t0, t1, ...] + left_x_NP = th.stack( + left_x_NP, dim=-2 + ) # discard initial_hidden_states, left_limit of xs for [t0, t1, ...] + else: + # Trick inspired by: https://github.com/PeaBrane/mamba-tiny/blob/master/scans.py + # .unsqueeze(-2) to add sequence dimension to initial state + log_impulse_Np1_P = th.concat( + (initial_state_P.unsqueeze(-2), right_Bu_NP + impulse_NP), dim=-2 + ).log() + lamdba_dt_star = F.pad(lambda_dt_NP.cumsum(-2), (0, 0, 1, 0)) + right_x_log_NP = ( + th.logcumsumexp(log_impulse_Np1_P - lamdba_dt_star, -2) + lamdba_dt_star + ) + right_x_NP = right_x_log_NP.exp() # Contains initial_state_P in index 0 + left_x_NP = ( + lambda_dt_NP.exp() * right_x_NP[..., :-1, :] + right_Bu_NP + ) # Evolves previous hidden state forward to compute left limit + right_x_NP = right_x_NP[..., 1:, :] + + conj_sym_mult = 2 if self.conj_sym else 1 + left_y_NH = ( + conj_sym_mult + * th.einsum("hp,...np->...nh", self.C_tilde_HP, left_x_NP).real + + left_Du_NH + ) # ys for [t0, t1, ...] + right_y_NH = ( + conj_sym_mult + * th.einsum("hp,...np->...nh", self.C_tilde_HP, right_x_NP).real + + right_Du_NH + ) # ys for [t0, t1, ...] + + return right_x_NP, left_y_NH, right_y_NH + + def get_left_limit( + self, + right_limit_P: th.Tensor, # Along with dt, can have any number of leading dimensions, produces a tensor of dim ...MP + dt_G: th.Tensor, + current_right_u_H: Optional[th.Tensor], + next_left_u_GH: Optional[th.Tensor], + ) -> th.Tensor: + """ + To get the left limit, we roll on the layer for the right dt. + Computed for a single point (vmap for multiple). + + :param right_limit_P: + :param dt: Length of time to roll the layer on for. + :return: + """ + if current_right_u_H is not None and self.pre_norm: + current_right_u_H = self.norm(current_right_u_H) + + lambda_res = self.get_lambda( + current_right_u_H, shift_u=False + ) # U should already be shifted + if "lambda_rescaled_P" in lambda_res: + lambda_bar_GP = th.exp( + th.einsum("...g,p->...gp", dt_G, lambda_res["lambda_rescaled_P"]) + ) + else: + lambda_bar_GP = th.exp( + th.einsum("...g,...p->...gp", dt_G, lambda_res["lambda_rescaled_NP"]) + ) + + # lambda_rescaled_P = th.exp(self.log_step_size_P) * self.Lambda_P + # lambda_bar_GP = th.exp(th.einsum('...g,p->...gp', dt_G, lambda_rescaled_P)) + int_hidden_GP = th.einsum("...p,...gp->...gp", right_limit_P, lambda_bar_GP) + + if current_right_u_H is None: # no Bu term + assert self.is_first_layer + return int_hidden_GP + else: # add Bu to impulse + if self.pre_norm: + current_right_u_H = self.norm(current_right_u_H) + + impulse_GP = th.einsum( + "...gp,ph,...h->...gp", + lambda_bar_GP - 1.0, + self.B_tilde_PH, + current_right_u_H.type(th.complex64) + if self.complex_values + else current_right_u_H, + ) + + return int_hidden_GP + impulse_GP + + +class Int_Backward_LLH(Int_Forward_LLH): + # LLH but Bu_t is integrated w.r.t dt instead of dN_t + # After discretization, when evolving x_t to x_t', applies ZOH on u_t' over [t,t'] backwards in time + # (as opposed to u_{t} forwards over [t,t']) + + def _ssm( + self, + left_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + right_u_NH: Optional[th.Tensor], # Very first layer, should feed in `None` + impulse_NP: th.Tensor, + dt_N: th.Tensor, + initial_state_P: th.Tensor, + ) -> Tuple[th.Tensor, th.Tensor]: + """ + Apply the linear SSM to the inputs. + + In the context of TPPs, this returns the right limit of the "intensity function". + This intensity will have been passed through a non-linearity, though, and so there is no + guarantee for it is positive. + + :param u_NH: [..., seq_len, input_dim] + :param alpha_NP: [..., seq_len, hidden_dim] + :param dt_N: [..., seq_len] + :param initial_state_P: [..., hidden_dim] + + :return: + """ + # Pull out the dimensions. + *leading_dims, N, P = impulse_NP.shape + + # lambda_rescaled_P = th.exp(self.log_step_size_P) * self.Lambda_P + # lambda_dt_NP = th.einsum('...n,p->...np', dt_N, lambda_rescaled_P) + lambda_res = self.get_lambda(right_u_NH=right_u_NH, shift_u=True) + if "lambda_rescaled_P" in lambda_res: + lambda_dt_NP = th.einsum( + "...n,p->...np", dt_N, lambda_res["lambda_rescaled_P"] + ) + else: + lambda_dt_NP = th.einsum( + "...n,...np->...np", dt_N, lambda_res["lambda_rescaled_NP"] + ) + + if left_u_NH is not None: + left_Bu_NP = th.einsum( + "...np,ph,...nh->...np", + lambda_dt_NP.exp() - 1.0, # dts: [0, t1-t0, t2-t1, ...] + self.B_tilde_PH, + left_u_NH.type(th.complex64) if self.complex_values else left_u_NH, + ) + left_Du_NH = th.einsum( + "...nh,h->...nh", + left_u_NH, + self.D_HH, + ) + else: + assert self.is_first_layer + left_Bu_NP = left_Du_NH = 0.0 + + if right_u_NH is not None: + right_Du_NH = th.einsum( + "...nh,h->...nh", + right_u_NH, + self.D_HH, + ) + else: + assert self.is_first_layer + right_Du_NH = 0.0 + + if self.for_loop: + right_x_P = initial_state_P + left_x_NP, right_x_NP = [], [] + for i in range(N): + left_x_P = lambda_dt_NP[..., i, :].exp() * right_x_P + ( + left_Bu_NP[..., i, :] if left_u_NH is not None else 0.0 + ) + right_x_P = left_x_P + impulse_NP[..., i, :] + left_x_NP.append(left_x_P) + right_x_NP.append(right_x_P) + right_x_NP = th.stack( + right_x_NP, dim=-2 + ) # discard initial_hidden_states, right_limit of xs for [t0, t1, ...] + left_x_NP = th.stack( + left_x_NP, dim=-2 + ) # discard initial_hidden_states, left_limit of xs for [t0, t1, ...] + else: + # Trick inspired by: https://github.com/PeaBrane/mamba-tiny/blob/master/scans.py + # .unsqueeze(-2) to add sequence dimension to initial state + log_impulse_Np1_P = th.concat( + (initial_state_P.unsqueeze(-2), left_Bu_NP + impulse_NP), dim=-2 + ).log() + lamdba_dt_star = F.pad(lambda_dt_NP.cumsum(-2), (0, 0, 1, 0)) + right_x_log_NP = ( + th.logcumsumexp(log_impulse_Np1_P - lamdba_dt_star, -2) + lamdba_dt_star + ) + right_x_NP = right_x_log_NP.exp() # Contains initial_state_P in index 0 + left_x_NP = ( + lambda_dt_NP.exp() * right_x_NP[..., :-1, :] + left_Bu_NP + ) # Evolves previous hidden state forward to compute left limit + right_x_NP = right_x_NP[..., 1:, :] + + conj_sym_mult = 2 if self.conj_sym else 1 + left_y_NH = ( + conj_sym_mult + * th.einsum("hp,...np->...nh", self.C_tilde_HP, left_x_NP).real + + left_Du_NH + ) # ys for [t0, t1, ...] + right_y_NH = ( + conj_sym_mult + * th.einsum("hp,...np->...nh", self.C_tilde_HP, right_x_NP).real + + right_Du_NH + ) # ys for [t0, t1, ...] + + return right_x_NP, left_y_NH, right_y_NH + + def get_left_limit( + self, + right_limit_P: th.Tensor, # Along with dt, can have any number of leading dimensions, produces a tensor of dim ...MP + dt_G: th.Tensor, + current_right_u_H: th.Tensor, + next_left_u_GH: th.Tensor, + ) -> th.Tensor: + """ + To get the left limit, we roll on the layer for the right dt. + Computed for a single point (vmap for multiple). + + :param right_limit_P: + :param dt: Length of time to roll the layer on for. + :return: + """ + + if current_right_u_H is not None and self.pre_norm: + current_right_u_H = self.norm(current_right_u_H) + + lambda_res = self.get_lambda( + current_right_u_H, shift_u=False + ) # U should already be shifted + if "lambda_rescaled_P" in lambda_res: + lambda_bar_GP = th.exp( + th.einsum("...g,p->...gp", dt_G, lambda_res["lambda_rescaled_P"]) + ) + else: + lambda_bar_GP = th.exp( + th.einsum("...g,...p->...gp", dt_G, lambda_res["lambda_rescaled_NP"]) + ) + + int_hidden_GP = th.einsum("...p,...gp->...gp", right_limit_P, lambda_bar_GP) + + if next_left_u_GH is None: # no Bu term + assert self.is_first_layer + return int_hidden_GP + else: # add Bu to impulse + if self.pre_norm: + next_left_u_GH = self.norm(next_left_u_GH) + + impulse_GP = th.einsum( + "...gp,ph,...gh->...gp", + lambda_bar_GP - 1.0, + self.B_tilde_PH, + next_left_u_GH.type(th.complex64) + if self.complex_values + else next_left_u_GH, + ) + + return int_hidden_GP + impulse_GP diff --git a/easy_tpp/ssm/ssm_util.py b/easy_tpp/ssm/ssm_util.py new file mode 100644 index 0000000..72d9b5c --- /dev/null +++ b/easy_tpp/ssm/ssm_util.py @@ -0,0 +1,80 @@ +# @title Imports and environment +import torch as th + + +def discretize_zoh(Lambda, B_tilde, Delta): + """Discretize a diagonalized, continuous-time linear SSM + using zero-order hold method. + + modified from: https://github.com/lindermanlab/S5/blob/3c18fdb6b06414da35e77b94b9cd855f6a95ef17/s5/ssm.py#L29 + + Args: + Lambda (complex64): diagonal state matrix (P,) + B_tilde (complex64): input matrix (P, H) + Delta (float32): discretization step sizes (P,) + Returns: + discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) + """ + Identity = th.ones(Lambda.shape[0]) + Lambda_bar = th.exp(Lambda * Delta) + B_bar = (1 / Lambda * (Lambda_bar - Identity))[..., None] * B_tilde + return Lambda_bar, B_bar + + +def apply_ssm( + Lambda_bar_NP, + B_bar_NPH, + C_tilde_HP, + input_sequence_NH, + alpha_NP, + conj_sym, + initial_state_P=None, +): + """Compute the NxH output of discretized SSM given an NxH input. + + modified from: https://github.com/lindermanlab/S5/blob/3c18fdb6b06414da35e77b94b9cd855f6a95ef17/s5/ssm.py#L60 + - removed bidirectionality. + - assume Lambda_bar is N-length. + + Args: + Lambda_bar_NP (complex64): discretized diagonal state matrix for each interval (N, P) + B_bar_NPH (complex64): "discretized" input matrix. Note: may be outside ZOH (N, P, H) + C_tilde_HP (complex64): output matrix (H, P) + input_sequence_NH (float32): input sequence of features (N, H) + alpha_NP (complex64): mark-specific biases (N, P) + conj_sym (bool): Whether conjugate symmetry is enforced + initial_state_P (): Allow passing in a specific initial state (otherwise zero is assumed.) + Returns: + ys_NH (float32): the SSM outputs (S5 layer preactivations) (N, H) + """ + N, P, H = B_bar_NPH.shape + + # Compute effective inputs. + Bu_elements_NP = th.vmap(lambda b, u, alpha: b @ u.type(th.complex64) + alpha)( + B_bar_NPH, input_sequence_NH, alpha_NP + ) + + # # Torch doesn't roll an associative scan... yet... + # _, xs = jax.lax.associative_scan(binary_operator, (Lambda_elements, Bu_elements)) + + # Set the initial state if we haven't already. + if initial_state_P is None: + state = th.zeros((P,)) + else: + state = initial_state_P + + # Accumulate the hidden states here. Note the initial state shouldn't be returned. + # xs = th.zeros((L, P)).type(th.complex64) + xs = [state] + + for i, (lam_P, bu_P) in enumerate(zip(Lambda_bar_NP, Bu_elements_NP)): + # state = lam_P * state + bu_P + # xs[i] = state + xs.append(lam_P * xs[-1] + bu_P) + xs = th.stack(xs)[1:] + + # Output the xs and ys after projecting. + if conj_sym: + return xs, th.vmap(lambda x: 2 * (C_tilde_HP @ x).real)(xs) + else: + return xs, th.vmap(lambda x: (C_tilde_HP @ x).real)(xs) diff --git a/examples/configs/experiment_config.yaml b/examples/configs/experiment_config.yaml index 623c9db..27ece55 100644 --- a/examples/configs/experiment_config.yaml +++ b/examples/configs/experiment_config.yaml @@ -585,4 +585,41 @@ AttNHP_gen: over_sample_rate: 5 num_samples_boundary: 5 dtime_max: 5 - num_step_gen: 10 \ No newline at end of file + num_step_gen: 10 + + +# Example configuration for training State-Space Point Process (S2P2) model. +S2P2_train: + base_config: + stage: train + backend: torch + dataset_id: taxi + runner_id: std_tpp + model_id: S2P2 + base_dir: './checkpoints/' + trainer_config: + batch_size: 256 + max_epoch: 300 + shuffle: True + optimizer: adam + learning_rate: 1.e-2 + valid_freq: 1 + use_tfb: False + metrics: [ 'acc', 'rmse' ] + seed: 2019 + gpu: -1 # ID of GPU to use. Set to -1 to use CPU instead. `mps` backend could lead to incorrect results, please use CPU or CUDA. + model_config: + hidden_size: 128 # Number of dimensions for u_t and y_t, labeled as H in the paper. + loss_integral_num_sample_per_step: 10 # How many time points to use to estimate the integrated intensity between each pair of subsequent events for the log-likelihood. + use_mc_samples: True # Use Monte-Carlo sampling for the integral estimation. If False, uses a quadrature with a grid of evenly spaced points. + num_layers: 4 # Number of LLH layers. + model_specs: + P: 16 # Number of dimensions for the hidden state x_t, labeled as P in the paper. + dropout_rate: 0.1 # Dropout rate, used immediately after the activation function between layers but before the normalization. Formally, we set u^{(l+1)}_t = LayerNorm(dropout(\sigma(y^{(l)}_t)) + u^{(l)}_t). + act_func: gelu # gelu | half_glu | full_glu # Activation function to use between layers. + for_loop: True # If enabled, uses for-loop for computing the recurrence in the LLH layers. If disabled, uses a parallel scan. + pre_norm: False # Should be set to False. If True, uses a LayerNorm on the inputs to a LLH layer. + post_norm: True # Should be set to True. If True, uses a LayerNorm on the outputs of a LLH layer (after transforming and adding the residual). + int_forward_variant: False # Should be set to False. If True, uses u_{t_i} as the ZOH constant for u_t with t \in (t_i, t_{i+1}]. + int_backward_variant: True # Should be set to True. If True, uses u_{t_{i+1}-} as the ZOH constant for u_t with t \in (t_i, t_{i+1}]. + relative_time: True # If True, predicts the scaling factor to be applied to the dynamics between each pair of subsequent events. See Sec. 3.3 of the paper.