forked from sktime/pytorch-forecasting
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbaseline.py
More file actions
85 lines (68 loc) · 2.44 KB
/
baseline.py
File metadata and controls
85 lines (68 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Baseline model.
"""
from typing import Any
import torch
from pytorch_forecasting.models import BaseModel
class Baseline(BaseModel):
"""
Baseline model that uses last known target value to make prediction.
Example:
.. code-block:: python
from pytorch_forecasting import BaseModel, MAE
# generating predictions
predictions = Baseline().predict(dataloader)
# calculate baseline performance in terms of mean absolute error (MAE)
metric = MAE()
model = Baseline()
for x, y in dataloader:
metric.update(model(x), y)
metric.compute()
"""
def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Network forward pass.
Parameters
----------
x : Dict[str, torch.Tensor]
network input
Returns
-------
Dict[str, torch.Tensor]
network outputs
"""
if isinstance(x["encoder_target"], tuple | list): # multiple targets
prediction = [
self.forward_one_target(
encoder_lengths=x["encoder_lengths"],
decoder_lengths=x["decoder_lengths"],
encoder_target=encoder_target,
)
for encoder_target in x["encoder_target"]
]
else: # one target
prediction = self.forward_one_target(
encoder_lengths=x["encoder_lengths"],
decoder_lengths=x["decoder_lengths"],
encoder_target=x["encoder_target"],
)
return self.to_network_output(prediction=prediction)
def forward_one_target(
self,
encoder_lengths: torch.Tensor,
decoder_lengths: torch.Tensor,
encoder_target: torch.Tensor,
):
max_prediction_length = decoder_lengths.max()
assert (
encoder_lengths.min() > 0
), "Encoder lengths of at least 1 required to obtain last value"
last_values = encoder_target[
torch.arange(encoder_target.size(0)), encoder_lengths - 1
]
prediction = last_values[:, None].expand(-1, max_prediction_length)
return prediction
def to_prediction(self, out: dict[str, Any], use_metric: bool = True, **kwargs):
return out.prediction
def to_quantiles(self, out: dict[str, Any], use_metric: bool = True, **kwargs):
return out.prediction[..., None]