Skip to content

Commit f05cae2

Browse files
bug: use direct casadi bspline function for 1D & 2D cubic interp (#4572)
* bug: use direct casadi bspline function for 1D cubic interp * bug: use direct bspline func for 2d cubic interp
1 parent 9a479cf commit f05cae2

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

src/pybamm/expression_tree/operations/convert_to_casadi.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import casadi
88
import numpy as np
99
from scipy import special
10+
from scipy import interpolate
1011

1112

1213
class CasadiConverter:
@@ -165,6 +166,18 @@ def _convert(self, symbol, t, y, y_dot, inputs):
165166
# for some reason, pybamm.Interpolant always returns a column vector, so match that
166167
test = test.T
167168
return test
169+
elif solver == "bspline":
170+
bspline = interpolate.make_interp_spline(
171+
symbol.x[0], symbol.y, k=3
172+
)
173+
knots = [bspline.t]
174+
coeffs = bspline.c.flatten()
175+
degree = [bspline.k]
176+
m = len(coeffs) // len(symbol.x[0])
177+
f = casadi.Function.bspline(
178+
symbol.name, knots, coeffs, degree, m
179+
)
180+
return f(converted_children[0])
168181
else:
169182
return casadi.interpolant(
170183
"LUT", solver, symbol.x, symbol.y.flatten()
@@ -176,6 +189,20 @@ def _convert(self, symbol, t, y, y_dot, inputs):
176189
symbol.y.ravel(order="F"),
177190
converted_children,
178191
)
192+
elif solver == "bspline" and len(converted_children) == 2:
193+
bspline = interpolate.RectBivariateSpline(
194+
symbol.x[0], symbol.x[1], symbol.y
195+
)
196+
[tx, ty, c] = bspline.tck
197+
[kx, ky] = bspline.degrees
198+
knots = [tx, ty]
199+
coeffs = c
200+
degree = [kx, ky]
201+
m = 1
202+
f = casadi.Function.bspline(
203+
symbol.name, knots, coeffs, degree, m
204+
)
205+
return f(casadi.hcat(converted_children).T).T
179206
else:
180207
LUT = casadi.interpolant(
181208
"LUT", solver, symbol.x, symbol.y.ravel(order="F")

0 commit comments

Comments
 (0)