Skip to content

Commit 7886c38

Browse files
authored
Add test supervised solver for graph based models (#480)
1 parent e541af0 commit 7886c38

File tree

1 file changed

+112
-1
lines changed

1 file changed

+112
-1
lines changed

tests/test_solver/test_supervised_solver.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import torch
22
import pytest
3+
from torch._dynamo.eval_frame import OptimizedModule
4+
from torch_geometric.nn import GCNConv
35
from pina import Condition, LabelTensor
46
from pina.condition import InputTargetCondition
57
from pina.problem import AbstractProblem
68
from pina.solver import SupervisedSolver
79
from pina.model import FeedForward
810
from pina.trainer import Trainer
9-
from torch._dynamo.eval_frame import OptimizedModule
11+
from pina.graph import KNNGraph
1012

1113

1214
class LabelTensorProblem(AbstractProblem):
@@ -28,9 +30,64 @@ class TensorProblem(AbstractProblem):
2830
}
2931

3032

33+
x = torch.rand((100, 20, 5))
34+
pos = torch.rand((100, 20, 2))
35+
output_ = torch.rand((100, 20, 1))
36+
input_ = [
37+
KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True)
38+
for x_, pos_ in zip(x, pos)
39+
]
40+
41+
42+
class GraphProblem(AbstractProblem):
43+
output_variables = None
44+
conditions = {"data": Condition(input=input_, target=output_)}
45+
46+
47+
x = LabelTensor(torch.rand((100, 20, 5)), ["a", "b", "c", "d", "e"])
48+
pos = LabelTensor(torch.rand((100, 20, 2)), ["x", "y"])
49+
output_ = LabelTensor(torch.rand((100, 20, 1)), ["u"])
50+
input_ = [
51+
KNNGraph(x=x[i], pos=pos[i], neighbours=3, edge_attr=True)
52+
for i in range(len(x))
53+
]
54+
55+
56+
class GraphProblemLT(AbstractProblem):
57+
output_variables = ["u"]
58+
input_variables = ["a", "b", "c", "d", "e"]
59+
conditions = {"data": Condition(input=input_, target=output_)}
60+
61+
3162
model = FeedForward(2, 1)
3263

3364

65+
class Model(torch.nn.Module):
66+
67+
def __init__(self, *args, **kwargs):
68+
super().__init__(*args, **kwargs)
69+
self.lift = torch.nn.Linear(5, 10)
70+
self.activation = torch.nn.Tanh()
71+
self.output = torch.nn.Linear(10, 1)
72+
73+
self.conv = GCNConv(10, 10)
74+
75+
def forward(self, batch):
76+
77+
x = batch.x
78+
edge_index = batch.edge_index
79+
for _ in range(1):
80+
y = self.lift(x)
81+
y = self.activation(y)
82+
y = self.conv(y, edge_index)
83+
y = self.activation(y)
84+
y = self.output(y)
85+
return y
86+
87+
88+
graph_model = Model()
89+
90+
3491
def test_constructor():
3592
SupervisedSolver(problem=TensorProblem(), model=model)
3693
SupervisedSolver(problem=LabelTensorProblem(), model=model)
@@ -59,6 +116,24 @@ def test_solver_train(use_lt, batch_size, compile):
59116
assert isinstance(solver.model, OptimizedModule)
60117

61118

119+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
120+
@pytest.mark.parametrize("use_lt", [True, False])
121+
def test_solver_train_graph(batch_size, use_lt):
122+
problem = GraphProblemLT() if use_lt else GraphProblem()
123+
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
124+
trainer = Trainer(
125+
solver=solver,
126+
max_epochs=2,
127+
accelerator="cpu",
128+
batch_size=batch_size,
129+
train_size=1.0,
130+
test_size=0.0,
131+
val_size=0.0,
132+
)
133+
134+
trainer.train()
135+
136+
62137
@pytest.mark.parametrize("use_lt", [True, False])
63138
@pytest.mark.parametrize("compile", [True, False])
64139
def test_solver_validation(use_lt, compile):
@@ -79,6 +154,24 @@ def test_solver_validation(use_lt, compile):
79154
assert isinstance(solver.model, OptimizedModule)
80155

81156

157+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
158+
@pytest.mark.parametrize("use_lt", [True, False])
159+
def test_solver_validation_graph(batch_size, use_lt):
160+
problem = GraphProblemLT() if use_lt else GraphProblem()
161+
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
162+
trainer = Trainer(
163+
solver=solver,
164+
max_epochs=2,
165+
accelerator="cpu",
166+
batch_size=batch_size,
167+
train_size=0.9,
168+
val_size=0.1,
169+
test_size=0.0,
170+
)
171+
172+
trainer.train()
173+
174+
82175
@pytest.mark.parametrize("use_lt", [True, False])
83176
@pytest.mark.parametrize("compile", [True, False])
84177
def test_solver_test(use_lt, compile):
@@ -99,6 +192,24 @@ def test_solver_test(use_lt, compile):
99192
assert isinstance(solver.model, OptimizedModule)
100193

101194

195+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
196+
@pytest.mark.parametrize("use_lt", [True, False])
197+
def test_solver_test_graph(batch_size, use_lt):
198+
problem = GraphProblemLT() if use_lt else GraphProblem()
199+
solver = SupervisedSolver(problem=problem, model=graph_model, use_lt=use_lt)
200+
trainer = Trainer(
201+
solver=solver,
202+
max_epochs=2,
203+
accelerator="cpu",
204+
batch_size=batch_size,
205+
train_size=0.8,
206+
val_size=0.1,
207+
test_size=0.1,
208+
)
209+
210+
trainer.test()
211+
212+
102213
def test_train_load_restore():
103214
dir = "tests/test_solver/tmp/"
104215
problem = LabelTensorProblem()

0 commit comments

Comments
 (0)