Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ log
.idea

examples/checkpoints/*

notebooks/checkpoints/*
18 changes: 14 additions & 4 deletions easy_tpp/default_registers/register_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@ def rmse_metric_function(predictions, labels, **kwargs):
float: average rmse of the time predictions.
"""
seq_mask = kwargs.get('seq_mask')
pred = predictions[PredOutputIndex.TimePredIndex][seq_mask]
label = labels[PredOutputIndex.TimePredIndex][seq_mask]
if seq_mask is None or len(seq_mask) == 0:
# If mask is empty or None, use all predictions
pred = predictions[PredOutputIndex.TimePredIndex]
label = labels[PredOutputIndex.TimePredIndex]
else:
pred = predictions[PredOutputIndex.TimePredIndex][seq_mask]
label = labels[PredOutputIndex.TimePredIndex][seq_mask]

pred = np.reshape(pred, [-1])
label = np.reshape(label, [-1])
Expand All @@ -36,8 +41,13 @@ def acc_metric_function(predictions, labels, **kwargs):
float: accuracy ratio of the type predictions.
"""
seq_mask = kwargs.get('seq_mask')
pred = predictions[PredOutputIndex.TypePredIndex][seq_mask]
label = labels[PredOutputIndex.TypePredIndex][seq_mask]
if seq_mask is None or len(seq_mask) == 0:
# If mask is empty or None, use all predictions
pred = predictions[PredOutputIndex.TypePredIndex]
label = labels[PredOutputIndex.TypePredIndex]
else:
pred = predictions[PredOutputIndex.TypePredIndex][seq_mask]
label = labels[PredOutputIndex.TypePredIndex][seq_mask]
pred = np.reshape(pred, [-1])
label = np.reshape(label, [-1])
return np.mean(pred == label)
13 changes: 8 additions & 5 deletions easy_tpp/model/torch_model/torch_basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,18 @@ def predict_multi_step_since_last_event(self, batch, forward=False):
"""Multi-step prediction since last event in the sequence.

Args:
time_seqs (tensor): [batch_size, seq_len].
time_delta_seqs (tensor): [batch_size, seq_len].
type_seqs (tensor): [batch_size, seq_len].
num_step (int): num of steps for prediction.
batch (tuple): A tuple containing:
- time_seq_label (tensor): Timestamps of events [batch_size, seq_len].
- time_delta_seq_label (tensor): Time intervals between events [batch_size, seq_len].
- event_seq_label (tensor): Event types [batch_size, seq_len].
- batch_non_pad_mask_label (tensor): Mask for non-padding elements [batch_size, seq_len].
- attention_mask (tensor): Mask for attention [batch_size, seq_len].
forward (bool, optional): Whether to use the entire sequence for prediction. Defaults to False.

Returns:
tuple: tensors of dtime and type prediction, [batch_size, seq_len].
"""
time_seq_label, time_delta_seq_label, event_seq_label, batch_non_pad_mask_label, _, type_mask_label = batch
time_seq_label, time_delta_seq_label, event_seq_label, _, _ = batch

num_step = self.gen_config.num_step_gen

Expand Down
2 changes: 1 addition & 1 deletion notebooks/easytpp_1_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
],
"source": [
"# ues the latest release\n",
"# !pip install easy_tpp\n",
"# !pip install easy-tpp\n",
"\n",
"# or use the git main branch\n",
"!pip install git+https://github.com/ant-research/EasyTemporalPointProcess.git"
Expand Down
2 changes: 1 addition & 1 deletion version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.0'
__version__ = '0.1.2'