11import torch
22import pytest
3+ from torch ._dynamo .eval_frame import OptimizedModule
4+ from torch_geometric .nn import GCNConv
35from pina import Condition , LabelTensor
46from pina .condition import InputTargetCondition
57from pina .problem import AbstractProblem
68from pina .solver import SupervisedSolver
79from pina .model import FeedForward
810from pina .trainer import Trainer
9- from torch . _dynamo . eval_frame import OptimizedModule
11+ from pina . graph import KNNGraph
1012
1113
1214class LabelTensorProblem (AbstractProblem ):
@@ -28,9 +30,64 @@ class TensorProblem(AbstractProblem):
2830 }
2931
3032
33+ x = torch .rand ((100 , 20 , 5 ))
34+ pos = torch .rand ((100 , 20 , 2 ))
35+ output_ = torch .rand ((100 , 20 , 1 ))
36+ input_ = [
37+ KNNGraph (x = x_ , pos = pos_ , neighbours = 3 , edge_attr = True )
38+ for x_ , pos_ in zip (x , pos )
39+ ]
40+
41+
42+ class GraphProblem (AbstractProblem ):
43+ output_variables = None
44+ conditions = {"data" : Condition (input = input_ , target = output_ )}
45+
46+
47+ x = LabelTensor (torch .rand ((100 , 20 , 5 )), ["a" , "b" , "c" , "d" , "e" ])
48+ pos = LabelTensor (torch .rand ((100 , 20 , 2 )), ["x" , "y" ])
49+ output_ = LabelTensor (torch .rand ((100 , 20 , 1 )), ["u" ])
50+ input_ = [
51+ KNNGraph (x = x [i ], pos = pos [i ], neighbours = 3 , edge_attr = True )
52+ for i in range (len (x ))
53+ ]
54+
55+
56+ class GraphProblemLT (AbstractProblem ):
57+ output_variables = ["u" ]
58+ input_variables = ["a" , "b" , "c" , "d" , "e" ]
59+ conditions = {"data" : Condition (input = input_ , target = output_ )}
60+
61+
3162model = FeedForward (2 , 1 )
3263
3364
65+ class Model (torch .nn .Module ):
66+
67+ def __init__ (self , * args , ** kwargs ):
68+ super ().__init__ (* args , ** kwargs )
69+ self .lift = torch .nn .Linear (5 , 10 )
70+ self .activation = torch .nn .Tanh ()
71+ self .output = torch .nn .Linear (10 , 1 )
72+
73+ self .conv = GCNConv (10 , 10 )
74+
75+ def forward (self , batch ):
76+
77+ x = batch .x
78+ edge_index = batch .edge_index
79+ for _ in range (1 ):
80+ y = self .lift (x )
81+ y = self .activation (y )
82+ y = self .conv (y , edge_index )
83+ y = self .activation (y )
84+ y = self .output (y )
85+ return y
86+
87+
88+ graph_model = Model ()
89+
90+
3491def test_constructor ():
3592 SupervisedSolver (problem = TensorProblem (), model = model )
3693 SupervisedSolver (problem = LabelTensorProblem (), model = model )
@@ -59,6 +116,24 @@ def test_solver_train(use_lt, batch_size, compile):
59116 assert isinstance (solver .model , OptimizedModule )
60117
61118
119+ @pytest .mark .parametrize ("batch_size" , [None , 1 , 5 , 20 ])
120+ @pytest .mark .parametrize ("use_lt" , [True , False ])
121+ def test_solver_train_graph (batch_size , use_lt ):
122+ problem = GraphProblemLT () if use_lt else GraphProblem ()
123+ solver = SupervisedSolver (problem = problem , model = graph_model , use_lt = use_lt )
124+ trainer = Trainer (
125+ solver = solver ,
126+ max_epochs = 2 ,
127+ accelerator = "cpu" ,
128+ batch_size = batch_size ,
129+ train_size = 1.0 ,
130+ test_size = 0.0 ,
131+ val_size = 0.0 ,
132+ )
133+
134+ trainer .train ()
135+
136+
62137@pytest .mark .parametrize ("use_lt" , [True , False ])
63138@pytest .mark .parametrize ("compile" , [True , False ])
64139def test_solver_validation (use_lt , compile ):
@@ -79,6 +154,24 @@ def test_solver_validation(use_lt, compile):
79154 assert isinstance (solver .model , OptimizedModule )
80155
81156
157+ @pytest .mark .parametrize ("batch_size" , [None , 1 , 5 , 20 ])
158+ @pytest .mark .parametrize ("use_lt" , [True , False ])
159+ def test_solver_validation_graph (batch_size , use_lt ):
160+ problem = GraphProblemLT () if use_lt else GraphProblem ()
161+ solver = SupervisedSolver (problem = problem , model = graph_model , use_lt = use_lt )
162+ trainer = Trainer (
163+ solver = solver ,
164+ max_epochs = 2 ,
165+ accelerator = "cpu" ,
166+ batch_size = batch_size ,
167+ train_size = 0.9 ,
168+ val_size = 0.1 ,
169+ test_size = 0.0 ,
170+ )
171+
172+ trainer .train ()
173+
174+
82175@pytest .mark .parametrize ("use_lt" , [True , False ])
83176@pytest .mark .parametrize ("compile" , [True , False ])
84177def test_solver_test (use_lt , compile ):
@@ -99,6 +192,24 @@ def test_solver_test(use_lt, compile):
99192 assert isinstance (solver .model , OptimizedModule )
100193
101194
195+ @pytest .mark .parametrize ("batch_size" , [None , 1 , 5 , 20 ])
196+ @pytest .mark .parametrize ("use_lt" , [True , False ])
197+ def test_solver_test_graph (batch_size , use_lt ):
198+ problem = GraphProblemLT () if use_lt else GraphProblem ()
199+ solver = SupervisedSolver (problem = problem , model = graph_model , use_lt = use_lt )
200+ trainer = Trainer (
201+ solver = solver ,
202+ max_epochs = 2 ,
203+ accelerator = "cpu" ,
204+ batch_size = batch_size ,
205+ train_size = 0.8 ,
206+ val_size = 0.1 ,
207+ test_size = 0.1 ,
208+ )
209+
210+ trainer .test ()
211+
212+
102213def test_train_load_restore ():
103214 dir = "tests/test_solver/tmp/"
104215 problem = LabelTensorProblem ()
0 commit comments