Skip to content

Commit 05394a2

Browse files
committed
dtype checking during exponential euler method
1 parent e0ee142 commit 05394a2

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

brainpy/_src/integrators/ode/exponential.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@
106106
"""
107107

108108
from functools import wraps
109+
110+
import jax.numpy as jnp
111+
109112
from brainpy import errors
110113
from brainpy._src import math as bm
111114
from brainpy._src.integrators import constants as C, utils, joint_eq
@@ -356,6 +359,9 @@ def _build_integrator(self, eq):
356359
# integration function
357360
def integral(*args, **kwargs):
358361
assert len(args) > 0
362+
if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
363+
raise ValueError('The input data type should be float32, float64, float16, or bfloat16 when using Exponential Euler method.'
364+
f'But we got {args[0].dtype}.')
359365
dt = kwargs.pop(C.DT, self.dt)
360366
linear, derivative = value_and_grad(*args, **kwargs)
361367
phi = bm.exprel(dt * linear)

0 commit comments

Comments
 (0)