66import numpy as np
77from typing import Literal , Union
88
9+ from bayesflow .adapters import Adapter
910from bayesflow .types import Tensor
1011from bayesflow .utils import filter_kwargs
12+ from bayesflow .utils .logging import warning
1113
1214from . import logging
1315
@@ -265,6 +267,53 @@ def body(_loop_var, _loop_state):
265267 return state
266268
267269
270+ def integrate_scipy (
271+ fn : Callable ,
272+ state : dict [str , ArrayLike ],
273+ start_time : ArrayLike ,
274+ stop_time : ArrayLike ,
275+ scipy_kwargs : dict | None = None ,
276+ ** kwargs ,
277+ ) -> dict [str , ArrayLike ]:
278+ import scipy .integrate
279+
280+ scipy_kwargs = scipy_kwargs or {}
281+ keys = list (state .keys ())
282+ # convert to tensor before determining the shape in case a number was passed
283+ shapes = keras .tree .map_structure (lambda x : keras .ops .shape (keras .ops .convert_to_tensor (x )), state )
284+ adapter = Adapter ().concatenate (keys , into = "x" , axis = - 1 ).convert_dtype (np .float32 , np .float64 )
285+
286+ def state_to_vector (state ):
287+ state = keras .tree .map_structure (keras .ops .convert_to_numpy , state )
288+ # flatten state
289+ state = keras .tree .map_structure (lambda x : keras .ops .reshape (x , (- 1 ,)), state )
290+ # apply concatenation
291+ x = adapter .forward (state )["x" ]
292+ return x
293+
294+ def vector_to_state (x ):
295+ state = adapter .inverse ({"x" : x })
296+ state = {key : keras .ops .reshape (value , shapes [key ]) for key , value in state .items ()}
297+ state = keras .tree .map_structure (keras .ops .convert_to_tensor , state )
298+ return state
299+
300+ def scipy_wrapper_fn (time , x ):
301+ state = vector_to_state (x )
302+ time = keras .ops .convert_to_tensor (time , dtype = "float32" )
303+ deltas = fn (time , ** filter_kwargs (state , fn ))
304+ return state_to_vector (deltas )
305+
306+ result = scipy .integrate .solve_ivp (
307+ scipy_wrapper_fn ,
308+ (start_time , stop_time ),
309+ state_to_vector (state ),
310+ ** scipy_kwargs ,
311+ )
312+
313+ result = vector_to_state (result .y [:, - 1 ])
314+ return result
315+
316+
268317def integrate (
269318 fn : Callable ,
270319 state : dict [str , ArrayLike ],
@@ -282,6 +331,12 @@ def integrate(
282331 "Please provide start_time and stop_time for the integration, was "
283332 f"'start_time={ start_time } ', 'stop_time={ stop_time } '."
284333 )
334+ if method == "scipy" :
335+ if min_steps != 10 :
336+ warning ("Setting min_steps has no effect for method 'scipy'" )
337+ if max_steps != 10_000 :
338+ warning ("Setting max_steps has no effect for method 'scipy'" )
339+ return integrate_scipy (fn , state , start_time , stop_time , ** kwargs )
285340 return integrate_adaptive (fn , state , start_time , stop_time , min_steps , max_steps , method , ** kwargs )
286341 elif isinstance (steps , int ):
287342 if start_time is None or stop_time is None :
0 commit comments