Skip to content

Commit 1bb9d46

Browse files
committed
implement everything into solver
1 parent 1daa10b commit 1bb9d46

File tree

8 files changed

+285
-245
lines changed

8 files changed

+285
-245
lines changed
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import torch
2+
import matplotlib.pyplot as plt
3+
4+
from pina import Trainer
5+
from pina.optim import TorchOptimizer
6+
from pina.problem import AbstractProblem
7+
from pina.condition.data_condition import DataCondition
8+
from pina.solver import AutoregressiveSolver
9+
10+
NUM_TIMESTEPS = 100
11+
NUM_FEATURES = 15
12+
USE_TEST_MODEL = False
13+
14+
# ============================================================================
15+
# DATA
16+
# ============================================================================
17+
18+
torch.manual_seed(42)
19+
20+
y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES)
21+
y[0] = torch.rand(NUM_FEATURES) # Random initial state
22+
23+
for t in range(NUM_TIMESTEPS - 1):
24+
y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum())
25+
26+
# ============================================================================
27+
# TRAINING
28+
# ============================================================================
29+
30+
class SimpleModel(torch.nn.Module):
31+
def __init__(self):
32+
super().__init__()
33+
self.layers = torch.nn.Sequential(
34+
torch.nn.Linear(y.shape[1], 20),
35+
torch.nn.ReLU(),
36+
torch.nn.Dropout(0.2),
37+
torch.nn.Linear(20, y.shape[1]),
38+
)
39+
40+
def forward(self, x):
41+
return x + self.layers(x)
42+
43+
44+
class TestModel(torch.nn.Module):
45+
"""
46+
Debug model that implements the EXACT transformation rule.
47+
y[t+1] = 0.95 * y[t]
48+
Expected loss is zero
49+
"""
50+
51+
def __init__(self, data_series=None):
52+
super().__init__()
53+
self.dummy_param = torch.nn.Parameter(torch.zeros(1))
54+
55+
def forward(self, x):
56+
next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True))
57+
return next_state + 0.0 * self.dummy_param
58+
59+
60+
class Problem(AbstractProblem):
61+
output_variables = None
62+
input_variables = None
63+
conditions = {
64+
"data_condition_0":DataCondition(input=y),
65+
"data_condition_1":DataCondition(input=y),
66+
}
67+
68+
problem = Problem()
69+
70+
#for each condition, define unroll instructions with these keys:
71+
# - unroll_length: length of each unroll window
72+
# - num_unrolls: number of unroll windows to create (if None, use all possible)
73+
# - randomize: whether to randomize the starting indices of the unroll windows
74+
unroll_instructions = {
75+
"data_condition_0": {
76+
"unroll_length": 10,
77+
"num_unrolls": 89,
78+
"randomize": True,
79+
"eps": 5.0
80+
},
81+
"data_condition_1": {
82+
"unroll_length": 20,
83+
"num_unrolls": 79,
84+
"randomize": True,
85+
"eps": 10.0
86+
},
87+
}
88+
89+
solver = AutoregressiveSolver(
90+
unroll_instructions=unroll_instructions,
91+
problem=problem,
92+
model=TestModel() if USE_TEST_MODEL else SimpleModel(),
93+
optimizer= TorchOptimizer(torch.optim.AdamW, lr=0.01),
94+
eps=10.0,
95+
)
96+
97+
trainer = Trainer(
98+
solver, max_epochs=2000, accelerator="cpu", enable_model_summary=False, shuffle=False
99+
)
100+
trainer.train()
101+
102+
# ============================================================================
103+
# VISUALIZATION
104+
# ============================================================================
105+
106+
test_start_idx = 50
107+
num_prediction_steps = 30
108+
109+
initial_state = y[test_start_idx] # Shape: [features]
110+
predictions = solver.predict(initial_state, num_prediction_steps)
111+
actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1]
112+
113+
total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:])
114+
print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}")
115+
116+
# viauzlize single dof
117+
dof_to_plot = [0, 3, 6, 9, 12]
118+
colors = [
119+
"r",
120+
"g",
121+
"b",
122+
"c",
123+
"m",
124+
"y",
125+
"k",
126+
]
127+
plt.figure(figsize=(10, 6))
128+
for dof, color in zip(dof_to_plot, colors):
129+
plt.plot(
130+
range(test_start_idx, test_start_idx + num_prediction_steps + 1),
131+
actual[:, dof].numpy(),
132+
label="Actual",
133+
marker="o",
134+
color=color,
135+
markerfacecolor="none",
136+
)
137+
plt.plot(
138+
range(test_start_idx, test_start_idx + num_prediction_steps + 1),
139+
predictions[:, dof].numpy(),
140+
label="Predicted",
141+
marker="x",
142+
color=color,
143+
)
144+
145+
plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}")
146+
plt.legend()
147+
plt.xlabel("Timestep")
148+
plt.savefig(f"autoregressive_predictions.png")
149+
plt.close()

pina/condition/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"DataCondition",
1616
"GraphDataCondition",
1717
"TensorDataCondition",
18-
"AutoregressiveCondition",
1918
]
2019

2120
from .condition_interface import ConditionInterface
@@ -38,5 +37,3 @@
3837
GraphDataCondition,
3938
TensorDataCondition,
4039
)
41-
42-
from .autoregressive_condition import AutoregressiveCondition

pina/condition/autoregressive_condition.py

Lines changed: 0 additions & 91 deletions
This file was deleted.

pina/loss/__init__.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99
"NeuralTangentKernelWeighting",
1010
"SelfAdaptiveWeighting",
1111
"LinearWeighting",
12-
"TimeWeightingInterface",
13-
"ConstantTimeWeighting",
14-
"ExponentialTimeWeighting",
15-
"LinearTimeWeighting",
1612
]
1713

1814
from .loss_interface import LossInterface
@@ -23,9 +19,3 @@
2319
from .ntk_weighting import NeuralTangentKernelWeighting
2420
from .self_adaptive_weighting import SelfAdaptiveWeighting
2521
from .linear_weighting import LinearWeighting
26-
from .time_weighting_interface import TimeWeightingInterface
27-
from .time_weighting import (
28-
ConstantTimeWeighting,
29-
ExponentialTimeWeighting,
30-
LinearTimeWeighting,
31-
)

pina/loss/time_weighting.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

pina/loss/time_weighting_interface.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)