|
| 1 | +import torch |
| 2 | +import matplotlib.pyplot as plt |
| 3 | + |
| 4 | +from pina import Trainer |
| 5 | +from pina.optim import TorchOptimizer |
| 6 | +from pina.problem import AbstractProblem |
| 7 | +from pina.condition.data_condition import DataCondition |
| 8 | +from pina.solver import AutoregressiveSolver |
| 9 | + |
| 10 | +NUM_TIMESTEPS = 100 |
| 11 | +NUM_FEATURES = 15 |
| 12 | +USE_TEST_MODEL = False |
| 13 | + |
| 14 | +# ============================================================================ |
| 15 | +# DATA |
| 16 | +# ============================================================================ |
| 17 | + |
| 18 | +torch.manual_seed(42) |
| 19 | + |
| 20 | +y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES) |
| 21 | +y[0] = torch.rand(NUM_FEATURES) # Random initial state |
| 22 | + |
| 23 | +for t in range(NUM_TIMESTEPS - 1): |
| 24 | + y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum()) |
| 25 | + |
| 26 | +# ============================================================================ |
| 27 | +# TRAINING |
| 28 | +# ============================================================================ |
| 29 | + |
| 30 | +class SimpleModel(torch.nn.Module): |
| 31 | + def __init__(self): |
| 32 | + super().__init__() |
| 33 | + self.layers = torch.nn.Sequential( |
| 34 | + torch.nn.Linear(y.shape[1], 20), |
| 35 | + torch.nn.ReLU(), |
| 36 | + torch.nn.Dropout(0.2), |
| 37 | + torch.nn.Linear(20, y.shape[1]), |
| 38 | + ) |
| 39 | + |
| 40 | + def forward(self, x): |
| 41 | + return x + self.layers(x) |
| 42 | + |
| 43 | + |
| 44 | +class TestModel(torch.nn.Module): |
| 45 | + """ |
| 46 | + Debug model that implements the EXACT transformation rule. |
| 47 | + y[t+1] = 0.95 * y[t] |
| 48 | + Expected loss is zero |
| 49 | + """ |
| 50 | + |
| 51 | + def __init__(self, data_series=None): |
| 52 | + super().__init__() |
| 53 | + self.dummy_param = torch.nn.Parameter(torch.zeros(1)) |
| 54 | + |
| 55 | + def forward(self, x): |
| 56 | + next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True)) |
| 57 | + return next_state + 0.0 * self.dummy_param |
| 58 | + |
| 59 | + |
| 60 | +class Problem(AbstractProblem): |
| 61 | + output_variables = None |
| 62 | + input_variables = None |
| 63 | + conditions = { |
| 64 | + "data_condition_0":DataCondition(input=y), |
| 65 | + "data_condition_1":DataCondition(input=y), |
| 66 | + } |
| 67 | + |
| 68 | +problem = Problem() |
| 69 | + |
| 70 | +#for each condition, define unroll instructions with these keys: |
| 71 | +# - unroll_length: length of each unroll window |
| 72 | +# - num_unrolls: number of unroll windows to create (if None, use all possible) |
| 73 | +# - randomize: whether to randomize the starting indices of the unroll windows |
| 74 | +unroll_instructions = { |
| 75 | + "data_condition_0": { |
| 76 | + "unroll_length": 10, |
| 77 | + "num_unrolls": 89, |
| 78 | + "randomize": True, |
| 79 | + "eps": 5.0 |
| 80 | + }, |
| 81 | + "data_condition_1": { |
| 82 | + "unroll_length": 20, |
| 83 | + "num_unrolls": 79, |
| 84 | + "randomize": True, |
| 85 | + "eps": 10.0 |
| 86 | + }, |
| 87 | +} |
| 88 | + |
| 89 | +solver = AutoregressiveSolver( |
| 90 | + unroll_instructions=unroll_instructions, |
| 91 | + problem=problem, |
| 92 | + model=TestModel() if USE_TEST_MODEL else SimpleModel(), |
| 93 | + optimizer= TorchOptimizer(torch.optim.AdamW, lr=0.01), |
| 94 | + eps=10.0, |
| 95 | +) |
| 96 | + |
| 97 | +trainer = Trainer( |
| 98 | + solver, max_epochs=2000, accelerator="cpu", enable_model_summary=False, shuffle=False |
| 99 | +) |
| 100 | +trainer.train() |
| 101 | + |
| 102 | +# ============================================================================ |
| 103 | +# VISUALIZATION |
| 104 | +# ============================================================================ |
| 105 | + |
| 106 | +test_start_idx = 50 |
| 107 | +num_prediction_steps = 30 |
| 108 | + |
| 109 | +initial_state = y[test_start_idx] # Shape: [features] |
| 110 | +predictions = solver.predict(initial_state, num_prediction_steps) |
| 111 | +actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1] |
| 112 | + |
| 113 | +total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:]) |
| 114 | +print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}") |
| 115 | + |
| 116 | +# viauzlize single dof |
| 117 | +dof_to_plot = [0, 3, 6, 9, 12] |
| 118 | +colors = [ |
| 119 | + "r", |
| 120 | + "g", |
| 121 | + "b", |
| 122 | + "c", |
| 123 | + "m", |
| 124 | + "y", |
| 125 | + "k", |
| 126 | +] |
| 127 | +plt.figure(figsize=(10, 6)) |
| 128 | +for dof, color in zip(dof_to_plot, colors): |
| 129 | + plt.plot( |
| 130 | + range(test_start_idx, test_start_idx + num_prediction_steps + 1), |
| 131 | + actual[:, dof].numpy(), |
| 132 | + label="Actual", |
| 133 | + marker="o", |
| 134 | + color=color, |
| 135 | + markerfacecolor="none", |
| 136 | + ) |
| 137 | + plt.plot( |
| 138 | + range(test_start_idx, test_start_idx + num_prediction_steps + 1), |
| 139 | + predictions[:, dof].numpy(), |
| 140 | + label="Predicted", |
| 141 | + marker="x", |
| 142 | + color=color, |
| 143 | + ) |
| 144 | + |
| 145 | +plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}") |
| 146 | +plt.legend() |
| 147 | +plt.xlabel("Timestep") |
| 148 | +plt.savefig(f"autoregressive_predictions.png") |
| 149 | +plt.close() |
0 commit comments