Skip to content

Commit e59bb51

Browse files
Merge pull request #2764 from pybamm-team/issue-2763-latexify
#2763 start improving latexify
2 parents 9d06ea3 + 9800bbd commit e59bb51

File tree

28 files changed

+1226
-741
lines changed

28 files changed

+1226
-741
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
## Bug fixes
1212

13+
- Improved `model.latexify()` to have a cleaner and more readable output ([#2764](https://github.com/pybamm-team/PyBaMM/pull/2764))
1314
- Fixed electrolyte conservation in the case of concentration-dependent transference number ([#2758](https://github.com/pybamm-team/PyBaMM/pull/2758))
1415
- Fixed `plot_voltage_components` so that the sum of overpotentials is now equal to the voltage ([#2740](https://github.com/pybamm-team/PyBaMM/pull/2740))
1516

examples/notebooks/models/latexify.ipynb

Lines changed: 963 additions & 448 deletions
Large diffs are not rendered by default.
-4.93 MB
Binary file not shown.

pybamm/expression_tree/concatenations.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -380,16 +380,12 @@ def __init__(self, *children):
380380
raise ValueError("Cannot concatenate symbols with different bounds")
381381
super().__init__(*children, name=name)
382382

383-
if not any(c._raw_print_name is None for c in children):
384-
print_name = intersect(
385-
children[0]._raw_print_name, children[1]._raw_print_name
386-
)
387-
for child in children[2:]:
388-
print_name = intersect(print_name, child._raw_print_name)
389-
if print_name.endswith("_"):
390-
print_name = print_name[:-1]
391-
else:
392-
print_name = None
383+
print_name = intersect(children[0]._raw_print_name, children[1]._raw_print_name)
384+
for child in children[2:]:
385+
print_name = intersect(print_name, child._raw_print_name)
386+
if print_name.endswith("_"):
387+
print_name = print_name[:-1]
388+
393389
self.print_name = print_name
394390

395391

pybamm/expression_tree/operations/latexify.py

Lines changed: 48 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Latexify class
33
#
44
import copy
5-
import operator
65
import re
76
import warnings
87

@@ -55,58 +54,6 @@ def __init__(self, model, filename=None, newline=True):
5554
self.filename = filename
5655
self.newline = newline
5756

58-
def _get_concat_displays(self, node):
59-
"""
60-
Returns all the concatenation nodes by doing a depth first search through the
61-
entire equation tree with ranges in front of all nodes.
62-
"""
63-
concat_displays = []
64-
dfs_nodes = [node]
65-
while dfs_nodes:
66-
node = dfs_nodes.pop()
67-
if (
68-
hasattr(node, "concat_latex")
69-
and getattr(node, "print_name", None) is not None
70-
):
71-
# Combine list of concatenations with list of ranges
72-
concat_geo = map(
73-
operator.add,
74-
node.concat_latex,
75-
self._get_concat_geometry_displays(node),
76-
)
77-
78-
# Add cases and split by new line
79-
concat_sym = (
80-
r"\begin{cases}" + r" \\ ".join(concat_geo) + r"\end{cases}"
81-
)
82-
concat_eqn = sympy.Eq(
83-
sympy.Symbol(node.print_name),
84-
sympy.Symbol(concat_sym),
85-
evaluate=False,
86-
)
87-
concat_displays.append(concat_eqn)
88-
dfs_nodes.extend(node.children)
89-
90-
# Remove duplicates from the list whilst preserving order
91-
return list(dict.fromkeys(concat_displays))
92-
93-
def _get_concat_geometry_displays(self, var):
94-
"""Returns a list of min/max ranges of all concatenation nodes in latex."""
95-
geo = []
96-
97-
# Loop through all subdomains for concatenations
98-
for domain in var.domain:
99-
for var_name, rng in self.model.default_geometry[domain].items():
100-
if "min" in rng and "max" in rng:
101-
rng_min = get_rng_min_max_name(rng, "min")
102-
rng_max = get_rng_min_max_name(rng, "max")
103-
104-
name = sympy.latex(var_name)
105-
geo_latex = f"& {rng_min} < {name} < {rng_max}"
106-
geo.append(geo_latex)
107-
108-
return geo
109-
11057
def _get_geometry_displays(self, var):
11158
"""
11259
Returns min range from the first domain and max range from the last domain of
@@ -136,7 +83,7 @@ def _get_geometry_displays(self, var):
13683

13784
return geo
13885

139-
def _get_bcs_displays(self, lhs_dr, var):
86+
def _get_bcs_displays(self, var):
14087
"""
14188
Returns a list of boundary condition equations with ranges in front of
14289
the equations.
@@ -146,27 +93,25 @@ def _get_bcs_displays(self, lhs_dr, var):
14693

14794
if bcs:
14895
# Take range minimum from the first domain
149-
for var_name, rng in self.model.default_geometry[var.domain[0]].items():
150-
# Trim name (r_n --> r)
151-
name = re.findall(r"(.)_*.*", str(var_name))[0]
152-
rng_min = get_rng_min_max_name(rng, "min")
153-
154-
bcs_left = sympy.latex(bcs["left"][0].to_equation())
155-
bcs_left_latex = bcs_left + f"\quad {name} = {rng_min}"
156-
bcs_eqn = sympy.Eq(lhs_dr, sympy.Symbol(bcs_left_latex), evaluate=False)
157-
bcs_eqn_list.append(bcs_eqn)
96+
var_name = list(self.model.default_geometry[var.domain[0]].keys())[0]
97+
rng_left = list(self.model.default_geometry[var.domain[0]].values())[0]
98+
rng_right = list(self.model.default_geometry[var.domain[-1]].values())[0]
15899

159-
# Take range maximum from the last domain
160-
for var_name, rng in self.model.default_geometry[var.domain[-1]].items():
161-
# Trim name (r_n --> r)
162-
name = re.findall(r"(.)_*.*", str(var_name))[0]
163-
rng_max = get_rng_min_max_name(rng, "max")
100+
# Trim name (r_n --> r)
101+
var_name = re.findall(r"(.)_*.*", str(var_name))[0]
164102

165-
bcs_right = sympy.latex(bcs["right"][0].to_equation())
166-
bcs_right_latex = bcs_right + f"\quad {name} = {rng_max}"
167-
bcs_eqn = sympy.Eq(
168-
lhs_dr, sympy.Symbol(bcs_right_latex), evaluate=False
169-
)
103+
rng_min = get_rng_min_max_name(rng_left, "min")
104+
rng_max = get_rng_min_max_name(rng_right, "max")
105+
106+
for side, rng in [("left", rng_min), ("right", rng_max)]:
107+
bc_value, bc_type = bcs[side]
108+
bcs_side = sympy.latex(bc_value.to_equation())
109+
bcs_side_latex = bcs_side + f"\\quad \\text{{at }} {var_name} = {rng}"
110+
if bc_type == "Dirichlet":
111+
lhs = sympy.Symbol(var.print_name)
112+
else:
113+
lhs = sympy.Symbol(r"\nabla " + var.print_name)
114+
bcs_eqn = sympy.Eq(lhs, sympy.Symbol(bcs_side_latex), evaluate=False)
170115
bcs_eqn_list.append(bcs_eqn)
171116

172117
return bcs_eqn_list
@@ -214,7 +159,14 @@ def _get_param_var(self, node):
214159

215160
return param_list, var_list
216161

217-
def latexify(self):
162+
def latexify(self, output_variables=None):
163+
# Voltage is the default output variable if it exists
164+
if output_variables is None:
165+
if "Voltage [V]" in self.model.variables:
166+
output_variables = ["Voltage [V]"]
167+
else:
168+
output_variables = []
169+
218170
eqn_list = []
219171
param_list = []
220172
var_list = []
@@ -234,22 +186,24 @@ def latexify(self):
234186
eqn_list.append(sympy.Symbol(r"\\ \textbf{" + str(var) + "}"))
235187

236188
# Set lhs derivative
237-
lhs = sympy.Derivative(var_symbol, "t")
238-
lhs_dr = sympy.Derivative(var_symbol, "r")
189+
ddt = sympy.Derivative(var_symbol, "t")
239190

240191
# Override lhs for algebraic
241-
if eqn_type == "algebraic":
192+
if eqn_type == "rhs":
193+
lhs = ddt
194+
else:
242195
lhs = 0
243196

244197
# Override derivative to partial derivative
245-
if len(var.domain) != 0 and var.domain != "current collector":
246-
lhs_dr.force_partial = True
247-
248-
if not eqn_type == "algebraic":
249-
lhs.force_partial = True
198+
if (
199+
len(var.domain) != 0
200+
and var.domain != "current collector"
201+
and eqn_type == "rhs"
202+
):
203+
lhs.force_partial = True
250204

251205
# Boundary conditions equations
252-
bcs = self._get_bcs_displays(lhs_dr, var)
206+
bcs = self._get_bcs_displays(var)
253207

254208
# Add ranges from geometry in rhs
255209
geo = self._get_geometry_displays(var)
@@ -263,14 +217,13 @@ def latexify(self):
263217
if not eqn_type == "algebraic":
264218
init = self.model.initial_conditions.get(var, None)
265219
init_eqn = sympy.Eq(var_symbol, init.to_equation(), evaluate=False)
266-
init_eqn = sympy.Symbol(sympy.latex(init_eqn) + r"\quad at\; t=0")
220+
init_eqn = sympy.Symbol(
221+
sympy.latex(init_eqn) + r"\quad \text{at}\; t=0"
222+
)
267223

268224
# Make equation from lhs and rhs
269225
lhs_rhs = sympy.Eq(lhs, rhs, evaluate=False)
270226

271-
# Get all concatenation nodes
272-
concat_displays = self._get_concat_displays(eqn)
273-
274227
# Set SymPy's init printing to use CustomPrint from sympy_overrides.py
275228
sympy.init_printing(
276229
use_latex=True,
@@ -282,10 +235,6 @@ def latexify(self):
282235
# Add model equations to the list
283236
eqn_list.append(lhs_rhs)
284237

285-
# Add concatenation to the list
286-
if concat_displays:
287-
eqn_list.append(concat_displays)
288-
289238
# Add initial conditions to the list
290239
if not eqn_type == "algebraic":
291240
eqn_list.extend([init_eqn])
@@ -298,23 +247,21 @@ def latexify(self):
298247
param_list.extend(list1)
299248
var_list.extend(list2)
300249

301-
# Add voltage expression to the list
302-
if "Voltage [V]" in self.model.variables:
303-
voltage = self.model.variables["Voltage [V]"].to_equation()
304-
voltage_eqn = sympy.Eq(sympy.Symbol("V"), voltage, evaluate=False)
305-
# Add voltage to the list
306-
eqn_list.append(sympy.Symbol(r"\\ \textbf{Voltage [V]}"))
307-
eqn_list.extend([voltage_eqn])
250+
# Add output variables to the list
251+
for var_name in output_variables:
252+
var = self.model.variables[var_name].to_equation()
253+
var_eqn = sympy.Eq(sympy.Symbol("V"), var, evaluate=False)
254+
# Add var to the list
255+
eqn_list.append(sympy.Symbol(r"\\ \textbf{" + var_name + "}"))
256+
eqn_list.extend([var_eqn])
308257

309258
# Remove duplicates from the list whilst preserving order
310259
param_list = list(dict.fromkeys(param_list))
311260
var_list = list(dict.fromkeys(var_list))
312261
# Add Parameters and Variables to the list
313262
eqn_list.append(sympy.Symbol(r"\\ \textbf{Parameters and Variables}"))
314-
# Add parameters to the list
315-
eqn_list.extend(param_list)
316-
# Add names to the list
317263
eqn_list.extend(var_list)
264+
eqn_list.extend(param_list)
318265

319266
# Split list with new lines
320267
eqn_new_line = sympy.Symbol(r"\\\\".join(map(custom_print_func, eqn_list)))
Lines changed: 57 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
#
22
# Prettify print_name
33
#
4-
import re
5-
64
PRINT_NAME_OVERRIDES = {
75
"current_with_time": "I",
8-
"eps_c_e": r"\epsilon{c_e}",
9-
"thermodynamic_factor": r"1+\frac{dlnf}{dlnc}",
10-
"negative_particle_concentration_scale": r"c_{n}^{max}",
11-
"positive_particle_concentration_scale": r"c_{p}^{max}",
6+
"current_density_with_time": r"i_{\mathrm{cell}}",
7+
"thermodynamic_factor": r"\left(1+\frac{dlnf}{dlnc}\right)",
8+
"t_plus": r"t_{\mathrm{+}}",
129
}
1310

1411
GREEK_LETTERS = [
@@ -45,58 +42,64 @@ def prettify_print_name(name):
4542
if name is None or "{" in name or "\\" in name:
4643
return name
4744

48-
# Return print_name if name exists in the dictionary
49-
if name in PRINT_NAME_OVERRIDES:
50-
return PRINT_NAME_OVERRIDES[name]
51-
52-
# Superscripts with comma separated (U_ref_n --> U_{n}^{ref})
53-
sup_re1 = re.search(r"^[\da-zA-Z]+_?((?:init|ref|typ|max|0))_?(.*)", name)
54-
if sup_re1:
55-
sup_str = (
56-
r"{"
57-
+ sup_re1.group(2).replace("_", "\,")
58-
+ r"}^{"
59-
+ sup_re1.group(1)
60-
+ r"}"
61-
)
62-
sup_var = sup_re1.group(1) + "_" + sup_re1.group(2)
63-
name = name.replace(sup_var, sup_str)
45+
# Find subscripts, superscripts, and averaging
46+
# Remove them from the name one by one and add them later in processed form
47+
subscripts = []
48+
superscripts = []
49+
average = False
6450

65-
# Superscripts with comma separated (U_n_ref --> U_{n}^{ref})
66-
sup_re2 = re.search(r"^[\da-zA-Z]+_?(.*?)_?((?:init|ref|typ|max|0))", name)
67-
if sup_re2:
68-
sup_str = (
69-
r"{"
70-
+ sup_re2.group(1).replace("_", "\,")
71-
+ r"}^{"
72-
+ sup_re2.group(2)
73-
+ r"}"
74-
)
75-
sup_var = sup_re2.group(1) + "_" + sup_re2.group(2)
76-
name = name.replace(sup_var, sup_str)
51+
processing = True
52+
while processing:
53+
# Set processing to False. If any of the following conditions are met,
54+
# it will be set to True again
55+
processing = False
56+
for superscript in ["init", "ref", "typ", "max", "0", "surf"]:
57+
if f"_{superscript}_" in name or name.endswith(f"_{superscript}"):
58+
superscripts.append(superscript)
59+
name = name.replace(f"_{superscript}", "")
60+
processing = True
61+
break
62+
# "0" might also appear without a preceding underscore
63+
for superscript in ["0"]:
64+
if superscript in name:
65+
superscripts.append(superscript)
66+
name = name.replace(superscript, "")
67+
processing = True
68+
break
69+
for subscript in ["cc", "dl", "R", "e", "s", "n", "p", "amb"]:
70+
if f"_{subscript}_" in name or name.endswith(f"_{subscript}"):
71+
subscripts.append(subscript)
72+
name = name.replace(f"_{subscript}", "")
73+
processing = True
74+
break
75+
for av in ["av", "xav"]:
76+
if f"_{av}_" in name or name.endswith(f"_{av}"):
77+
average = True
78+
name = name.replace(f"_{av}", "")
79+
processing = True
80+
break
7781

78-
# Subscripts with comma separated (a_R_p --> a_{R\,p})
79-
sub_re = re.search(r"^a_+(\w+)", name)
80-
if sub_re:
81-
sub_str = r"{" + sub_re.group(1).replace("_", "\,") + r"}"
82-
name = name.replace(sub_re.group(1), sub_str)
83-
84-
# Bar with comma separated (c_s_n_xav --> \bar{c}_{s\,n})
85-
bar_re = re.search(r"^([a-zA-Z]+)_*(\w*?)_(?:av|xav)", name)
86-
if bar_re:
87-
name = (
88-
r"\bar{"
89-
+ bar_re.group(1)
90-
+ r"}_{"
91-
+ bar_re.group(2).replace("_", "\,")
92-
+ r"}"
93-
)
82+
# Process name
83+
# Override print_name if name exists in the dictionary
84+
if name in PRINT_NAME_OVERRIDES:
85+
name = PRINT_NAME_OVERRIDES[name]
9486

95-
# Replace eps with epsilon (eps_n --> epsilon_n)
96-
name = re.sub(r"(eps)(?![0-9a-zA-Z])", "epsilon", name)
87+
# Replace eps with epsilon (e.g. eps_n --> epsilon_n)
88+
if name == "eps":
89+
name = r"\epsilon"
90+
if name == "eps_c":
91+
name = r"(\epsilon c)"
9792

9893
# Greek letters (delta --> \delta)
99-
greek_re = r"(?<!\\)(" + "|".join(GREEK_LETTERS) + r")(?![0-9a-zA-Z])"
100-
name = re.sub(greek_re, r"\\\1", name, flags=re.IGNORECASE)
94+
if name.lower() in GREEK_LETTERS:
95+
name = "\\" + name
96+
97+
# Add subscripts and superscripts
98+
if average:
99+
name = r"\overline{" + name + "}"
100+
if subscripts:
101+
name += r"_{\mathrm{" + ",".join(subscripts) + "}}"
102+
if superscripts:
103+
name += r"^{\mathrm{" + ",".join(superscripts) + "}}"
101104

102105
return name

0 commit comments

Comments
 (0)