Skip to content

Commit e8e2ac5

Browse files
committed
Merge pull request #3423 from jsbrittain/jax_gpu
JaxSolver fails when using GPU support with no input parameters
1 parent 37dfe89 commit e8e2ac5

File tree

3 files changed

+6
-1
lines changed

3 files changed

+6
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)
22

3+
## Bug fixes
4+
5+
- Fixed a bug where the JaxSolver would fails when using GPU support with no input parameters ([#3423](https://github.com/pybamm-team/PyBaMM/pull/3423))
6+
37
# [v23.9rc0](https://github.com/pybamm-team/PyBaMM/tree/v23.9rc0) - 2023-10-31
48

59
## Features

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@
154154
"navbar_end": ["theme-switcher", "navbar-icon-links"],
155155
# add Algolia to the persistent navbar, this removes the default search icon
156156
"navbar_persistent": "algolia-searchbox",
157+
"navigation_with_keys": False,
157158
"use_edit_page_button": True,
158159
"pygment_light_style": "xcode",
159160
"pygment_dark_style": "monokai",

pybamm/solvers/jax_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def _integrate(self, model, t_eval, inputs=None):
215215

216216
y = []
217217
platform = jax.lib.xla_bridge.get_backend().platform.casefold()
218-
if platform.startswith("cpu"):
218+
if len(inputs) <= 1 or platform.startswith("cpu"):
219219
# cpu execution runs faster when multithreaded
220220
async def solve_model_for_inputs():
221221
async def solve_model_async(inputs_v):

0 commit comments

Comments
 (0)