diff --git a/tests/test_solver/test_supervised_solver.py b/tests/test_solver/test_supervised_solver.py index 63f6b307e..30ae08064 100644 --- a/tests/test_solver/test_supervised_solver.py +++ b/tests/test_solver/test_supervised_solver.py @@ -1,12 +1,14 @@ import torch import pytest +from torch._dynamo.eval_frame import OptimizedModule +from torch_geometric.nn import GCNConv from pina import Condition, LabelTensor from pina.condition import InputTargetCondition from pina.problem import AbstractProblem from pina.solver import SupervisedSolver from pina.model import FeedForward from pina.trainer import Trainer -from torch._dynamo.eval_frame import OptimizedModule +from pina.graph import KNNGraph class LabelTensorProblem(AbstractProblem): @@ -28,9 +30,64 @@ class TensorProblem(AbstractProblem): } +x = torch.rand((100, 20, 5)) +pos = torch.rand((100, 20, 2)) +output_ = torch.rand((100, 20, 1)) +input_ = [ + KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) + for x_, pos_ in zip(x, pos) +] + + +class GraphProblem(AbstractProblem): + output_variables = None + conditions = {"data": Condition(input=input_, target=output_)} + + +x = LabelTensor(torch.rand((100, 20, 5)), ["a", "b", "c", "d", "e"]) +pos = LabelTensor(torch.rand((100, 20, 2)), ["x", "y"]) +output_ = LabelTensor(torch.rand((100, 20, 1)), ["u"]) +input_ = [ + KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True) + for i in range(len(x)) +] + + +class GraphProblemLT(AbstractProblem): + output_variables = ["u"] + input_variables = ["a", "b", "c", "d", "e"] + conditions = {"data": Condition(input=input_, target=output_)} + + model = FeedForward(2, 1) +class Model(torch.nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lift = torch.nn.Linear(5, 10) + self.activation = torch.nn.Tanh() + self.output = torch.nn.Linear(10, 1) + + self.conv = GCNConv(10, 10) + + def forward(self, batch): + + x = batch.x + edge_index = batch.edge_index + for _ in range(1): + y = self.lift(x) + y = self.activation(y) + y = self.conv(y, edge_index) + y = self.activation(y) + y = self.output(y) + return y + + +graph_model = Model() + + def test_constructor(): SupervisedSolver(problem=TensorProblem(), model=model) SupervisedSolver(problem=LabelTensorProblem(), model=model) @@ -59,6 +116,24 @@ def test_solver_train(use_lt, batch_size, compile): assert isinstance(solver.model, OptimizedModule) +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("use_lt", [True, False]) +def test_solver_train_graph(batch_size, use_lt): + problem = GraphProblemLT() if use_lt else GraphProblem() + solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=1.0, + test_size=0.0, + val_size=0.0, + ) + + trainer.train() + + @pytest.mark.parametrize("use_lt", [True, False]) @pytest.mark.parametrize("compile", [True, False]) def test_solver_validation(use_lt, compile): @@ -79,6 +154,24 @@ def test_solver_validation(use_lt, compile): assert isinstance(solver.model, OptimizedModule) +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("use_lt", [True, False]) +def test_solver_validation_graph(batch_size, use_lt): + problem = GraphProblemLT() if use_lt else GraphProblem() + solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=0.9, + val_size=0.1, + test_size=0.0, + ) + + trainer.train() + + @pytest.mark.parametrize("use_lt", [True, False]) @pytest.mark.parametrize("compile", [True, False]) def test_solver_test(use_lt, compile): @@ -99,6 +192,24 @@ def test_solver_test(use_lt, compile): assert isinstance(solver.model, OptimizedModule) +@pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) +@pytest.mark.parametrize("use_lt", [True, False]) +def test_solver_test_graph(batch_size, use_lt): + problem = GraphProblemLT() if use_lt else GraphProblem() + solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=0.8, + val_size=0.1, + test_size=0.1, + ) + + trainer.test() + + def test_train_load_restore(): dir = "tests/test_solver/tmp/" problem = LabelTensorProblem()