Skip to content

Commit dc4855c

Browse files
Merge pull request #2907 from pybamm-team/issue-2670-interp2d
Issue 2670 interp2d
2 parents 9b4490b + 9ebbb11 commit dc4855c

File tree

9 files changed

+68
-269
lines changed

9 files changed

+68
-269
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
- PyBaMM is now supported on Python `3.10` and `3.11` ([#2435](https://github.com/pybamm-team/PyBaMM/pull/2435))
88
- Updated to casadi 3.6, which required some changes to the casadi integrator. ([#2859](https://github.com/pybamm-team/PyBaMM/pull/2859))
99

10+
# Optimizations
11+
12+
- Fixed deprecated `interp2d` method by switching to `xarray.DataArray` as the backend for `ProcessedVariable` ([#2907](https://github.com/pybamm-team/PyBaMM/pull/2907))
13+
1014
## Bug fixes
1115

1216
- Parameter sets can now contain the key "chemistry", and will ignore its value (this previously would give errors in some cases) ([#2901](https://github.com/pybamm-team/PyBaMM/pull/2901))

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ imageio>=2.9.0
1010
jupyter # For example notebooks
1111
pybtex
1212
sympy >= 1.8
13+
xarray
1314
# Note: Matplotlib is loaded for debug plots but to ensure pybamm runs
1415
# on systems without an attached display it should never be imported
1516
# outside of plot() methods.

pybamm/solvers/processed_variable.py

Lines changed: 23 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import numbers
66
import numpy as np
77
import pybamm
8-
import scipy.interpolate as interp
98
from scipy.integrate import cumulative_trapezoid
9+
import xarray as xr
1010

1111

1212
class ProcessedVariable(object):
@@ -131,18 +131,7 @@ def initialise_0D(self):
131131
)
132132

133133
# set up interpolation
134-
if len(self.t_pts) == 1:
135-
# Variable is just a scalar value, but we need to create a callable
136-
# function to be consistent with other processed variables
137-
self._interpolation_function = Interpolant0D(entries)
138-
else:
139-
self._interpolation_function = interp.interp1d(
140-
self.t_pts,
141-
entries,
142-
kind="linear",
143-
fill_value=np.nan,
144-
bounds_error=False,
145-
)
134+
self._xr_data_array = xr.DataArray(entries, coords=[("t", self.t_pts)])
146135

147136
self.entries = entries
148137
self.dimensions = 0
@@ -211,22 +200,10 @@ def initialise_1D(self, fixed_t=False):
211200
self.first_dim_pts = edges
212201

213202
# set up interpolation
214-
if len(self.t_pts) == 1:
215-
# function of space only
216-
self._interpolation_function = Interpolant1D(
217-
pts_for_interp, entries_for_interp
218-
)
219-
else:
220-
# function of space and time. Note that the order of 't' and 'space'
221-
# is the reverse of what you'd expect
222-
self._interpolation_function = interp.interp2d(
223-
self.t_pts,
224-
pts_for_interp,
225-
entries_for_interp,
226-
kind="linear",
227-
fill_value=np.nan,
228-
bounds_error=False,
229-
)
203+
self._xr_data_array = xr.DataArray(
204+
entries_for_interp,
205+
coords=[(self.first_dimension, pts_for_interp), ("t", self.t_pts)],
206+
)
230207

231208
def initialise_2D(self):
232209
"""
@@ -362,21 +339,14 @@ def initialise_2D(self):
362339
self.second_dim_pts = second_dim_edges
363340

364341
# set up interpolation
365-
if len(self.t_pts) == 1:
366-
# function of space only. Note the order of the points is the reverse
367-
# of what you'd expect
368-
self._interpolation_function = Interpolant2D(
369-
first_dim_pts_for_interp, second_dim_pts_for_interp, entries_for_interp
370-
)
371-
else:
372-
# function of space and time.
373-
self._interpolation_function = interp.RegularGridInterpolator(
374-
(first_dim_pts_for_interp, second_dim_pts_for_interp, self.t_pts),
375-
entries_for_interp,
376-
method="linear",
377-
fill_value=np.nan,
378-
bounds_error=False,
379-
)
342+
self._xr_data_array = xr.DataArray(
343+
entries_for_interp,
344+
coords={
345+
self.first_dimension: first_dim_pts_for_interp,
346+
self.second_dimension: second_dim_pts_for_interp,
347+
"t": self.t_pts,
348+
},
349+
)
380350

381351
def initialise_2D_scikit_fem(self):
382352
y_sol = self.mesh.edges["y"]
@@ -411,74 +381,21 @@ def initialise_2D_scikit_fem(self):
411381
self.second_dim_pts = z_sol
412382

413383
# set up interpolation
414-
if len(self.t_pts) == 1:
415-
# function of space only. Note the order of the points is the reverse
416-
# of what you'd expect
417-
self._interpolation_function = Interpolant2D(
418-
self.first_dim_pts, self.second_dim_pts, entries
419-
)
420-
else:
421-
# function of space and time.
422-
self._interpolation_function = interp.RegularGridInterpolator(
423-
(self.first_dim_pts, self.second_dim_pts, self.t_pts),
424-
entries,
425-
method="linear",
426-
fill_value=np.nan,
427-
bounds_error=False,
428-
)
384+
self._xr_data_array = xr.DataArray(
385+
entries,
386+
coords={"y": y_sol, "z": z_sol, "t": self.t_pts},
387+
)
429388

430389
def __call__(self, t=None, x=None, r=None, y=None, z=None, R=None, warn=True):
431390
"""
432391
Evaluate the variable at arbitrary *dimensional* t (and x, r, y, z and/or R),
433392
using interpolation
434393
"""
435-
# If t is None and there is only one value of time in the soluton (i.e.
436-
# the solution is independent of time) then we set t equal to the value
437-
# stored in the solution. If the variable is constant (doesn't depend on
438-
# time) evaluate arbitrarily at the first value of t. Otherwise, raise
439-
# an error
440-
if t is None:
441-
if len(self.t_pts) == 1:
442-
t = self.t_pts
443-
elif len(self.base_variables) == 1 and self.base_variables[0].is_constant():
444-
t = self.t_pts[0]
445-
else:
446-
raise ValueError(
447-
"t cannot be None for variable {}".format(self.base_variables)
448-
)
449-
450-
# Call interpolant of correct spatial dimension
451-
if self.dimensions == 0:
452-
out = self._interpolation_function(t)
453-
elif self.dimensions == 1:
454-
out = self.call_1D(t, x, r, z, R)
455-
elif self.dimensions == 2:
456-
out = self.call_2D(t, x, r, y, z, R)
457-
if warn is True and np.isnan(out).any():
458-
pybamm.logger.warning(
459-
"Calling variable outside interpolation range (returns 'nan')"
460-
)
461-
return out
462-
463-
def call_1D(self, t, x, r, z, R):
464-
"""Evaluate a 1D variable"""
465-
spatial_var = eval_dimension_name(self.first_dimension, x, r, None, z, R)
466-
return self._interpolation_function(t, spatial_var)
467-
468-
def call_2D(self, t, x, r, y, z, R):
469-
"""Evaluate a 2D variable"""
470-
first_dim = eval_dimension_name(self.first_dimension, x, r, y, z, R)
471-
second_dim = eval_dimension_name(self.second_dimension, x, r, y, z, R)
472-
if isinstance(first_dim, np.ndarray):
473-
if isinstance(second_dim, np.ndarray) and isinstance(t, np.ndarray):
474-
first_dim = first_dim[:, np.newaxis, np.newaxis]
475-
second_dim = second_dim[:, np.newaxis]
476-
elif isinstance(second_dim, np.ndarray) or isinstance(t, np.ndarray):
477-
first_dim = first_dim[:, np.newaxis]
478-
else:
479-
if isinstance(second_dim, np.ndarray) and isinstance(t, np.ndarray):
480-
second_dim = second_dim[:, np.newaxis]
481-
return self._interpolation_function((first_dim, second_dim, t))
394+
kwargs = {"t": t, "x": x, "r": r, "y": y, "z": z, "R": R}
395+
# Remove any None arguments
396+
kwargs = {key: value for key, value in kwargs.items() if value is not None}
397+
# Use xarray interpolation, return numpy array
398+
return self._xr_data_array.interp(**kwargs).values
482399

483400
@property
484401
def data(self):
@@ -564,79 +481,3 @@ def initialise_sensitivity_explicit_forward(self):
564481

565482
# Save attribute
566483
self._sensitivities = sensitivities
567-
568-
569-
class Interpolant0D:
570-
def __init__(self, entries):
571-
self.entries = entries
572-
573-
def __call__(self, t):
574-
return self.entries
575-
576-
577-
class Interpolant1D:
578-
def __init__(self, pts_for_interp, entries_for_interp):
579-
self.interpolant = interp.interp1d(
580-
pts_for_interp,
581-
entries_for_interp[:, 0],
582-
kind="linear",
583-
fill_value=np.nan,
584-
bounds_error=False,
585-
)
586-
587-
def __call__(self, t, z):
588-
if isinstance(z, np.ndarray):
589-
return self.interpolant(z)[:, np.newaxis]
590-
else:
591-
return self.interpolant(z)
592-
593-
594-
class Interpolant2D:
595-
def __init__(
596-
self, first_dim_pts_for_interp, second_dim_pts_for_interp, entries_for_interp
597-
):
598-
self.interpolant = interp.interp2d(
599-
second_dim_pts_for_interp,
600-
first_dim_pts_for_interp,
601-
entries_for_interp[:, :, 0],
602-
kind="linear",
603-
fill_value=np.nan,
604-
bounds_error=False,
605-
)
606-
607-
def __call__(self, input):
608-
"""
609-
Calls and returns a 2D interpolant of the correct shape depending on the
610-
shape of the input
611-
"""
612-
first_dim, second_dim, _ = input
613-
if isinstance(first_dim, np.ndarray) and isinstance(second_dim, np.ndarray):
614-
first_dim = first_dim[:, 0, 0]
615-
second_dim = second_dim[:, 0]
616-
return self.interpolant(second_dim, first_dim)
617-
elif isinstance(first_dim, np.ndarray):
618-
first_dim = first_dim[:, 0]
619-
return self.interpolant(second_dim, first_dim)[:, 0]
620-
elif isinstance(second_dim, np.ndarray):
621-
second_dim = second_dim[:, 0]
622-
return self.interpolant(second_dim, first_dim)
623-
else:
624-
return self.interpolant(second_dim, first_dim)[0]
625-
626-
627-
def eval_dimension_name(name, x, r, y, z, R):
628-
if name == "x":
629-
out = x
630-
elif name == "r":
631-
out = r
632-
elif name == "y":
633-
out = y
634-
elif name == "z":
635-
out = z
636-
elif name == "R":
637-
out = R
638-
639-
if out is None:
640-
raise ValueError("inputs {} cannot be None".format(name))
641-
else:
642-
return out

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ casadi >= 3.6.0
88
imageio>=2.9.0
99
pybtex>=0.24.0
1010
sympy >= 1.8
11+
xarray
1112
bpx
1213
tqdm
1314
# Note: Matplotlib is loaded for debug plots but to ensure pybamm runs

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def compile_KLU():
214214
"importlib-metadata",
215215
"pybtex>=0.24.0",
216216
"sympy>=1.8",
217+
"xarray",
217218
"bpx",
218219
"tqdm",
219220
# Note: Matplotlib is loaded for debug plots, but to ensure pybamm runs

tests/integration/test_models/standard_output_comparison.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def __init__(self, time, solutions):
146146
def test_all(self):
147147
self.compare("Negative particle concentration [mol.m-3]")
148148
self.compare("Positive particle concentration [mol.m-3]")
149-
self.compare("Negative particle flux [mol.m-2.s-1]", rtol=0.05)
150-
self.compare("Positive particle flux [mol.m-2.s-1]", rtol=0.05)
149+
self.compare("Negative particle flux [mol.m-2.s-1]", atol=1e-7, rtol=0.05)
150+
self.compare("Positive particle flux [mol.m-2.s-1]", atol=1e-7, rtol=0.05)
151151

152152

153153
class PorosityComparison(BaseOutputComparison):

tests/unit/test_models/test_submodels/test_effective_current_collector.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,18 @@ def test_get_processed_variables(self):
7575
# for each current collector model
7676
for model in models[1:]:
7777
solution = model.default_solver.solve(model)
78-
vars = model.post_process(solution, param, V, I)
78+
variables = model.post_process(solution, param, V, I)
7979
pts = np.array([0.1, 0.5, 0.9]) * min(
8080
param.evaluate(model.param.L_y), param.evaluate(model.param.L_z)
8181
)
82-
for var, processed_var in vars.items():
82+
for var, processed_var in variables.items():
8383
if "Voltage [V]" in var:
8484
processed_var(t=solution_1D.t[5])
8585
else:
86-
processed_var(t=solution_1D.t[5], y=pts, z=pts)
86+
if model.options["dimensionality"] == 1:
87+
processed_var(t=solution_1D.t[5], z=pts)
88+
else:
89+
processed_var(t=solution_1D.t[5], y=pts, z=pts)
8790

8891

8992
if __name__ == "__main__":

tests/unit/test_plotting/test_quick_plot.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -329,19 +329,15 @@ def test_loqs_spme(self):
329329
)
330330
quick_plot.plot(0)
331331

332-
qp_data = (
333-
quick_plot.plots[("Electrolyte concentration [mol.m-3]",)][0][
334-
0
335-
].get_ydata(),
336-
)[0]
332+
qp_data = quick_plot.plots[("Electrolyte concentration [mol.m-3]",)][0][
333+
0
334+
].get_ydata()
337335
np.testing.assert_array_almost_equal(qp_data, c_e[:, 0])
338-
quick_plot.slider_update(t_eval[-1] / scale)
339336

340-
qp_data = (
341-
quick_plot.plots[("Electrolyte concentration [mol.m-3]",)][0][
342-
0
343-
].get_ydata(),
344-
)[0][:, 0]
337+
quick_plot.slider_update(t_eval[-1] / scale)
338+
qp_data = quick_plot.plots[("Electrolyte concentration [mol.m-3]",)][0][
339+
0
340+
].get_ydata()
345341
np.testing.assert_array_almost_equal(qp_data, c_e[:, 1])
346342

347343
# test quick plot of particle for spme

0 commit comments

Comments
 (0)