22# Latexify class
33#
44import copy
5- import operator
65import re
76import warnings
87
@@ -55,58 +54,6 @@ def __init__(self, model, filename=None, newline=True):
5554 self .filename = filename
5655 self .newline = newline
5756
58- def _get_concat_displays (self , node ):
59- """
60- Returns all the concatenation nodes by doing a depth first search through the
61- entire equation tree with ranges in front of all nodes.
62- """
63- concat_displays = []
64- dfs_nodes = [node ]
65- while dfs_nodes :
66- node = dfs_nodes .pop ()
67- if (
68- hasattr (node , "concat_latex" )
69- and getattr (node , "print_name" , None ) is not None
70- ):
71- # Combine list of concatenations with list of ranges
72- concat_geo = map (
73- operator .add ,
74- node .concat_latex ,
75- self ._get_concat_geometry_displays (node ),
76- )
77-
78- # Add cases and split by new line
79- concat_sym = (
80- r"\begin{cases}" + r" \\ " .join (concat_geo ) + r"\end{cases}"
81- )
82- concat_eqn = sympy .Eq (
83- sympy .Symbol (node .print_name ),
84- sympy .Symbol (concat_sym ),
85- evaluate = False ,
86- )
87- concat_displays .append (concat_eqn )
88- dfs_nodes .extend (node .children )
89-
90- # Remove duplicates from the list whilst preserving order
91- return list (dict .fromkeys (concat_displays ))
92-
93- def _get_concat_geometry_displays (self , var ):
94- """Returns a list of min/max ranges of all concatenation nodes in latex."""
95- geo = []
96-
97- # Loop through all subdomains for concatenations
98- for domain in var .domain :
99- for var_name , rng in self .model .default_geometry [domain ].items ():
100- if "min" in rng and "max" in rng :
101- rng_min = get_rng_min_max_name (rng , "min" )
102- rng_max = get_rng_min_max_name (rng , "max" )
103-
104- name = sympy .latex (var_name )
105- geo_latex = f"& { rng_min } < { name } < { rng_max } "
106- geo .append (geo_latex )
107-
108- return geo
109-
11057 def _get_geometry_displays (self , var ):
11158 """
11259 Returns min range from the first domain and max range from the last domain of
@@ -136,7 +83,7 @@ def _get_geometry_displays(self, var):
13683
13784 return geo
13885
139- def _get_bcs_displays (self , lhs_dr , var ):
86+ def _get_bcs_displays (self , var ):
14087 """
14188 Returns a list of boundary condition equations with ranges in front of
14289 the equations.
@@ -146,27 +93,25 @@ def _get_bcs_displays(self, lhs_dr, var):
14693
14794 if bcs :
14895 # Take range minimum from the first domain
149- for var_name , rng in self .model .default_geometry [var .domain [0 ]].items ():
150- # Trim name (r_n --> r)
151- name = re .findall (r"(.)_*.*" , str (var_name ))[0 ]
152- rng_min = get_rng_min_max_name (rng , "min" )
153-
154- bcs_left = sympy .latex (bcs ["left" ][0 ].to_equation ())
155- bcs_left_latex = bcs_left + f"\quad { name } = { rng_min } "
156- bcs_eqn = sympy .Eq (lhs_dr , sympy .Symbol (bcs_left_latex ), evaluate = False )
157- bcs_eqn_list .append (bcs_eqn )
96+ var_name = list (self .model .default_geometry [var .domain [0 ]].keys ())[0 ]
97+ rng_left = list (self .model .default_geometry [var .domain [0 ]].values ())[0 ]
98+ rng_right = list (self .model .default_geometry [var .domain [- 1 ]].values ())[0 ]
15899
159- # Take range maximum from the last domain
160- for var_name , rng in self .model .default_geometry [var .domain [- 1 ]].items ():
161- # Trim name (r_n --> r)
162- name = re .findall (r"(.)_*.*" , str (var_name ))[0 ]
163- rng_max = get_rng_min_max_name (rng , "max" )
100+ # Trim name (r_n --> r)
101+ var_name = re .findall (r"(.)_*.*" , str (var_name ))[0 ]
164102
165- bcs_right = sympy .latex (bcs ["right" ][0 ].to_equation ())
166- bcs_right_latex = bcs_right + f"\quad { name } = { rng_max } "
167- bcs_eqn = sympy .Eq (
168- lhs_dr , sympy .Symbol (bcs_right_latex ), evaluate = False
169- )
103+ rng_min = get_rng_min_max_name (rng_left , "min" )
104+ rng_max = get_rng_min_max_name (rng_right , "max" )
105+
106+ for side , rng in [("left" , rng_min ), ("right" , rng_max )]:
107+ bc_value , bc_type = bcs [side ]
108+ bcs_side = sympy .latex (bc_value .to_equation ())
109+ bcs_side_latex = bcs_side + f"\\ quad \\ text{{at }} { var_name } = { rng } "
110+ if bc_type == "Dirichlet" :
111+ lhs = sympy .Symbol (var .print_name )
112+ else :
113+ lhs = sympy .Symbol (r"\nabla " + var .print_name )
114+ bcs_eqn = sympy .Eq (lhs , sympy .Symbol (bcs_side_latex ), evaluate = False )
170115 bcs_eqn_list .append (bcs_eqn )
171116
172117 return bcs_eqn_list
@@ -214,7 +159,14 @@ def _get_param_var(self, node):
214159
215160 return param_list , var_list
216161
217- def latexify (self ):
162+ def latexify (self , output_variables = None ):
163+ # Voltage is the default output variable if it exists
164+ if output_variables is None :
165+ if "Voltage [V]" in self .model .variables :
166+ output_variables = ["Voltage [V]" ]
167+ else :
168+ output_variables = []
169+
218170 eqn_list = []
219171 param_list = []
220172 var_list = []
@@ -234,22 +186,24 @@ def latexify(self):
234186 eqn_list .append (sympy .Symbol (r"\\ \textbf{" + str (var ) + "}" ))
235187
236188 # Set lhs derivative
237- lhs = sympy .Derivative (var_symbol , "t" )
238- lhs_dr = sympy .Derivative (var_symbol , "r" )
189+ ddt = sympy .Derivative (var_symbol , "t" )
239190
240191 # Override lhs for algebraic
241- if eqn_type == "algebraic" :
192+ if eqn_type == "rhs" :
193+ lhs = ddt
194+ else :
242195 lhs = 0
243196
244197 # Override derivative to partial derivative
245- if len (var .domain ) != 0 and var .domain != "current collector" :
246- lhs_dr .force_partial = True
247-
248- if not eqn_type == "algebraic" :
249- lhs .force_partial = True
198+ if (
199+ len (var .domain ) != 0
200+ and var .domain != "current collector"
201+ and eqn_type == "rhs"
202+ ):
203+ lhs .force_partial = True
250204
251205 # Boundary conditions equations
252- bcs = self ._get_bcs_displays (lhs_dr , var )
206+ bcs = self ._get_bcs_displays (var )
253207
254208 # Add ranges from geometry in rhs
255209 geo = self ._get_geometry_displays (var )
@@ -263,14 +217,13 @@ def latexify(self):
263217 if not eqn_type == "algebraic" :
264218 init = self .model .initial_conditions .get (var , None )
265219 init_eqn = sympy .Eq (var_symbol , init .to_equation (), evaluate = False )
266- init_eqn = sympy .Symbol (sympy .latex (init_eqn ) + r"\quad at\; t=0" )
220+ init_eqn = sympy .Symbol (
221+ sympy .latex (init_eqn ) + r"\quad \text{at}\; t=0"
222+ )
267223
268224 # Make equation from lhs and rhs
269225 lhs_rhs = sympy .Eq (lhs , rhs , evaluate = False )
270226
271- # Get all concatenation nodes
272- concat_displays = self ._get_concat_displays (eqn )
273-
274227 # Set SymPy's init printing to use CustomPrint from sympy_overrides.py
275228 sympy .init_printing (
276229 use_latex = True ,
@@ -282,10 +235,6 @@ def latexify(self):
282235 # Add model equations to the list
283236 eqn_list .append (lhs_rhs )
284237
285- # Add concatenation to the list
286- if concat_displays :
287- eqn_list .append (concat_displays )
288-
289238 # Add initial conditions to the list
290239 if not eqn_type == "algebraic" :
291240 eqn_list .extend ([init_eqn ])
@@ -298,23 +247,21 @@ def latexify(self):
298247 param_list .extend (list1 )
299248 var_list .extend (list2 )
300249
301- # Add voltage expression to the list
302- if "Voltage [V]" in self . model . variables :
303- voltage = self .model .variables ["Voltage [V]" ].to_equation ()
304- voltage_eqn = sympy .Eq (sympy .Symbol ("V" ), voltage , evaluate = False )
305- # Add voltage to the list
306- eqn_list .append (sympy .Symbol (r"\\ \textbf{Voltage [V] }" ))
307- eqn_list .extend ([voltage_eqn ])
250+ # Add output variables to the list
251+ for var_name in output_variables :
252+ var = self .model .variables [var_name ].to_equation ()
253+ var_eqn = sympy .Eq (sympy .Symbol ("V" ), var , evaluate = False )
254+ # Add var to the list
255+ eqn_list .append (sympy .Symbol (r"\\ \textbf{" + var_name + " }" ))
256+ eqn_list .extend ([var_eqn ])
308257
309258 # Remove duplicates from the list whilst preserving order
310259 param_list = list (dict .fromkeys (param_list ))
311260 var_list = list (dict .fromkeys (var_list ))
312261 # Add Parameters and Variables to the list
313262 eqn_list .append (sympy .Symbol (r"\\ \textbf{Parameters and Variables}" ))
314- # Add parameters to the list
315- eqn_list .extend (param_list )
316- # Add names to the list
317263 eqn_list .extend (var_list )
264+ eqn_list .extend (param_list )
318265
319266 # Split list with new lines
320267 eqn_new_line = sympy .Symbol (r"\\\\" .join (map (custom_print_func , eqn_list )))
0 commit comments