@@ -15,12 +15,18 @@ def test_supervised_tensor_collector():
1515 class SupervisedProblem (AbstractProblem ):
1616 output_variables = None
1717 conditions = {
18- 'data1' : Condition (input_points = torch .rand ((10 , 2 )),
19- output_points = torch .rand ((10 , 2 ))),
20- 'data2' : Condition (input_points = torch .rand ((20 , 2 )),
21- output_points = torch .rand ((20 , 2 ))),
22- 'data3' : Condition (input_points = torch .rand ((30 , 2 )),
23- output_points = torch .rand ((30 , 2 ))),
18+ "data1" : Condition (
19+ input_points = torch .rand ((10 , 2 )),
20+ output_points = torch .rand ((10 , 2 )),
21+ ),
22+ "data2" : Condition (
23+ input_points = torch .rand ((20 , 2 )),
24+ output_points = torch .rand ((20 , 2 )),
25+ ),
26+ "data3" : Condition (
27+ input_points = torch .rand ((30 , 2 )),
28+ output_points = torch .rand ((30 , 2 )),
29+ ),
2430 }
2531
2632 problem = SupervisedProblem ()
@@ -31,65 +37,58 @@ class SupervisedProblem(AbstractProblem):
3137
3238def test_pinn_collector ():
3339 def laplace_equation (input_ , output_ ):
34- force_term = (torch .sin (input_ .extract (['x' ]) * torch .pi ) *
35- torch .sin (input_ .extract (['y' ]) * torch .pi ))
36- delta_u = laplacian (output_ .extract (['u' ]), input_ )
40+ force_term = torch .sin (input_ .extract (["x" ]) * torch .pi ) * torch .sin (
41+ input_ .extract (["y" ]) * torch .pi
42+ )
43+ delta_u = laplacian (output_ .extract (["u" ]), input_ )
3744 return delta_u - force_term
3845
3946 my_laplace = Equation (laplace_equation )
40- in_ = LabelTensor (torch .tensor ([[0. , 1. ]], requires_grad = True ), ['x' , 'y' ])
41- out_ = LabelTensor (torch .tensor ([[0. ]], requires_grad = True ), ['u' ])
47+ in_ = LabelTensor (
48+ torch .tensor ([[0.0 , 1.0 ]], requires_grad = True ), ["x" , "y" ]
49+ )
50+ out_ = LabelTensor (torch .tensor ([[0.0 ]], requires_grad = True ), ["u" ])
4251
4352 class Poisson (SpatialProblem ):
44- output_variables = ['u' ]
45- spatial_domain = CartesianDomain ({'x' : [0 , 1 ], 'y' : [0 , 1 ]})
53+ output_variables = ["u" ]
54+ spatial_domain = CartesianDomain ({"x" : [0 , 1 ], "y" : [0 , 1 ]})
4655
4756 conditions = {
48- 'gamma1' :
49- Condition (domain = CartesianDomain ({
50- 'x' : [0 , 1 ],
51- 'y' : 1
52- }),
53- equation = FixedValue (0.0 )),
54- 'gamma2' :
55- Condition (domain = CartesianDomain ({
56- 'x' : [0 , 1 ],
57- 'y' : 0
58- }),
59- equation = FixedValue (0.0 )),
60- 'gamma3' :
61- Condition (domain = CartesianDomain ({
62- 'x' : 1 ,
63- 'y' : [0 , 1 ]
64- }),
65- equation = FixedValue (0.0 )),
66- 'gamma4' :
67- Condition (domain = CartesianDomain ({
68- 'x' : 0 ,
69- 'y' : [0 , 1 ]
70- }),
71- equation = FixedValue (0.0 )),
72- 'D' :
73- Condition (domain = CartesianDomain ({
74- 'x' : [0 , 1 ],
75- 'y' : [0 , 1 ]
76- }),
77- equation = my_laplace ),
78- 'data' :
79- Condition (input_points = in_ , output_points = out_ )
57+ "gamma1" : Condition (
58+ domain = CartesianDomain ({"x" : [0 , 1 ], "y" : 1 }),
59+ equation = FixedValue (0.0 ),
60+ ),
61+ "gamma2" : Condition (
62+ domain = CartesianDomain ({"x" : [0 , 1 ], "y" : 0 }),
63+ equation = FixedValue (0.0 ),
64+ ),
65+ "gamma3" : Condition (
66+ domain = CartesianDomain ({"x" : 1 , "y" : [0 , 1 ]}),
67+ equation = FixedValue (0.0 ),
68+ ),
69+ "gamma4" : Condition (
70+ domain = CartesianDomain ({"x" : 0 , "y" : [0 , 1 ]}),
71+ equation = FixedValue (0.0 ),
72+ ),
73+ "D" : Condition (
74+ domain = CartesianDomain ({"x" : [0 , 1 ], "y" : [0 , 1 ]}),
75+ equation = my_laplace ,
76+ ),
77+ "data" : Condition (input_points = in_ , output_points = out_ ),
8078 }
8179
8280 def poisson_sol (self , pts ):
83- return - (torch .sin (pts .extract (['x' ]) * torch .pi ) *
84- torch .sin (pts .extract (['y' ]) * torch .pi )) / (
85- 2 * torch .pi ** 2 )
81+ return - (
82+ torch .sin (pts .extract (["x" ]) * torch .pi )
83+ * torch .sin (pts .extract (["y" ]) * torch .pi )
84+ ) / (2 * torch .pi ** 2 )
8685
8786 truth_solution = poisson_sol
8887
8988 problem = Poisson ()
90- boundaries = [' gamma1' , ' gamma2' , ' gamma3' , ' gamma4' ]
91- problem .discretise_domain (10 , ' grid' , domains = boundaries )
92- problem .discretise_domain (10 , ' grid' , domains = 'D' )
89+ boundaries = [" gamma1" , " gamma2" , " gamma3" , " gamma4" ]
90+ problem .discretise_domain (10 , " grid" , domains = boundaries )
91+ problem .discretise_domain (10 , " grid" , domains = "D" )
9392
9493 collector = Collector (problem )
9594 collector .store_fixed_data ()
@@ -98,31 +97,34 @@ def poisson_sol(self, pts):
9897 for k , v in problem .conditions .items ():
9998 if isinstance (v , InputOutputPointsCondition ):
10099 assert list (collector .data_collections [k ].keys ()) == [
101- 'input_points' , 'output_points' ]
100+ "input_points" ,
101+ "output_points" ,
102+ ]
102103
103104 for k , v in problem .conditions .items ():
104105 if isinstance (v , DomainEquationCondition ):
105106 assert list (collector .data_collections [k ].keys ()) == [
106- 'input_points' , 'equation' ]
107+ "input_points" ,
108+ "equation" ,
109+ ]
107110
108111
109112def test_supervised_graph_collector ():
110113 pos = torch .rand ((100 , 3 ))
111114 x = [torch .rand ((100 , 3 )) for _ in range (10 )]
112- graph_list_1 = RadiusGraph (pos = pos , x = x , build_edge_attr = True , r = .4 )
115+ graph_list_1 = [ RadiusGraph (pos = pos , radius = 0.4 , x = x_ ) for x_ in x ]
113116 out_1 = torch .rand ((10 , 100 , 3 ))
117+
114118 pos = torch .rand ((50 , 3 ))
115119 x = [torch .rand ((50 , 3 )) for _ in range (10 )]
116- graph_list_2 = RadiusGraph (pos = pos , x = x , build_edge_attr = True , r = .4 )
120+ graph_list_2 = [ RadiusGraph (pos = pos , radius = 0.4 , x = x_ ) for x_ in x ]
117121 out_2 = torch .rand ((10 , 50 , 3 ))
118122
119123 class SupervisedProblem (AbstractProblem ):
120124 output_variables = None
121125 conditions = {
122- 'data1' : Condition (input_points = graph_list_1 ,
123- output_points = out_1 ),
124- 'data2' : Condition (input_points = graph_list_2 ,
125- output_points = out_2 ),
126+ "data1" : Condition (input_points = graph_list_1 , output_points = out_1 ),
127+ "data2" : Condition (input_points = graph_list_2 , output_points = out_2 ),
126128 }
127129
128130 problem = SupervisedProblem ()
0 commit comments