diff --git a/.gitignore b/.gitignore index 431edb1..7be2005 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,5 @@ log .idea examples/checkpoints/* + +notebooks/checkpoints/* \ No newline at end of file diff --git a/easy_tpp/default_registers/register_metrics.py b/easy_tpp/default_registers/register_metrics.py index 6df9870..feeab1f 100644 --- a/easy_tpp/default_registers/register_metrics.py +++ b/easy_tpp/default_registers/register_metrics.py @@ -16,8 +16,13 @@ def rmse_metric_function(predictions, labels, **kwargs): float: average rmse of the time predictions. """ seq_mask = kwargs.get('seq_mask') - pred = predictions[PredOutputIndex.TimePredIndex][seq_mask] - label = labels[PredOutputIndex.TimePredIndex][seq_mask] + if seq_mask is None or len(seq_mask) == 0: + # If mask is empty or None, use all predictions + pred = predictions[PredOutputIndex.TimePredIndex] + label = labels[PredOutputIndex.TimePredIndex] + else: + pred = predictions[PredOutputIndex.TimePredIndex][seq_mask] + label = labels[PredOutputIndex.TimePredIndex][seq_mask] pred = np.reshape(pred, [-1]) label = np.reshape(label, [-1]) @@ -36,8 +41,13 @@ def acc_metric_function(predictions, labels, **kwargs): float: accuracy ratio of the type predictions. """ seq_mask = kwargs.get('seq_mask') - pred = predictions[PredOutputIndex.TypePredIndex][seq_mask] - label = labels[PredOutputIndex.TypePredIndex][seq_mask] + if seq_mask is None or len(seq_mask) == 0: + # If mask is empty or None, use all predictions + pred = predictions[PredOutputIndex.TypePredIndex] + label = labels[PredOutputIndex.TypePredIndex] + else: + pred = predictions[PredOutputIndex.TypePredIndex][seq_mask] + label = labels[PredOutputIndex.TypePredIndex][seq_mask] pred = np.reshape(pred, [-1]) label = np.reshape(label, [-1]) return np.mean(pred == label) diff --git a/easy_tpp/model/torch_model/torch_basemodel.py b/easy_tpp/model/torch_model/torch_basemodel.py index 1c59ecd..6b3e42c 100644 --- a/easy_tpp/model/torch_model/torch_basemodel.py +++ b/easy_tpp/model/torch_model/torch_basemodel.py @@ -205,15 +205,18 @@ def predict_multi_step_since_last_event(self, batch, forward=False): """Multi-step prediction since last event in the sequence. Args: - time_seqs (tensor): [batch_size, seq_len]. - time_delta_seqs (tensor): [batch_size, seq_len]. - type_seqs (tensor): [batch_size, seq_len]. - num_step (int): num of steps for prediction. + batch (tuple): A tuple containing: + - time_seq_label (tensor): Timestamps of events [batch_size, seq_len]. + - time_delta_seq_label (tensor): Time intervals between events [batch_size, seq_len]. + - event_seq_label (tensor): Event types [batch_size, seq_len]. + - batch_non_pad_mask_label (tensor): Mask for non-padding elements [batch_size, seq_len]. + - attention_mask (tensor): Mask for attention [batch_size, seq_len]. + forward (bool, optional): Whether to use the entire sequence for prediction. Defaults to False. Returns: tuple: tensors of dtime and type prediction, [batch_size, seq_len]. """ - time_seq_label, time_delta_seq_label, event_seq_label, batch_non_pad_mask_label, _, type_mask_label = batch + time_seq_label, time_delta_seq_label, event_seq_label, _, _ = batch num_step = self.gen_config.num_step_gen diff --git a/notebooks/easytpp_1_dataset.ipynb b/notebooks/easytpp_1_dataset.ipynb index c0cdfee..2dce697 100644 --- a/notebooks/easytpp_1_dataset.ipynb +++ b/notebooks/easytpp_1_dataset.ipynb @@ -97,7 +97,7 @@ ], "source": [ "# ues the latest release\n", - "# !pip install easy_tpp\n", + "# !pip install easy-tpp\n", "\n", "# or use the git main branch\n", "!pip install git+https://github.com/ant-research/EasyTemporalPointProcess.git" diff --git a/version.py b/version.py index b794fd4..10939f0 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -__version__ = '0.1.0' +__version__ = '0.1.2'