77import casadi
88import numpy as np
99from scipy import special
10+ from scipy import interpolate
1011
1112
1213class CasadiConverter :
@@ -165,6 +166,18 @@ def _convert(self, symbol, t, y, y_dot, inputs):
165166 # for some reason, pybamm.Interpolant always returns a column vector, so match that
166167 test = test .T
167168 return test
169+ elif solver == "bspline" :
170+ bspline = interpolate .make_interp_spline (
171+ symbol .x [0 ], symbol .y , k = 3
172+ )
173+ knots = [bspline .t ]
174+ coeffs = bspline .c .flatten ()
175+ degree = [bspline .k ]
176+ m = len (coeffs ) // len (symbol .x [0 ])
177+ f = casadi .Function .bspline (
178+ symbol .name , knots , coeffs , degree , m
179+ )
180+ return f (converted_children [0 ])
168181 else :
169182 return casadi .interpolant (
170183 "LUT" , solver , symbol .x , symbol .y .flatten ()
@@ -176,6 +189,20 @@ def _convert(self, symbol, t, y, y_dot, inputs):
176189 symbol .y .ravel (order = "F" ),
177190 converted_children ,
178191 )
192+ elif solver == "bspline" and len (converted_children ) == 2 :
193+ bspline = interpolate .RectBivariateSpline (
194+ symbol .x [0 ], symbol .x [1 ], symbol .y
195+ )
196+ [tx , ty , c ] = bspline .tck
197+ [kx , ky ] = bspline .degrees
198+ knots = [tx , ty ]
199+ coeffs = c
200+ degree = [kx , ky ]
201+ m = 1
202+ f = casadi .Function .bspline (
203+ symbol .name , knots , coeffs , degree , m
204+ )
205+ return f (casadi .hcat (converted_children ).T ).T
179206 else :
180207 LUT = casadi .interpolant (
181208 "LUT" , solver , symbol .x , symbol .y .ravel (order = "F" )
0 commit comments