Skip to content

Commit 205ca81

Browse files
Merge pull request #4196 from pybamm-team/issue-4183-remove-autograd
#4183 remove autograd
2 parents 8ba4791 + d38117b commit 205ca81

File tree

10 files changed

+56
-116
lines changed

10 files changed

+56
-116
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
## Breaking changes
5252

53+
- Functions that are created using `pybamm.Function(function_object, children)` can no longer be differentiated symbolically (e.g. to compute the Jacobian). This should affect no users, since function derivatives for all "standard" functions are explicitly implemented ([#4196](https://github.com/pybamm-team/PyBaMM/pull/4196))
5354
- Removed data files under `pybamm/input` and released them in a separate repository upstream at [pybamm-data](https://github.com/pybamm-team/pybamm-data/releases/tag/v1.0.0). Note that data files under `pybamm/input/parameters` have not been removed. ([#4098](https://github.com/pybamm-team/PyBaMM/pull/4098))
5455
- Removed `check_model` argument from `Simulation.solve`. To change the `check_model` option, use `Simulation(..., discretisation_kwargs={"check_model": False})`. ([#4020](https://github.com/pybamm-team/PyBaMM/pull/4020))
5556
- Removed multiple Docker images. Here on, a single Docker image tagged `pybamm/pybamm:latest` will be provided with both solvers (`IDAKLU` and `JAX`) pre-installed. ([#3992](https://github.com/pybamm-team/PyBaMM/pull/3992))

asv.conf.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@
8282
"wget": [],
8383
"cmake": [],
8484
"anytree": [],
85-
"autograd": [],
8685
"scikit-fem": [],
8786
"imageio": [],
8887
"pybtex": [],

pybamm/expression_tree/functions.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from typing_extensions import TypeVar
1212

1313
import pybamm
14-
from pybamm.util import import_optional_dependency
1514

1615

1716
class Function(pybamm.Symbol):
@@ -26,9 +25,6 @@ class Function(pybamm.Symbol):
2625
func(child0.evaluate(t, y, u), child1.evaluate(t, y, u), etc).
2726
children : :class:`pybamm.Symbol`
2827
The children nodes to apply the function to
29-
derivative : str, optional
30-
Which derivative to use when differentiating ("autograd" or "derivative").
31-
Default is "autograd".
3228
differentiated_function : method, optional
3329
The function which was differentiated to obtain this one. Default is None.
3430
"""
@@ -38,7 +34,6 @@ def __init__(
3834
function: Callable,
3935
*children: pybamm.Symbol,
4036
name: str | None = None,
41-
derivative: str | None = "autograd",
4237
differentiated_function: Callable | None = None,
4338
):
4439
# Turn numbers into scalars
@@ -57,7 +52,6 @@ def __init__(
5752
domains = self.get_children_domains(children)
5853

5954
self.function = function
60-
self.derivative = derivative
6155
self.differentiated_function = differentiated_function
6256

6357
super().__init__(name, children=children, domains=domains)
@@ -99,30 +93,10 @@ def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float):
9993
Derivative with respect to child number 'idx'.
10094
See :meth:`pybamm.Symbol._diff()`.
10195
"""
102-
autograd = import_optional_dependency("autograd")
103-
# Store differentiated function, needed in case we want to convert to CasADi
104-
if self.derivative == "autograd":
105-
return Function(
106-
autograd.elementwise_grad(self.function, idx),
107-
*children,
108-
differentiated_function=self.function,
109-
)
110-
elif self.derivative == "derivative":
111-
if len(children) > 1:
112-
raise ValueError(
113-
"""
114-
differentiation using '.derivative()' not implemented for functions
115-
with more than one child
116-
"""
117-
)
118-
else:
119-
# keep using "derivative" as derivative
120-
return pybamm.Function(
121-
self.function.derivative(), # type: ignore[attr-defined]
122-
*children,
123-
derivative="derivative",
124-
differentiated_function=self.function,
125-
)
96+
raise NotImplementedError(
97+
"Derivative of base Function class is not implemented. "
98+
"Please implement in child class."
99+
)
126100

127101
def _function_jac(self, children_jacs):
128102
"""Calculate the Jacobian of a function."""
@@ -190,7 +164,6 @@ def create_copy(
190164
self.function,
191165
*children,
192166
name=self.name,
193-
derivative=self.derivative,
194167
differentiated_function=self.differentiated_function,
195168
)
196169
else:
@@ -217,7 +190,6 @@ def _function_new_copy(self, children: list) -> Function:
217190
self.function,
218191
*children,
219192
name=self.name,
220-
derivative=self.derivative,
221193
differentiated_function=self.differentiated_function,
222194
)
223195
)

pybamm/expression_tree/interpolant.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
interpolator: str | None = "linear",
5151
extrapolate: bool = True,
5252
entries_string: str | None = None,
53+
_num_derivatives: int = 0,
5354
):
5455
# Check interpolator is valid
5556
if interpolator not in ["linear", "cubic", "pchip"]:
@@ -189,9 +190,13 @@ def __init__(
189190
self.x = x
190191
self.y = y
191192
self.entries_string = entries_string
192-
super().__init__(
193-
interpolating_function, *children, name=name, derivative="derivative"
194-
)
193+
194+
# Differentiate the interpolating function if necessary
195+
self._num_derivatives = _num_derivatives
196+
for _ in range(_num_derivatives):
197+
interpolating_function = interpolating_function.derivative()
198+
199+
super().__init__(interpolating_function, *children, name=name)
195200

196201
# Store information as attributes
197202
self.interpolator = interpolator
@@ -213,6 +218,7 @@ def _from_json(cls, snippet: dict):
213218
name=snippet["name"],
214219
interpolator=snippet["interpolator"],
215220
extrapolate=snippet["extrapolate"],
221+
_num_derivatives=snippet["_num_derivatives"],
216222
)
217223

218224
@property
@@ -241,6 +247,7 @@ def set_id(self):
241247
self.entries_string,
242248
*tuple([child.id for child in self.children]),
243249
*tuple(self.domain),
250+
self._num_derivatives,
244251
)
245252
)
246253

@@ -256,6 +263,7 @@ def create_copy(self, new_children=None, perform_simplifications=True):
256263
interpolator=self.interpolator,
257264
extrapolate=self.extrapolate,
258265
entries_string=self.entries_string,
266+
_num_derivatives=self._num_derivatives,
259267
)
260268

261269
def _function_evaluate(self, evaluated_children):
@@ -311,6 +319,27 @@ def _function_evaluate(self, evaluated_children):
311319
else: # pragma: no cover
312320
raise ValueError(f"Invalid dimension: {self.dimension}")
313321

322+
def _function_diff(self, children: Sequence[pybamm.Symbol], idx: float):
323+
"""
324+
Derivative with respect to child number 'idx'.
325+
See :meth:`pybamm.Symbol._diff()`.
326+
"""
327+
if len(children) > 1:
328+
raise NotImplementedError(
329+
"differentiation not implemented for functions with more than one child"
330+
)
331+
else:
332+
# keep using "derivative" as derivative
333+
return Interpolant(
334+
self.x,
335+
self.y,
336+
children,
337+
name=self.name,
338+
interpolator=self.interpolator,
339+
extrapolate=self.extrapolate,
340+
_num_derivatives=self._num_derivatives + 1,
341+
)
342+
314343
def to_json(self):
315344
"""
316345
Method to serialise an Interpolant object into JSON.
@@ -323,6 +352,7 @@ def to_json(self):
323352
"y": self.y.tolist(),
324353
"interpolator": self.interpolator,
325354
"extrapolate": self.extrapolate,
355+
"_num_derivatives": self._num_derivatives,
326356
}
327357

328358
return json_dict

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ jax = [
121121
]
122122
# Contains all optional dependencies, except for jax and dev dependencies
123123
all = [
124-
"autograd>=1.6.2",
125124
"scikit-fem>=8.1.0",
126125
"pybamm[examples,plot,cite,bpx,tqdm]",
127126
]

tests/unit/test_expression_tree/test_functions.py

Lines changed: 3 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from tests import (
1414
function_test,
1515
multi_var_function_test,
16-
multi_var_function_cube_test,
1716
)
1817

1918

@@ -52,57 +51,12 @@ def test_function_of_one_variable(self):
5251

5352
def test_diff(self):
5453
a = pybamm.StateVector(slice(0, 1))
55-
b = pybamm.StateVector(slice(1, 2))
56-
y = np.array([5])
5754
func = pybamm.Function(function_test, a)
58-
self.assertEqual(func.diff(a).evaluate(y=y), 2)
59-
self.assertEqual(func.diff(func).evaluate(), 1)
60-
func = pybamm.sin(a)
61-
self.assertEqual(func.evaluate(y=y), np.sin(a.evaluate(y=y)))
62-
self.assertEqual(func.diff(a).evaluate(y=y), np.cos(a.evaluate(y=y)))
63-
func = pybamm.exp(a)
64-
self.assertEqual(func.evaluate(y=y), np.exp(a.evaluate(y=y)))
65-
self.assertEqual(func.diff(a).evaluate(y=y), np.exp(a.evaluate(y=y)))
66-
67-
# multiple variables
68-
func = pybamm.Function(multi_var_function_test, 4 * a, 3 * a)
69-
self.assertEqual(func.diff(a).evaluate(y=y), 7)
70-
func = pybamm.Function(multi_var_function_test, 4 * a, 3 * b)
71-
self.assertEqual(func.diff(a).evaluate(y=np.array([5, 6])), 4)
72-
self.assertEqual(func.diff(b).evaluate(y=np.array([5, 6])), 3)
73-
func = pybamm.Function(multi_var_function_cube_test, 4 * a, 3 * b)
74-
self.assertEqual(func.diff(a).evaluate(y=np.array([5, 6])), 4)
75-
self.assertEqual(
76-
func.diff(b).evaluate(y=np.array([5, 6])), 3 * 3 * (3 * 6) ** 2
77-
)
78-
79-
# exceptions
80-
func = pybamm.Function(
81-
multi_var_function_cube_test, 4 * a, 3 * b, derivative="derivative"
82-
)
83-
with self.assertRaises(ValueError):
55+
with self.assertRaisesRegex(
56+
NotImplementedError, "Derivative of base Function class is not implemented"
57+
):
8458
func.diff(a)
8559

86-
def test_function_of_multiple_variables(self):
87-
a = pybamm.Variable("a")
88-
b = pybamm.Parameter("b")
89-
func = pybamm.Function(multi_var_function_test, a, b)
90-
self.assertEqual(func.name, "function (multi_var_function_test)")
91-
self.assertEqual(str(func), "multi_var_function_test(a, b)")
92-
self.assertEqual(func.children[0].name, a.name)
93-
self.assertEqual(func.children[1].name, b.name)
94-
95-
# test eval and diff
96-
a = pybamm.StateVector(slice(0, 1))
97-
b = pybamm.StateVector(slice(1, 2))
98-
y = np.array([5, 2])
99-
func = pybamm.Function(multi_var_function_test, a, b)
100-
101-
self.assertEqual(func.evaluate(y=y), 7)
102-
self.assertEqual(func.diff(a).evaluate(y=y), 1)
103-
self.assertEqual(func.diff(b).evaluate(y=y), 1)
104-
self.assertEqual(func.diff(func).evaluate(), 1)
105-
10660
def test_exceptions(self):
10761
a = pybamm.Variable("a", domain="something")
10862
b = pybamm.Variable("b", domain="something else")

tests/unit/test_expression_tree/test_interpolant.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,20 @@ def test_diff(self):
326326
decimal=3,
327327
)
328328

329+
# test 2D interpolation diff fails
330+
x = (np.arange(-5.01, 5.01, 0.05), np.arange(-5.01, 5.01, 0.01))
331+
xx, yy = np.meshgrid(x[0], x[1], indexing="ij")
332+
z = np.sin(xx**2 + yy**2)
333+
var1 = pybamm.StateVector(slice(0, 1))
334+
var2 = pybamm.StateVector(slice(1, 2))
335+
# linear
336+
interp = pybamm.Interpolant(x, z, (var1, var2), interpolator="linear")
337+
with self.assertRaisesRegex(
338+
NotImplementedError,
339+
"differentiation not implemented for functions with more than one child",
340+
):
341+
interp.diff(var1)
342+
329343
def test_processing(self):
330344
x = np.linspace(0, 1, 200)
331345
y = pybamm.StateVector(slice(0, 2))
@@ -369,6 +383,7 @@ def test_to_from_json(self):
369383
],
370384
"interpolator": "linear",
371385
"extrapolate": True,
386+
"_num_derivatives": 0,
372387
}
373388

374389
# check correct writing to json

tests/unit/test_expression_tree/test_operations/test_convert_to_casadi.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -314,22 +314,6 @@ def test_concatenations(self):
314314
y_eval = np.linspace(0, 1, expr.size)
315315
self.assert_casadi_equal(f(y_eval), casadi.SX(expr.evaluate(y=y_eval)))
316316

317-
def test_convert_differentiated_function(self):
318-
a = pybamm.InputParameter("a")
319-
b = pybamm.InputParameter("b")
320-
321-
def myfunction(x, y):
322-
return x + y**3
323-
324-
f = pybamm.Function(myfunction, a, b).diff(a)
325-
self.assert_casadi_equal(
326-
f.to_casadi(inputs={"a": 1, "b": 2}), casadi.DM(1), evalf=True
327-
)
328-
f = pybamm.Function(myfunction, a, b).diff(b)
329-
self.assert_casadi_equal(
330-
f.to_casadi(inputs={"a": 1, "b": 2}), casadi.DM(12), evalf=True
331-
)
332-
333317
def test_convert_input_parameter(self):
334318
casadi_t = casadi.MX.sym("t")
335319
casadi_y = casadi.MX.sym("y", 10)

tests/unit/test_expression_tree/test_operations/test_jac.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import unittest
99
from scipy.sparse import eye
1010
from tests import get_mesh_for_testing
11-
from tests import multi_var_function_test
1211

1312

1413
class TestJacobian(TestCase):
@@ -213,12 +212,6 @@ def test_functions(self):
213212
dfunc_dy = func.jac(y).evaluate(y=y0)
214213
np.testing.assert_array_equal(0, dfunc_dy)
215214

216-
# several children
217-
func = pybamm.Function(multi_var_function_test, 2 * y, 3 * y)
218-
jacobian = np.diag(5 * np.ones(4))
219-
dfunc_dy = func.jac(y).evaluate(y=y0)
220-
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())
221-
222215
def test_index(self):
223216
vec = pybamm.StateVector(slice(0, 5))
224217
ind = pybamm.Index(vec, 3)

tests/unit/test_expression_tree/test_operations/test_jac_2D.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from scipy.sparse import eye
1010
from tests import (
1111
get_1p1d_discretisation_for_testing,
12-
multi_var_function_test,
1312
)
1413

1514

@@ -200,12 +199,6 @@ def test_functions(self):
200199
dfunc_dy = func.jac(y).evaluate(y=y0)
201200
np.testing.assert_array_equal(0, dfunc_dy)
202201

203-
# several children
204-
func = pybamm.Function(multi_var_function_test, 2 * y, 3 * y)
205-
jacobian = np.diag(5 * np.ones(8))
206-
dfunc_dy = func.jac(y).evaluate(y=y0)
207-
np.testing.assert_array_equal(jacobian, dfunc_dy.toarray())
208-
209202
def test_jac_of_domain_concatenation(self):
210203
# create mesh
211204
disc = get_1p1d_discretisation_for_testing()

0 commit comments

Comments
 (0)