Skip to content

Commit 1daa10b

Browse files
committed
implement autoregressive condition, time_weighting, solver
1 parent 2888ae3 commit 1daa10b

File tree

9 files changed

+375
-0
lines changed

9 files changed

+375
-0
lines changed

pina/condition/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"DataCondition",
1616
"GraphDataCondition",
1717
"TensorDataCondition",
18+
"AutoregressiveCondition",
1819
]
1920

2021
from .condition_interface import ConditionInterface
@@ -37,3 +38,5 @@
3738
GraphDataCondition,
3839
TensorDataCondition,
3940
)
41+
42+
from .autoregressive_condition import AutoregressiveCondition
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
from .condition_interface import ConditionInterface
3+
from ..loss import TimeWeightingInterface, ConstantTimeWeighting
4+
from ..utils import check_consistency
5+
6+
7+
class AutoregressiveCondition(ConditionInterface):
8+
"""
9+
A specialized condition for autoregressive tasks.
10+
It generates input/unroll pairs from a single time-series tensor.
11+
"""
12+
13+
__slots__ = ["input", "unroll"]
14+
15+
def __init__(
16+
self,
17+
data,
18+
unroll_length,
19+
num_unrolls=None,
20+
randomize=True,
21+
time_weighting=None,
22+
):
23+
"""
24+
Create an AutoregressiveCondition.
25+
"""
26+
super().__init__()
27+
28+
self._n_timesteps, n_features = data.shape
29+
self._unroll_length = unroll_length
30+
self._requested_num_unrolls = num_unrolls
31+
self._randomize = randomize
32+
33+
# time weighting: weight the loss differently along the unroll
34+
if time_weighting is None:
35+
self._time_weighting = ConstantTimeWeighting()
36+
else:
37+
check_consistency(time_weighting, TimeWeightingInterface)
38+
self._time_weighting = time_weighting
39+
40+
# windows creation
41+
initial_data = []
42+
unroll_data = []
43+
44+
for starting_index in self.starting_indices:
45+
initial_data.append(data[starting_index])
46+
target_start = starting_index + 1
47+
unroll_data.append(
48+
data[target_start : target_start + self._unroll_length, :]
49+
)
50+
51+
self.input = torch.stack(initial_data) # [num_unrolls, features]
52+
self.unroll = torch.stack(
53+
unroll_data
54+
) # [num_unrolls, unroll_length, features]
55+
56+
@property
57+
def unroll_length(self):
58+
return self._unroll_length
59+
60+
@property
61+
def time_weighting(self):
62+
return self._time_weighting
63+
64+
@property
65+
def max_start_idx(self):
66+
max_start_idx = self._n_timesteps - self._unroll_length
67+
assert max_start_idx > 0, "Provided data sequence too short"
68+
return max_start_idx
69+
70+
@property
71+
def num_unrolls(self):
72+
if self._requested_num_unrolls is None:
73+
return self.max_start_idx
74+
else:
75+
assert (
76+
self._requested_num_unrolls < self.max_start_idx
77+
), "too many samples requested"
78+
return self._requested_num_unrolls
79+
80+
@property
81+
def starting_indices(self):
82+
all_starting_indices = torch.arange(self.max_start_idx)
83+
84+
if self._randomize:
85+
perm = torch.randperm(len(all_starting_indices))
86+
return all_starting_indices[perm[: self.num_unrolls]]
87+
else:
88+
selected_indices = torch.linspace(
89+
0, len(all_starting_indices) - 1, self.num_unrolls
90+
).long()
91+
return all_starting_indices[selected_indices]

pina/loss/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
"NeuralTangentKernelWeighting",
1010
"SelfAdaptiveWeighting",
1111
"LinearWeighting",
12+
"TimeWeightingInterface",
13+
"ConstantTimeWeighting",
14+
"ExponentialTimeWeighting",
15+
"LinearTimeWeighting",
1216
]
1317

1418
from .loss_interface import LossInterface
@@ -19,3 +23,9 @@
1923
from .ntk_weighting import NeuralTangentKernelWeighting
2024
from .self_adaptive_weighting import SelfAdaptiveWeighting
2125
from .linear_weighting import LinearWeighting
26+
from .time_weighting_interface import TimeWeightingInterface
27+
from .time_weighting import (
28+
ConstantTimeWeighting,
29+
ExponentialTimeWeighting,
30+
LinearTimeWeighting,
31+
)

pina/loss/time_weighting.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""Module for the Time Weighting."""
2+
3+
import torch
4+
from .time_weighting_interface import TimeWeightingInterface
5+
6+
7+
class ConstantTimeWeighting(TimeWeightingInterface):
8+
"""
9+
Weighting scheme that assigns equal weight to all time steps.
10+
"""
11+
12+
def __call__(self, num_steps, device):
13+
return torch.ones(num_steps, device=device) / num_steps
14+
15+
16+
class ExponentialTimeWeighting(TimeWeightingInterface):
17+
"""
18+
Weighting scheme change exponentially with time.
19+
gamma > 1.0: increasing weights
20+
0 < gamma < 1.0: decreasing weights
21+
weight at time t is gamma^t
22+
"""
23+
24+
def __init__(self, gamma=0.9):
25+
"""
26+
Initialization of the :class:`ExponentialTimeWeighting` class.
27+
:param float gamma: The decay factor. Default is 0.9.
28+
"""
29+
self.gamma = gamma
30+
31+
def __call__(self, num_steps, device):
32+
steps = torch.arange(num_steps, device=device, dtype=torch.float32)
33+
weights = self.gamma**steps
34+
return weights / weights.sum()
35+
36+
37+
class LinearTimeWeighting(TimeWeightingInterface):
38+
"""
39+
Weighting scheme that changes linearly from a start weight to an end weight.
40+
"""
41+
42+
def __init__(self, start=0.1, end=1.0):
43+
"""
44+
Initialization of the :class:`LinearDecayTimeWeighting` class.
45+
46+
:param float start: The starting weight. Default is 0.1.
47+
:param float end: The ending weight. Default is 1.0.
48+
"""
49+
self.start = start
50+
self.end = end
51+
52+
def __call__(self, num_steps, device):
53+
if num_steps == 1:
54+
return torch.ones(1, device=device)
55+
56+
weights = torch.linspace(self.start, self.end, num_steps, device=device)
57+
return weights / weights.sum()
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Module for the Time Weighting Interface."""
2+
3+
from abc import ABCMeta, abstractmethod
4+
import torch
5+
6+
7+
class TimeWeightingInterface(metaclass=ABCMeta):
8+
"""
9+
Abstract base class for all time weighting schemas. All time weighting
10+
schemas should inherit from this class.
11+
"""
12+
13+
@abstractmethod
14+
def __call__(self, num_steps, device):
15+
"""
16+
Compute the weights for the time steps.
17+
18+
:param int num_steps: The number of time steps.
19+
:param torch.device device: The device on which the weights should be
20+
created.
21+
:return: The weights for the time steps.
22+
:rtype: torch.Tensor
23+
"""
24+
pass

pina/solver/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"DeepEnsembleSupervisedSolver",
1919
"DeepEnsemblePINN",
2020
"GAROM",
21+
"AutoregressiveSolver",
2122
]
2223

2324
from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface
@@ -41,3 +42,7 @@
4142
DeepEnsemblePINN,
4243
)
4344
from .garom import GAROM
45+
from .autoregressive_solver import (
46+
AutoregressiveSolver,
47+
AutoregressiveSolverInterface,
48+
)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
__all__ = ["AutoregressiveSolver", "AutoregressiveSolverInterface"]
2+
3+
from .autoregressive_solver import AutoregressiveSolver
4+
from .autoregressive_solver_interface import AutoregressiveSolverInterface
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
from torch.nn.modules.loss import _Loss
3+
4+
from pina.utils import check_consistency
5+
from pina.solver.solver import SingleSolverInterface
6+
from pina.condition import AutoregressiveCondition
7+
from pina.loss import (
8+
LossInterface,
9+
TimeWeightingInterface,
10+
ConstantTimeWeighting,
11+
)
12+
from .autoregressive_solver_interface import AutoregressiveSolverInterface
13+
14+
15+
class AutoregressiveSolver(
16+
AutoregressiveSolverInterface, SingleSolverInterface
17+
):
18+
"""
19+
Autoregressive Solver class.
20+
"""
21+
22+
accepted_conditions_types = AutoregressiveCondition
23+
24+
def __init__(
25+
self,
26+
problem,
27+
model,
28+
loss=None,
29+
optimizer=None,
30+
scheduler=None,
31+
weighting=None,
32+
use_lt=False,
33+
):
34+
"""
35+
Initialization of the :class:`AutoregressiveSolver` class.
36+
"""
37+
super().__init__(
38+
problem=problem,
39+
model=model,
40+
loss=loss,
41+
optimizer=optimizer,
42+
scheduler=scheduler,
43+
weighting=weighting,
44+
use_lt=use_lt,
45+
)
46+
47+
def loss_data(self, input, target, unroll_length, time_weighting):
48+
"""
49+
Compute the data loss for the recursive autoregressive solver.
50+
This will be applied to each condition individually.
51+
"""
52+
steps_to_predict = unroll_length - 1
53+
# weights are passed from the condition
54+
weights = time_weighting(steps_to_predict, device=input.device)
55+
56+
total_loss = 0.0
57+
current_state = input
58+
59+
for step in range(steps_to_predict):
60+
61+
predicted_next_state = self.forward(
62+
current_state
63+
) # [batch_size, features]
64+
actual_next_state = target[:, step, :] # [batch_size, features]
65+
66+
step_loss = self.loss(predicted_next_state, actual_next_state)
67+
68+
total_loss += step_loss * weights[step]
69+
70+
current_state = predicted_next_state.detach()
71+
72+
return total_loss
73+
74+
def predict(self, initial_state, num_steps):
75+
"""
76+
Make recursive predictions starting from an initial state.
77+
"""
78+
self.eval() # Set model to evaluation mode
79+
80+
current_state = initial_state
81+
predictions = [current_state] # Store initial state without batch dim
82+
with torch.no_grad():
83+
for step in range(num_steps):
84+
next_state = self.forward(current_state)
85+
predictions.append(next_state) # Keep batch dim for storage
86+
current_state = next_state
87+
88+
return torch.stack(predictions)

0 commit comments

Comments
 (0)