Skip to content

Commit 215252b

Browse files
Merge branch 'develop' into telemetry
2 parents 7df73c4 + f05cae2 commit 215252b

File tree

21 files changed

+599
-78
lines changed

21 files changed

+599
-78
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ ci:
44

55
repos:
66
- repo: https://github.com/astral-sh/ruff-pre-commit
7-
rev: "v0.7.1"
7+
rev: "v0.7.2"
88
hooks:
99
- id: ruff
1010
args: [--fix, --show-fixes]

CONTRIBUTING.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ You now have everything you need to start making changes!
4444

4545
### B. Writing your code
4646

47-
6. PyBaMM is developed in [Python](https://www.python.org)), and makes heavy use of [NumPy](https://numpy.org/) (see also [NumPy for MatLab users](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html) and [Python for R users](https://www.rebeccabarter.com/blog/2023-09-11-from_r_to_python)).
47+
6. PyBaMM is developed in [Python](https://www.python.org), and makes heavy use of [NumPy](https://numpy.org/).
4848
7. Make sure to follow our [coding style guidelines](#coding-style-guidelines).
4949
8. Commit your changes to your branch with [useful, descriptive commit messages](https://chris.beams.io/posts/git-commit/): Remember these are
5050
publicly visible and should still make sense a few months ahead in time.
@@ -116,8 +116,8 @@ PyBaMM provides a utility function `import_optional_dependency`, to check for th
116116

117117
Optional dependencies should never be imported at the module level, but always inside methods. For example:
118118

119-
```
120-
def use_pybtex(x,y,z):
119+
```python
120+
def use_pybtex(x, y, z):
121121
pybtex = import_optional_dependency("pybtex")
122122
...
123123
```

docs/source/examples/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ The notebooks are organised into subfolders, and can be viewed in the galleries
8787
notebooks/parameterization/change-input-current.ipynb
8888
notebooks/parameterization/parameter-values.ipynb
8989
notebooks/parameterization/parameterization.ipynb
90+
notebooks/parameterization/sensitivities_and_data_fitting.ipynb
9091

9192
.. nbgallery::
9293
:caption: Simulations and Experiments

docs/source/examples/notebooks/parameterization/sensitivities_and_data_fitting.ipynb

Lines changed: 327 additions & 0 deletions
Large diffs are not rendered by default.

src/pybamm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .expression_tree.parameter import Parameter, FunctionParameter
4040
from .expression_tree.scalar import Scalar
4141
from .expression_tree.variable import *
42+
from .expression_tree.coupled_variable import *
4243
from .expression_tree.independent_variable import *
4344
from .expression_tree.independent_variable import t
4445
from .expression_tree.vector import Vector

src/pybamm/discretisations/discretisation.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -500,8 +500,8 @@ def check_tab_conditions(self, symbol, bcs):
500500

501501
if domain != "current collector":
502502
raise pybamm.ModelError(
503-
f"""Boundary conditions can only be applied on the tabs in the domain
504-
'current collector', but {symbol} has domain {domain}"""
503+
"Boundary conditions can only be applied on the tabs in the domain "
504+
f"'current collector', but {symbol} has domain {domain}"
505505
)
506506
# Replace keys with "left" and "right" as appropriate for 1D meshes
507507
if isinstance(mesh, pybamm.SubMesh1D):
@@ -893,11 +893,9 @@ def _process_symbol(self, symbol):
893893
y_slices = self.y_slices[symbol]
894894
except KeyError as error:
895895
raise pybamm.ModelError(
896-
f"""
897-
No key set for variable '{symbol.name}'. Make sure it is included in either
898-
model.rhs or model.algebraic in an unmodified form
899-
(e.g. not Broadcasted)
900-
"""
896+
f"No key set for variable '{symbol.name}'. Make sure it is included in either "
897+
"model.rhs or model.algebraic in an unmodified form "
898+
"(e.g. not Broadcasted)"
901899
) from error
902900
# Add symbol's reference and multiply by the symbol's scale
903901
# so that the state vector is of order 1
@@ -938,6 +936,11 @@ def _process_symbol(self, symbol):
938936
if symbol._expected_size is None:
939937
symbol._expected_size = expected_size
940938
return symbol.create_copy()
939+
940+
elif isinstance(symbol, pybamm.CoupledVariable):
941+
new_symbol = self.process_symbol(symbol.children[0])
942+
return new_symbol
943+
941944
else:
942945
# Backup option: return the object
943946
return symbol

src/pybamm/expression_tree/averages.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,8 @@ def z_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
251251
# Symbol must have domain [] or ["current collector"]
252252
if symbol.domain not in [[], ["current collector"]]:
253253
raise pybamm.DomainError(
254-
f"""z-average only implemented in the 'current collector' domain,
255-
but symbol has domains {symbol.domain}"""
254+
"z-average only implemented in the 'current collector' domain, "
255+
f"but symbol has domains {symbol.domain}"
256256
)
257257
# If symbol doesn't have a domain, its average value is itself
258258
if symbol.domain == []:
@@ -285,8 +285,8 @@ def yz_average(symbol: pybamm.Symbol) -> pybamm.Symbol:
285285
# Symbol must have domain [] or ["current collector"]
286286
if symbol.domain not in [[], ["current collector"]]:
287287
raise pybamm.DomainError(
288-
f"""y-z-average only implemented in the 'current collector' domain,
289-
but symbol has domains {symbol.domain}"""
288+
"y-z-average only implemented in the 'current collector' domain, "
289+
f"but symbol has domains {symbol.domain}"
290290
)
291291
# If symbol doesn't have a domain, its average value is itself
292292
if symbol.domain == []:

src/pybamm/expression_tree/binary_operators.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _preprocess_binary(
3636
# Check both left and right are pybamm Symbols
3737
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)):
3838
raise NotImplementedError(
39-
f"""BinaryOperator not implemented for symbols of type {type(left)} and {type(right)}"""
39+
f"BinaryOperator not implemented for symbols of type {type(left)} and {type(right)}"
4040
)
4141

4242
# Do some broadcasting in special cases, to avoid having to do this manually
@@ -389,9 +389,9 @@ def _binary_jac(self, left_jac, right_jac):
389389
return left @ right_jac
390390
else:
391391
raise NotImplementedError(
392-
f"""jac of 'MatrixMultiplication' is only
393-
implemented for left of type 'pybamm.Array',
394-
not {left.__class__}"""
392+
f"jac of 'MatrixMultiplication' is only "
393+
"implemented for left of type 'pybamm.Array', "
394+
f"not {left.__class__}"
395395
)
396396

397397
def _binary_evaluate(self, left, right):
@@ -1541,8 +1541,8 @@ def source(
15411541

15421542
if left.domain != ["current collector"] or right.domain != ["current collector"]:
15431543
raise pybamm.DomainError(
1544-
f"""'source' only implemented in the 'current collector' domain,
1545-
but symbols have domains {left.domain} and {right.domain}"""
1544+
"'source' only implemented in the 'current collector' domain, "
1545+
f"but symbols have domains {left.domain} and {right.domain}"
15461546
)
15471547
if boundary:
15481548
return pybamm.BoundaryMass(right) @ left
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pybamm
2+
3+
from pybamm.type_definitions import DomainType
4+
5+
6+
class CoupledVariable(pybamm.Symbol):
7+
"""
8+
A node in the expression tree representing a variable whose equation is set by a different model or submodel.
9+
10+
11+
Parameters
12+
----------
13+
name : str
14+
name of the node
15+
domain : iterable of str
16+
list of domains that this coupled variable is valid over
17+
"""
18+
19+
def __init__(
20+
self,
21+
name: str,
22+
domain: DomainType = None,
23+
) -> None:
24+
super().__init__(name, domain=domain)
25+
26+
def _evaluate_for_shape(self):
27+
"""
28+
Returns the scalar 'NaN' to represent the shape of a parameter.
29+
See :meth:`pybamm.Symbol.evaluate_for_shape()`
30+
"""
31+
return pybamm.evaluate_for_shape_using_domain(self.domains)
32+
33+
def create_copy(self):
34+
"""Creates a new copy of the coupled variable."""
35+
new_coupled_variable = CoupledVariable(self.name, self.domain)
36+
return new_coupled_variable
37+
38+
@property
39+
def children(self):
40+
return self._children
41+
42+
@children.setter
43+
def children(self, expr):
44+
self._children = expr
45+
46+
def set_coupled_variable(self, symbol, expr):
47+
"""Sets the children of the coupled variable to the expression passed in expr. If the symbol is not the coupled variable, then it searches the children of the symbol for the coupled variable. The coupled variable will be replaced by its first child (symbol.children[0], which should be expr) in the discretisation step."""
48+
if self == symbol:
49+
symbol.children = [
50+
expr,
51+
]
52+
else:
53+
for child in symbol.children:
54+
self.set_coupled_variable(child, expr)
55+
symbol.set_id()

src/pybamm/expression_tree/operations/convert_to_casadi.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import casadi
88
import numpy as np
99
from scipy import special
10+
from scipy import interpolate
1011

1112

1213
class CasadiConverter:
@@ -165,6 +166,18 @@ def _convert(self, symbol, t, y, y_dot, inputs):
165166
# for some reason, pybamm.Interpolant always returns a column vector, so match that
166167
test = test.T
167168
return test
169+
elif solver == "bspline":
170+
bspline = interpolate.make_interp_spline(
171+
symbol.x[0], symbol.y, k=3
172+
)
173+
knots = [bspline.t]
174+
coeffs = bspline.c.flatten()
175+
degree = [bspline.k]
176+
m = len(coeffs) // len(symbol.x[0])
177+
f = casadi.Function.bspline(
178+
symbol.name, knots, coeffs, degree, m
179+
)
180+
return f(converted_children[0])
168181
else:
169182
return casadi.interpolant(
170183
"LUT", solver, symbol.x, symbol.y.flatten()
@@ -176,6 +189,20 @@ def _convert(self, symbol, t, y, y_dot, inputs):
176189
symbol.y.ravel(order="F"),
177190
converted_children,
178191
)
192+
elif solver == "bspline" and len(converted_children) == 2:
193+
bspline = interpolate.RectBivariateSpline(
194+
symbol.x[0], symbol.x[1], symbol.y
195+
)
196+
[tx, ty, c] = bspline.tck
197+
[kx, ky] = bspline.degrees
198+
knots = [tx, ty]
199+
coeffs = c
200+
degree = [kx, ky]
201+
m = 1
202+
f = casadi.Function.bspline(
203+
symbol.name, knots, coeffs, degree, m
204+
)
205+
return f(casadi.hcat(converted_children).T).T
179206
else:
180207
LUT = casadi.interpolant(
181208
"LUT", solver, symbol.x, symbol.y.ravel(order="F")
@@ -231,8 +258,6 @@ def _convert(self, symbol, t, y, y_dot, inputs):
231258

232259
else:
233260
raise TypeError(
234-
f"""
235-
Cannot convert symbol of type '{type(symbol)}' to CasADi. Symbols must all be
236-
'linear algebra' at this stage.
237-
"""
261+
f"Cannot convert symbol of type '{type(symbol)}' to CasADi. Symbols must all be "
262+
"'linear algebra' at this stage."
238263
)

0 commit comments

Comments
 (0)