Replies: 1 comment 2 replies
-
Does this code work for you? # Create model
model = pybamm.BaseModel()
model.convert_to_format = "jax"
var = pybamm.Variable("var")
model.rhs = {var: 0.1 * var}
model.initial_conditions = {var: 1.0}
# No need to set parameters; can use base discretisation (no spatial operators)
for method in ["RK45", "BDF"]:
# Solve
solver = pybamm.JaxSolver(method=method, rtol=1e-8, atol=1e-8)
t_eval = np.linspace(0, 1, 80)
solution = solver.solve(model, t_eval) |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Using JaxSolver with SPM gave me this error:
raise TracerArrayConversionError(self) jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[1,1])>with<DynamicJaxprTrace(level=1/2)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)
I used method=BDF, because for some reason the default one does not work. I had to use 'model.events=[]' to remove events before attempting to solve using BDF. What should I do? I am trying to get PyBaMM running on a GPU, and thought Jax could help me with it (this was run on normal CPU, though).
Beta Was this translation helpful? Give feedback.
All reactions