Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 112 additions & 1 deletion tests/test_solver/test_supervised_solver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import torch
import pytest
from torch._dynamo.eval_frame import OptimizedModule
from torch_geometric.nn import GCNConv
from pina import Condition, LabelTensor
from pina.condition import InputTargetCondition
from pina.problem import AbstractProblem
from pina.solver import SupervisedSolver
from pina.model import FeedForward
from pina.trainer import Trainer
from torch._dynamo.eval_frame import OptimizedModule
from pina.graph import KNNGraph


class LabelTensorProblem(AbstractProblem):
Expand All @@ -28,9 +30,64 @@ class TensorProblem(AbstractProblem):
}


x = torch.rand((100, 20, 5))
pos = torch.rand((100, 20, 2))
output_ = torch.rand((100, 20, 1))
input_ = [
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
for x_, pos_ in zip(x, pos)
]


class GraphProblem(AbstractProblem):
output_variables = None
conditions = {"data": Condition(input=input_, target=output_)}


x = LabelTensor(torch.rand((100, 20, 5)), ["a", "b", "c", "d", "e"])
pos = LabelTensor(torch.rand((100, 20, 2)), ["x", "y"])
output_ = LabelTensor(torch.rand((100, 20, 1)), ["u"])
input_ = [
KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True)
for i in range(len(x))
]


class GraphProblemLT(AbstractProblem):
output_variables = ["u"]
input_variables = ["a", "b", "c", "d", "e"]
conditions = {"data": Condition(input=input_, target=output_)}


model = FeedForward(2, 1)


class Model(torch.nn.Module):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lift = torch.nn.Linear(5, 10)
self.activation = torch.nn.Tanh()
self.output = torch.nn.Linear(10, 1)

self.conv = GCNConv(10, 10)

def forward(self, batch):

x = batch.x
edge_index = batch.edge_index
for _ in range(1):
y = self.lift(x)
y = self.activation(y)
y = self.conv(y, edge_index)
y = self.activation(y)
y = self.output(y)
return y


graph_model = Model()


def test_constructor():
SupervisedSolver(problem=TensorProblem(), model=model)
SupervisedSolver(problem=LabelTensorProblem(), model=model)
Expand Down Expand Up @@ -59,6 +116,24 @@ def test_solver_train(use_lt, batch_size, compile):
assert isinstance(solver.model, OptimizedModule)


@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_train_graph(batch_size, use_lt):
problem = GraphProblemLT() if use_lt else GraphProblem()
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=1.0,
test_size=0.0,
val_size=0.0,
)

trainer.train()


@pytest.mark.parametrize("use_lt", [True, False])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_validation(use_lt, compile):
Expand All @@ -79,6 +154,24 @@ def test_solver_validation(use_lt, compile):
assert isinstance(solver.model, OptimizedModule)


@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_validation_graph(batch_size, use_lt):
problem = GraphProblemLT() if use_lt else GraphProblem()
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=0.9,
val_size=0.1,
test_size=0.0,
)

trainer.train()


@pytest.mark.parametrize("use_lt", [True, False])
@pytest.mark.parametrize("compile", [True, False])
def test_solver_test(use_lt, compile):
Expand All @@ -99,6 +192,24 @@ def test_solver_test(use_lt, compile):
assert isinstance(solver.model, OptimizedModule)


@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
@pytest.mark.parametrize("use_lt", [True, False])
def test_solver_test_graph(batch_size, use_lt):
problem = GraphProblemLT() if use_lt else GraphProblem()
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
trainer = Trainer(
solver=solver,
max_epochs=2,
accelerator="cpu",
batch_size=batch_size,
train_size=0.8,
val_size=0.1,
test_size=0.1,
)

trainer.test()


def test_train_load_restore():
dir = "tests/test_solver/tmp/"
problem = LabelTensorProblem()
Expand Down
Loading