33
44from pina import LabelTensor , Condition
55from pina .domain import CartesianDomain
6+ from pina .condition import (
7+ GraphInputOutputCondition ,
8+ GraphInputEquationCondition ,
9+ )
610from pina .equation .equation_factory import FixedValue
11+ from pina .graph import RadiusGraph
12+ from torch_geometric .data import Data
13+ from pina .operator import laplacian
14+ from pina .equation .equation import Equation
715
8- example_domain = CartesianDomain ({'x' : [0 , 1 ], 'y' : [0 , 1 ]})
9- example_input_pts = LabelTensor (torch .tensor ([[0 , 0 , 0 ]]), ['x' , 'y' , 'z' ])
10- example_output_pts = LabelTensor (torch .tensor ([[1 , 2 ]]), ['a' , 'b' ])
16+ example_domain = CartesianDomain ({"x" : [0 , 1 ], "y" : [0 , 1 ]})
17+ example_input_pts = LabelTensor (torch .tensor ([[0 , 0 , 0 ]]), ["x" , "y" , "z" ])
18+ example_output_pts = LabelTensor (torch .tensor ([[1 , 2 ]]), ["a" , "b" ])
1119
1220
1321def test_init_inputoutput ():
1422 Condition (input_points = example_input_pts , output_points = example_output_pts )
1523 with pytest .raises (ValueError ):
1624 Condition (example_input_pts , example_output_pts )
1725 with pytest .raises (ValueError ):
18- Condition (input_points = 3. , output_points = ' example' )
26+ Condition (input_points = 3.0 , output_points = " example" )
1927 with pytest .raises (ValueError ):
2028 Condition (input_points = example_domain , output_points = example_domain )
2129
2230
23- test_init_inputoutput ()
24-
25-
2631def test_init_domainfunc ():
2732 Condition (domain = example_domain , equation = FixedValue (0.0 ))
2833 with pytest .raises (ValueError ):
2934 Condition (example_domain , FixedValue (0.0 ))
3035 with pytest .raises (ValueError ):
31- Condition (domain = 3. , equation = ' example' )
36+ Condition (domain = 3.0 , equation = " example" )
3237 with pytest .raises (ValueError ):
3338 Condition (domain = example_input_pts , equation = example_output_pts )
3439
@@ -38,6 +43,78 @@ def test_init_inputfunc():
3843 with pytest .raises (ValueError ):
3944 Condition (example_domain , FixedValue (0.0 ))
4045 with pytest .raises (ValueError ):
41- Condition (input_points = 3. , equation = ' example' )
46+ Condition (input_points = 3.0 , equation = " example" )
4247 with pytest .raises (ValueError ):
4348 Condition (input_points = example_domain , equation = example_output_pts )
49+
50+
51+ def test_graph_io_condition ():
52+ x = torch .rand (10 , 10 , 4 )
53+ pos = torch .rand (10 , 10 , 2 )
54+ y = torch .rand (10 , 10 , 2 )
55+ graph = [
56+ RadiusGraph (x = x_ , pos = pos_ , radius = 0.1 , build_edge_attr = True , y = y_ )
57+ for x_ , pos_ , y_ in zip (x , pos , y )
58+ ]
59+ condition = Condition (graph = graph )
60+ assert isinstance (condition , GraphInputOutputCondition )
61+ assert isinstance (condition .graph , list )
62+
63+ x = x [0 ]
64+ pos = pos [0 ]
65+ y = y [0 ]
66+ edge_index = graph [0 ].edge_index
67+ graph = Data (x = x , pos = pos , edge_index = edge_index , y = y )
68+ condition = Condition (graph = graph )
69+ assert isinstance (condition , GraphInputOutputCondition )
70+ assert isinstance (condition .graph , Data )
71+
72+
73+ def laplace_equation (input_ , output_ ):
74+ """
75+ Implementation of the laplace equation.
76+ """
77+ force_term = torch .sin (input_ .extract (["x" ]) * torch .pi ) * torch .sin (
78+ input_ .extract (["y" ]) * torch .pi
79+ )
80+ delta_u = laplacian (output_ .extract (["u" ]), input_ )
81+ return delta_u - force_term
82+
83+
84+ def test_graph_eq_condition ():
85+ def laplace (input_ , output_ ):
86+ """
87+ Implementation of the laplace equation.
88+ """
89+ force_term = torch .sin (input_ .extract (["x" ]) * torch .pi ) * torch .sin (
90+ input_ .extract (["y" ]) * torch .pi
91+ )
92+ delta_u = laplacian (output_ .extract (["u" ]), input_ )
93+ return delta_u - force_term
94+
95+ x = torch .rand (10 , 10 , 4 )
96+ pos = torch .rand (10 , 10 , 2 )
97+ graph = [
98+ RadiusGraph (
99+ x = x_ ,
100+ pos = pos_ ,
101+ radius = 0.1 ,
102+ build_edge_attr = True ,
103+ )
104+ for x_ , pos_ , in zip (
105+ x ,
106+ pos ,
107+ )
108+ ]
109+ laplace_equation = Equation (laplace )
110+ condition = Condition (graph = graph , equation = laplace_equation )
111+ assert isinstance (condition , GraphInputEquationCondition )
112+ assert isinstance (condition .graph , list )
113+
114+ x = x [0 ]
115+ pos = pos [0 ]
116+ edge_index = graph [0 ].edge_index
117+ graph = Data (x = x , pos = pos , edge_index = edge_index )
118+ condition = Condition (graph = graph , equation = laplace_equation )
119+ assert isinstance (condition , GraphInputEquationCondition )
120+ assert isinstance (condition .graph , Data )
0 commit comments