66from pina .solver import CausalPINN
77from pina .trainer import Trainer
88from pina .model import FeedForward
9- from pina .problem .zoo import (
10- DiffusionReactionProblem ,
11- InverseDiffusionReactionProblem ,
12- )
9+ from pina .problem .zoo import DiffusionReactionProblem
1310from pina .condition import (
1411 InputTargetCondition ,
1512 InputEquationCondition ,
@@ -31,8 +28,6 @@ class DummySpatialProblem(SpatialProblem):
3128# define problems
3229problem = DiffusionReactionProblem ()
3330problem .discretise_domain (50 )
34- inverse_problem = InverseDiffusionReactionProblem ()
35- inverse_problem .discretise_domain (50 )
3631
3732# add input-output condition to test supervised learning
3833input_pts = torch .rand (50 , len (problem .input_variables ))
@@ -45,7 +40,7 @@ class DummySpatialProblem(SpatialProblem):
4540model = FeedForward (len (problem .input_variables ), len (problem .output_variables ))
4641
4742
48- @pytest .mark .parametrize ("problem" , [problem , inverse_problem ])
43+ @pytest .mark .parametrize ("problem" , [problem ])
4944@pytest .mark .parametrize ("eps" , [100 , 100.1 ])
5045def test_constructor (problem , eps ):
5146 with pytest .raises (ValueError ):
@@ -59,7 +54,7 @@ def test_constructor(problem, eps):
5954 )
6055
6156
62- @pytest .mark .parametrize ("problem" , [problem , inverse_problem ])
57+ @pytest .mark .parametrize ("problem" , [problem ])
6358@pytest .mark .parametrize ("batch_size" , [None , 1 , 5 , 20 ])
6459@pytest .mark .parametrize ("compile" , [True , False ])
6560def test_solver_train (problem , batch_size , compile ):
@@ -79,7 +74,7 @@ def test_solver_train(problem, batch_size, compile):
7974 assert isinstance (solver .model , OptimizedModule )
8075
8176
82- @pytest .mark .parametrize ("problem" , [problem , inverse_problem ])
77+ @pytest .mark .parametrize ("problem" , [problem ])
8378@pytest .mark .parametrize ("batch_size" , [None , 1 , 5 , 20 ])
8479@pytest .mark .parametrize ("compile" , [True , False ])
8580def test_solver_validation (problem , batch_size , compile ):
@@ -99,7 +94,7 @@ def test_solver_validation(problem, batch_size, compile):
9994 assert isinstance (solver .model , OptimizedModule )
10095
10196
102- @pytest .mark .parametrize ("problem" , [problem , inverse_problem ])
97+ @pytest .mark .parametrize ("problem" , [problem ])
10398@pytest .mark .parametrize ("batch_size" , [None , 1 , 5 , 20 ])
10499@pytest .mark .parametrize ("compile" , [True , False ])
105100def test_solver_test (problem , batch_size , compile ):
@@ -119,7 +114,7 @@ def test_solver_test(problem, batch_size, compile):
119114 assert isinstance (solver .model , OptimizedModule )
120115
121116
122- @pytest .mark .parametrize ("problem" , [problem , inverse_problem ])
117+ @pytest .mark .parametrize ("problem" , [problem ])
123118def test_train_load_restore (problem ):
124119 dir = "tests/test_solver/tmp"
125120 problem = problem
0 commit comments