Skip to content

Commit 90f7d69

Browse files
author
Gerrit Ansmann
committed
Fixing problem with small integration steps with Solve IVP. Allowing backwards integration for Solve IVP.
1 parent e6e141b commit 90f7d69

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

jitcode/integrator_tools.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,12 @@ def set_params(self,*args):
9999
self.try_to_initiate()
100100

101101
def integrate(self,t):
102-
if self.backend.t < t:
103-
while self.backend.t < t:
104-
self.backend.step()
105-
if self.backend.status == "failed":
106-
raise UnsuccessfulIntegration
107-
self.kwargs["y0"] = self.backend.dense_output()(t)
108-
self.kwargs["t0"] = t
109-
elif self.backend.t > t:
110-
raise ValueError("Target time smaller than current time. Cannot integrate backwards in time")
102+
while self.backend.t < t or self.backend.t_old is None:
103+
self.backend.step()
104+
if self.backend.status == "failed":
105+
raise UnsuccessfulIntegration
106+
self.kwargs["y0"] = self.backend.dense_output()(t)
107+
self.kwargs["t0"] = t
111108
return self.kwargs["y0"]
112109

113110
def successful(self):

tests/test_integrator_tools.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,19 @@ def test_zero_integration(self):
8585
self.integrator.set_initial_value(initial)
8686
assert_allclose(initial,self.integrator.integrate(0))
8787

88-
def test_backwards_integration(self):
89-
self.initialise(f,jac)
90-
self.integrator.set_initial_value([0,1,2,3],0)
91-
with self.assertRaises(ValueError):
92-
self.integrator.integrate(-1)
93-
9488
def test_failed_integration(self):
9589
self.initialise(f_back,jac_back,atol=1e-12,rtol=1e-12)
9690
self.integrator.set_initial_value([0,1,2,3],0)
9791
with self.assertRaises(UnsuccessfulIntegration):
9892
self.integrator.integrate(1)
9993

94+
class TestSkeletonWithBackwards(TestSkeleton):
95+
def test_backwards_integration(self):
96+
self.initialise(f,jac)
97+
self.integrator.set_initial_value([0,1,2,3],0)
98+
with self.assertRaises(ValueError):
99+
self.integrator.integrate(-1)
100+
100101
class TestRK45(unittest.TestCase,TestSkeleton):
101102
def initialise(self,f,jac,**kwargs):
102103
self.integrator = IVP_wrapper("RK45",f,**kwargs)
@@ -113,19 +114,19 @@ class TestBDF(unittest.TestCase,TestSkeleton):
113114
def initialise(self,f,jac,**kwargs):
114115
self.integrator = IVP_wrapper("BDF",f,jac,**kwargs)
115116

116-
class TestRK45_no_interpolation(unittest.TestCase,TestSkeleton):
117+
class TestRK45_no_interpolation(unittest.TestCase,TestSkeletonWithBackwards):
117118
def initialise(self,f,jac,**kwargs):
118119
self.integrator = IVP_wrapper_no_interpolation("RK45",f,**kwargs)
119120

120-
class TestRK23_no_interpolation(unittest.TestCase,TestSkeleton):
121+
class TestRK23_no_interpolation(unittest.TestCase,TestSkeletonWithBackwards):
121122
def initialise(self,f,jac,**kwargs):
122123
self.integrator = IVP_wrapper_no_interpolation("RK45",f,**kwargs)
123124

124-
class TestRadau_no_interpolation(unittest.TestCase,TestSkeleton):
125+
class TestRadau_no_interpolation(unittest.TestCase,TestSkeletonWithBackwards):
125126
def initialise(self,f,jac,**kwargs):
126127
self.integrator = IVP_wrapper_no_interpolation("Radau",f,jac,**kwargs)
127128

128-
class TestBDF_no_interpolation(unittest.TestCase,TestSkeleton):
129+
class TestBDF_no_interpolation(unittest.TestCase,TestSkeletonWithBackwards):
129130
def initialise(self,f,jac,**kwargs):
130131
self.integrator = IVP_wrapper_no_interpolation("BDF",f,jac,**kwargs)
131132

@@ -137,17 +138,17 @@ def initialise(self,f,jac,**kwargs):
137138
def test_failed_integration(self):
138139
pass
139140

140-
class TestDopri5(unittest.TestCase,TestSkeleton):
141+
class TestDopri5(unittest.TestCase,TestSkeletonWithBackwards):
141142
def initialise(self,f,jac,**kwargs):
142143
self.integrator = ODE_wrapper(f)
143144
self.integrator.set_integrator("dopri5")
144145

145-
class TestDop853(unittest.TestCase,TestSkeleton):
146+
class TestDop853(unittest.TestCase,TestSkeletonWithBackwards):
146147
def initialise(self,f,jac,**kwargs):
147148
self.integrator = ODE_wrapper(f)
148149
self.integrator.set_integrator("dop853")
149150

150-
class TestLsoda(unittest.TestCase,TestSkeleton):
151+
class TestLsoda(unittest.TestCase,TestSkeletonWithBackwards):
151152
def initialise(self,f,jac,**kwargs):
152153
self.integrator = ODE_wrapper(f,jac)
153154
self.integrator.set_integrator("lsoda")
@@ -156,7 +157,7 @@ def initialise(self,f,jac,**kwargs):
156157
def test_failed_integration(self):
157158
pass
158159

159-
class TestVode(unittest.TestCase,TestSkeleton):
160+
class TestVode(unittest.TestCase,TestSkeletonWithBackwards):
160161
def initialise(self,f,jac,**kwargs):
161162
self.integrator = ODE_wrapper(f,jac)
162163
self.integrator.set_integrator("vode")

0 commit comments

Comments
 (0)