Skip to content

Commit dd77bec

Browse files
Merge pull request #1789 from pybamm-team/issue-1768-concat
Issue 1768 concat
2 parents d0712e2 + 2e37286 commit dd77bec

File tree

5 files changed

+50
-27
lines changed

5 files changed

+50
-27
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
- Half-cell SPM and SPMe have been implemented ( [#1731](https://github.com/pybamm-team/PyBaMM/pull/1731))
66
## Bug fixes
77

8-
- Fixed finite volume discretization in spherical polar coordinates ([#1782](https://github.com/pybamm-team/PyBaMM/pull/1782))
98
- Fixed `sympy` operators for `Arctan` and `Exponential` ([#1786](https://github.com/pybamm-team/PyBaMM/pull/1786))
9+
- Fixed finite volume discretization in spherical polar coordinates ([#1782](https://github.com/pybamm-team/PyBaMM/pull/1782))
10+
11+
## Breaking changes
1012

13+
- Raise error if `Concatenation` is used directly with `Variable` objects (`concatenation` should be used instead) ([#1789](https://github.com/pybamm-team/PyBaMM/pull/1789))
1114
# [v21.10](https://github.com/pybamm-team/PyBaMM/tree/v21.9) - 2021-10-31
1215

1316
## Features

pybamm/expression_tree/concatenations.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ class Concatenation(pybamm.Symbol):
2424
"""
2525

2626
def __init__(self, *children, name=None, check_domain=True, concat_fun=None):
27+
# The second condition checks whether this is the base Concatenation class
28+
# or a subclass of Concatenation
29+
# (ConcatenationVariable, NumpyConcatenation, ...)
30+
if all(isinstance(child, pybamm.Variable) for child in children) and issubclass(
31+
Concatenation, type(self)
32+
):
33+
raise TypeError(
34+
"'ConcatenationVariable' should be used for concatenating 'Variable' "
35+
"objects. We recommend using the 'concatenation' function, which will "
36+
"automatically choose the best form."
37+
)
2738
if name is None:
2839
name = "concatenation"
2940
if check_domain:
@@ -46,10 +57,8 @@ def __str__(self):
4657
return out
4758

4859
def _diff(self, variable):
49-
""" See :meth:`pybamm.Symbol._diff()`. """
50-
children_diffs = [
51-
child.diff(variable) for child in self.cached_children
52-
]
60+
"""See :meth:`pybamm.Symbol._diff()`."""
61+
children_diffs = [child.diff(variable) for child in self.cached_children]
5362
if len(children_diffs) == 1:
5463
diff = children_diffs[0]
5564
else:
@@ -411,15 +420,17 @@ def simplified_concatenation(*children):
411420
"""Perform simplifications on a concatenation."""
412421
# remove children that are None
413422
children = list(filter(lambda x: x is not None, children))
414-
# Create Concatenation to easily read domains
415-
concat = Concatenation(*children)
416423
# Simplify concatenation of broadcasts all with the same child to a single
417424
# broadcast across all domains
418425
if len(children) == 0:
419426
raise ValueError("Cannot create empty concatenation")
420427
elif len(children) == 1:
421428
return children[0]
429+
elif all(isinstance(child, pybamm.Variable) for child in children):
430+
return pybamm.ConcatenationVariable(*children)
422431
else:
432+
# Create Concatenation to easily read domains
433+
concat = Concatenation(*children)
423434
if all(
424435
isinstance(child, pybamm.Broadcast)
425436
and child.child.id == children[0].child.id
@@ -432,9 +443,8 @@ def simplified_concatenation(*children):
432443
return pybamm.FullBroadcast(
433444
unique_child, concat.domain, concat.auxiliary_domains
434445
)
435-
elif all(isinstance(child, pybamm.Variable) for child in children):
436-
return pybamm.ConcatenationVariable(*children)
437-
return concat
446+
else:
447+
return concat
438448

439449

440450
def concatenation(*children):

tests/integration/test_models/standard_model_tests.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
class StandardModelTest(object):
11-
""" Basic processing test for the models. """
11+
"""Basic processing test for the models."""
1212

1313
def __init__(
1414
self,
@@ -62,8 +62,9 @@ def test_processing_disc(self, disc=None):
6262
# Model should still be well-posed after processing
6363
self.model.check_well_posedness(post_discretisation=True)
6464

65-
def test_solving(self, solver=None, t_eval=None, inputs=None,
66-
calculate_sensitivities=False):
65+
def test_solving(
66+
self, solver=None, t_eval=None, inputs=None, calculate_sensitivities=False
67+
):
6768
# Overwrite solver if given
6869
if solver is not None:
6970
self.solver = solver
@@ -82,7 +83,9 @@ def test_solving(self, solver=None, t_eval=None, inputs=None,
8283
t_eval = np.linspace(0, 3600 / Crate, 100)
8384

8485
self.solution = self.solver.solve(
85-
self.model, t_eval, inputs=inputs,
86+
self.model,
87+
t_eval,
88+
inputs=inputs,
8689
)
8790

8891
def test_outputs(self):
@@ -92,8 +95,9 @@ def test_outputs(self):
9295
)
9396
std_out_test.test_all()
9497

95-
def test_sensitivities(self, param_name, param_value,
96-
output_name='Terminal voltage [V]'):
98+
def test_sensitivities(
99+
self, param_name, param_value, output_name="Terminal voltage [V]"
100+
):
97101

98102
self.parameter_values.update({param_name: param_value})
99103
Crate = abs(
@@ -114,8 +118,7 @@ def test_sensitivities(self, param_name, param_value,
114118
self.solver.atol = 1e-8
115119

116120
self.solution = self.solver.solve(
117-
self.model, t_eval, inputs=inputs,
118-
calculate_sensitivities=True
121+
self.model, t_eval, inputs=inputs, calculate_sensitivities=True
119122
)
120123
output_sens = self.solution[output_name].sensitivities[param_name]
121124

@@ -124,18 +127,20 @@ def test_sensitivities(self, param_name, param_value,
124127
inputs_plus = {param_name: (param_value + 0.5 * h)}
125128
inputs_neg = {param_name: (param_value - 0.5 * h)}
126129
sol_plus = self.solver.solve(
127-
self.model, t_eval, inputs=inputs_plus,
130+
self.model,
131+
t_eval,
132+
inputs=inputs_plus,
128133
)
129134
output_plus = sol_plus[output_name](t=t_eval)
130-
sol_neg = self.solver.solve(
131-
self.model, t_eval, inputs=inputs_neg
132-
)
135+
sol_neg = self.solver.solve(self.model, t_eval, inputs=inputs_neg)
133136
output_neg = sol_neg[output_name](t=t_eval)
134-
fd = ((np.array(output_plus) - np.array(output_neg)) / h)
137+
fd = (np.array(output_plus) - np.array(output_neg)) / h
135138
fd = fd.transpose().reshape(-1, 1)
136139
np.testing.assert_allclose(
137-
output_sens, fd,
138-
rtol=1e-2, atol=1e-6,
140+
output_sens,
141+
fd,
142+
rtol=1e-2,
143+
atol=1e-6,
139144
)
140145

141146
def test_all(
@@ -156,7 +161,7 @@ def test_all(
156161

157162

158163
class OptimisationsTest(object):
159-
""" Test that the optimised models give the same result as the original model. """
164+
"""Test that the optimised models give the same result as the original model."""
160165

161166
def __init__(self, model, parameter_values=None, disc=None):
162167
# Set parameter values

tests/unit/test_expression_tree/test_concatenations.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ def test_base_concatenation(self):
4040
# concatenation of lenght 1
4141
self.assertEqual(pybamm.concatenation(a), a)
4242

43+
a = pybamm.Variable("a", domain="test a")
44+
b = pybamm.Variable("b", domain="test b")
45+
with self.assertRaisesRegex(TypeError, "ConcatenationVariable"):
46+
pybamm.Concatenation(a, b)
47+
4348
def test_concatenation_domains(self):
4449
a = pybamm.Symbol("a", domain=["negative electrode"])
4550
b = pybamm.Symbol("b", domain=["separator", "positive electrode"])

tests/unit/test_expression_tree/test_operations/test_evaluate_python.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def test_domain_concatenation_2D(self):
275275
np.testing.assert_allclose(result, expr.evaluate(y=y))
276276

277277
# check that concatenating a single domain is consistent
278-
expr = disc.process_symbol(pybamm.Concatenation(a))
278+
expr = disc.process_symbol(pybamm.concatenation(a))
279279
evaluator = pybamm.EvaluatorPython(expr)
280280
result = evaluator.evaluate(y=y)
281281
np.testing.assert_allclose(result, expr.evaluate(y=y))

0 commit comments

Comments
 (0)