|
5 | 5 | import numbers |
6 | 6 | import numpy as np |
7 | 7 | import pybamm |
8 | | -import scipy.interpolate as interp |
9 | 8 | from scipy.integrate import cumulative_trapezoid |
| 9 | +import xarray as xr |
10 | 10 |
|
11 | 11 |
|
12 | 12 | class ProcessedVariable(object): |
@@ -131,18 +131,7 @@ def initialise_0D(self): |
131 | 131 | ) |
132 | 132 |
|
133 | 133 | # set up interpolation |
134 | | - if len(self.t_pts) == 1: |
135 | | - # Variable is just a scalar value, but we need to create a callable |
136 | | - # function to be consistent with other processed variables |
137 | | - self._interpolation_function = Interpolant0D(entries) |
138 | | - else: |
139 | | - self._interpolation_function = interp.interp1d( |
140 | | - self.t_pts, |
141 | | - entries, |
142 | | - kind="linear", |
143 | | - fill_value=np.nan, |
144 | | - bounds_error=False, |
145 | | - ) |
| 134 | + self._xr_data_array = xr.DataArray(entries, coords=[("t", self.t_pts)]) |
146 | 135 |
|
147 | 136 | self.entries = entries |
148 | 137 | self.dimensions = 0 |
@@ -211,22 +200,10 @@ def initialise_1D(self, fixed_t=False): |
211 | 200 | self.first_dim_pts = edges |
212 | 201 |
|
213 | 202 | # set up interpolation |
214 | | - if len(self.t_pts) == 1: |
215 | | - # function of space only |
216 | | - self._interpolation_function = Interpolant1D( |
217 | | - pts_for_interp, entries_for_interp |
218 | | - ) |
219 | | - else: |
220 | | - # function of space and time. Note that the order of 't' and 'space' |
221 | | - # is the reverse of what you'd expect |
222 | | - self._interpolation_function = interp.interp2d( |
223 | | - self.t_pts, |
224 | | - pts_for_interp, |
225 | | - entries_for_interp, |
226 | | - kind="linear", |
227 | | - fill_value=np.nan, |
228 | | - bounds_error=False, |
229 | | - ) |
| 203 | + self._xr_data_array = xr.DataArray( |
| 204 | + entries_for_interp, |
| 205 | + coords=[(self.first_dimension, pts_for_interp), ("t", self.t_pts)], |
| 206 | + ) |
230 | 207 |
|
231 | 208 | def initialise_2D(self): |
232 | 209 | """ |
@@ -362,21 +339,14 @@ def initialise_2D(self): |
362 | 339 | self.second_dim_pts = second_dim_edges |
363 | 340 |
|
364 | 341 | # set up interpolation |
365 | | - if len(self.t_pts) == 1: |
366 | | - # function of space only. Note the order of the points is the reverse |
367 | | - # of what you'd expect |
368 | | - self._interpolation_function = Interpolant2D( |
369 | | - first_dim_pts_for_interp, second_dim_pts_for_interp, entries_for_interp |
370 | | - ) |
371 | | - else: |
372 | | - # function of space and time. |
373 | | - self._interpolation_function = interp.RegularGridInterpolator( |
374 | | - (first_dim_pts_for_interp, second_dim_pts_for_interp, self.t_pts), |
375 | | - entries_for_interp, |
376 | | - method="linear", |
377 | | - fill_value=np.nan, |
378 | | - bounds_error=False, |
379 | | - ) |
| 342 | + self._xr_data_array = xr.DataArray( |
| 343 | + entries_for_interp, |
| 344 | + coords={ |
| 345 | + self.first_dimension: first_dim_pts_for_interp, |
| 346 | + self.second_dimension: second_dim_pts_for_interp, |
| 347 | + "t": self.t_pts, |
| 348 | + }, |
| 349 | + ) |
380 | 350 |
|
381 | 351 | def initialise_2D_scikit_fem(self): |
382 | 352 | y_sol = self.mesh.edges["y"] |
@@ -411,74 +381,21 @@ def initialise_2D_scikit_fem(self): |
411 | 381 | self.second_dim_pts = z_sol |
412 | 382 |
|
413 | 383 | # set up interpolation |
414 | | - if len(self.t_pts) == 1: |
415 | | - # function of space only. Note the order of the points is the reverse |
416 | | - # of what you'd expect |
417 | | - self._interpolation_function = Interpolant2D( |
418 | | - self.first_dim_pts, self.second_dim_pts, entries |
419 | | - ) |
420 | | - else: |
421 | | - # function of space and time. |
422 | | - self._interpolation_function = interp.RegularGridInterpolator( |
423 | | - (self.first_dim_pts, self.second_dim_pts, self.t_pts), |
424 | | - entries, |
425 | | - method="linear", |
426 | | - fill_value=np.nan, |
427 | | - bounds_error=False, |
428 | | - ) |
| 384 | + self._xr_data_array = xr.DataArray( |
| 385 | + entries, |
| 386 | + coords={"y": y_sol, "z": z_sol, "t": self.t_pts}, |
| 387 | + ) |
429 | 388 |
|
430 | 389 | def __call__(self, t=None, x=None, r=None, y=None, z=None, R=None, warn=True): |
431 | 390 | """ |
432 | 391 | Evaluate the variable at arbitrary *dimensional* t (and x, r, y, z and/or R), |
433 | 392 | using interpolation |
434 | 393 | """ |
435 | | - # If t is None and there is only one value of time in the soluton (i.e. |
436 | | - # the solution is independent of time) then we set t equal to the value |
437 | | - # stored in the solution. If the variable is constant (doesn't depend on |
438 | | - # time) evaluate arbitrarily at the first value of t. Otherwise, raise |
439 | | - # an error |
440 | | - if t is None: |
441 | | - if len(self.t_pts) == 1: |
442 | | - t = self.t_pts |
443 | | - elif len(self.base_variables) == 1 and self.base_variables[0].is_constant(): |
444 | | - t = self.t_pts[0] |
445 | | - else: |
446 | | - raise ValueError( |
447 | | - "t cannot be None for variable {}".format(self.base_variables) |
448 | | - ) |
449 | | - |
450 | | - # Call interpolant of correct spatial dimension |
451 | | - if self.dimensions == 0: |
452 | | - out = self._interpolation_function(t) |
453 | | - elif self.dimensions == 1: |
454 | | - out = self.call_1D(t, x, r, z, R) |
455 | | - elif self.dimensions == 2: |
456 | | - out = self.call_2D(t, x, r, y, z, R) |
457 | | - if warn is True and np.isnan(out).any(): |
458 | | - pybamm.logger.warning( |
459 | | - "Calling variable outside interpolation range (returns 'nan')" |
460 | | - ) |
461 | | - return out |
462 | | - |
463 | | - def call_1D(self, t, x, r, z, R): |
464 | | - """Evaluate a 1D variable""" |
465 | | - spatial_var = eval_dimension_name(self.first_dimension, x, r, None, z, R) |
466 | | - return self._interpolation_function(t, spatial_var) |
467 | | - |
468 | | - def call_2D(self, t, x, r, y, z, R): |
469 | | - """Evaluate a 2D variable""" |
470 | | - first_dim = eval_dimension_name(self.first_dimension, x, r, y, z, R) |
471 | | - second_dim = eval_dimension_name(self.second_dimension, x, r, y, z, R) |
472 | | - if isinstance(first_dim, np.ndarray): |
473 | | - if isinstance(second_dim, np.ndarray) and isinstance(t, np.ndarray): |
474 | | - first_dim = first_dim[:, np.newaxis, np.newaxis] |
475 | | - second_dim = second_dim[:, np.newaxis] |
476 | | - elif isinstance(second_dim, np.ndarray) or isinstance(t, np.ndarray): |
477 | | - first_dim = first_dim[:, np.newaxis] |
478 | | - else: |
479 | | - if isinstance(second_dim, np.ndarray) and isinstance(t, np.ndarray): |
480 | | - second_dim = second_dim[:, np.newaxis] |
481 | | - return self._interpolation_function((first_dim, second_dim, t)) |
| 394 | + kwargs = {"t": t, "x": x, "r": r, "y": y, "z": z, "R": R} |
| 395 | + # Remove any None arguments |
| 396 | + kwargs = {key: value for key, value in kwargs.items() if value is not None} |
| 397 | + # Use xarray interpolation, return numpy array |
| 398 | + return self._xr_data_array.interp(**kwargs).values |
482 | 399 |
|
483 | 400 | @property |
484 | 401 | def data(self): |
@@ -564,79 +481,3 @@ def initialise_sensitivity_explicit_forward(self): |
564 | 481 |
|
565 | 482 | # Save attribute |
566 | 483 | self._sensitivities = sensitivities |
567 | | - |
568 | | - |
569 | | -class Interpolant0D: |
570 | | - def __init__(self, entries): |
571 | | - self.entries = entries |
572 | | - |
573 | | - def __call__(self, t): |
574 | | - return self.entries |
575 | | - |
576 | | - |
577 | | -class Interpolant1D: |
578 | | - def __init__(self, pts_for_interp, entries_for_interp): |
579 | | - self.interpolant = interp.interp1d( |
580 | | - pts_for_interp, |
581 | | - entries_for_interp[:, 0], |
582 | | - kind="linear", |
583 | | - fill_value=np.nan, |
584 | | - bounds_error=False, |
585 | | - ) |
586 | | - |
587 | | - def __call__(self, t, z): |
588 | | - if isinstance(z, np.ndarray): |
589 | | - return self.interpolant(z)[:, np.newaxis] |
590 | | - else: |
591 | | - return self.interpolant(z) |
592 | | - |
593 | | - |
594 | | -class Interpolant2D: |
595 | | - def __init__( |
596 | | - self, first_dim_pts_for_interp, second_dim_pts_for_interp, entries_for_interp |
597 | | - ): |
598 | | - self.interpolant = interp.interp2d( |
599 | | - second_dim_pts_for_interp, |
600 | | - first_dim_pts_for_interp, |
601 | | - entries_for_interp[:, :, 0], |
602 | | - kind="linear", |
603 | | - fill_value=np.nan, |
604 | | - bounds_error=False, |
605 | | - ) |
606 | | - |
607 | | - def __call__(self, input): |
608 | | - """ |
609 | | - Calls and returns a 2D interpolant of the correct shape depending on the |
610 | | - shape of the input |
611 | | - """ |
612 | | - first_dim, second_dim, _ = input |
613 | | - if isinstance(first_dim, np.ndarray) and isinstance(second_dim, np.ndarray): |
614 | | - first_dim = first_dim[:, 0, 0] |
615 | | - second_dim = second_dim[:, 0] |
616 | | - return self.interpolant(second_dim, first_dim) |
617 | | - elif isinstance(first_dim, np.ndarray): |
618 | | - first_dim = first_dim[:, 0] |
619 | | - return self.interpolant(second_dim, first_dim)[:, 0] |
620 | | - elif isinstance(second_dim, np.ndarray): |
621 | | - second_dim = second_dim[:, 0] |
622 | | - return self.interpolant(second_dim, first_dim) |
623 | | - else: |
624 | | - return self.interpolant(second_dim, first_dim)[0] |
625 | | - |
626 | | - |
627 | | -def eval_dimension_name(name, x, r, y, z, R): |
628 | | - if name == "x": |
629 | | - out = x |
630 | | - elif name == "r": |
631 | | - out = r |
632 | | - elif name == "y": |
633 | | - out = y |
634 | | - elif name == "z": |
635 | | - out = z |
636 | | - elif name == "R": |
637 | | - out = R |
638 | | - |
639 | | - if out is None: |
640 | | - raise ValueError("inputs {} cannot be None".format(name)) |
641 | | - else: |
642 | | - return out |
0 commit comments