diff --git a/tests/test_weighting/test_self_adaptive_weighting.py b/tests/test_weighting/test_self_adaptive_weighting.py index 066e8855e..af9818a60 100644 --- a/tests/test_weighting/test_self_adaptive_weighting.py +++ b/tests/test_weighting/test_self_adaptive_weighting.py @@ -1,4 +1,5 @@ import pytest +import torch from pina import Trainer from pina.solver import PINN from pina.model import FeedForward @@ -37,3 +38,26 @@ def test_train_aggregation(update_every_n_epochs): solver = PINN(problem=problem, model=model, weighting=weighting) trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") trainer.train() + +class Net_biased(torch.nn.Module): + def __init__(self, input_dim, output_dim, num_layers=2): + super().__init__() + self.mlp = FeedForward( + input_dimensions=input_dim, + output_dimensions=output_dim, + layers=[10 for _ in range(num_layers)] + ) + self.bias = torch.nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.mlp(x) + +@pytest.mark.parametrize("update_every_n_epochs", [1, 3]) +def test_train_aggregation_freezed_weights(update_every_n_epochs): + model = Net_biased(len(problem.input_variables), len(problem.output_variables)) + weighting = SelfAdaptiveWeighting( + update_every_n_epochs=update_every_n_epochs + ) + solver = PINN(problem=problem, model=model, weighting=weighting) + trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") + trainer.train() \ No newline at end of file