@@ -273,26 +273,59 @@ def set_variable_slices(self, variables):
273273 end = 0
274274 lower_bounds = []
275275 upper_bounds = []
276+
276277 # Iterate through unpacked variables, adding appropriate slices to y_slices
277278 for variable in variables :
279+ if variable in y_slices :
280+ continue
278281 # Add up the size of all the domains in variable.domain
279282 if isinstance (variable , pybamm .ConcatenationVariable ):
280- start_ = start
281283 spatial_method = self .spatial_methods [variable .domain [0 ]]
284+ dimension = spatial_method .mesh [variable .domain [0 ]].dimension
285+ start_ = start
282286 children = variable .children
283287 meshes = OrderedDict ()
288+ lr_points = OrderedDict ()
289+ tb_points = OrderedDict ()
284290 for child in children :
285291 meshes [child ] = [spatial_method .mesh [dom ] for dom in child .domain ]
292+ if dimension == 2 :
293+ lr_points [child ] = sum (
294+ spatial_method .mesh [dom ].npts_lr for dom in child .domain
295+ )
296+ tb_points [child ] = sum (
297+ spatial_method .mesh [dom ].npts_tb for dom in child .domain
298+ )
286299 sec_points = spatial_method ._get_auxiliary_domain_repeats (
287300 variable .domains
288301 )
289302 for _ in range (sec_points ):
303+ start_this_child = start_
290304 for child , mesh in meshes .items ():
291305 for domain_mesh in mesh :
292306 end += domain_mesh .npts_for_broadcast_to_nodes
293307 # Add to slices
294- y_slices [child ].append (slice (start_ , end ))
295- y_slices_explicit [child ].append (slice (start_ , end ))
308+ if dimension == 2 :
309+ other_children = set (meshes .keys ()) - {child }
310+ num_pts_to_skip = sum (
311+ lr_points [other_child ] for other_child in other_children
312+ )
313+ for row in range (tb_points [child ]):
314+ start_this_row = (
315+ start_this_child
316+ + (lr_points [child ] + num_pts_to_skip ) * row
317+ )
318+ end_this_row = start_this_row + lr_points [child ]
319+ y_slices [child ].append (
320+ slice (start_this_row , end_this_row )
321+ )
322+ y_slices_explicit [child ].append (
323+ slice (start_this_row , end_this_row )
324+ )
325+ start_this_child += lr_points [child ]
326+ else :
327+ y_slices [child ].append (slice (start_ , end ))
328+ y_slices_explicit [child ].append (slice (start_ , end ))
296329 # Increment start_
297330 start_ = end
298331 else :
@@ -737,6 +770,7 @@ def process_dict(self, var_eqn_dict, ics=False):
737770 eqn = pybamm .FullBroadcast (eqn , broadcast_domains = eqn_key .domains )
738771
739772 pybamm .logger .debug (f"Discretise { eqn_key !r} " )
773+
740774 processed_eqn = self .process_symbol (eqn )
741775 # Calculate scale if the key has a scale
742776 scale = getattr (eqn_key , "scale" , 1 )
@@ -811,9 +845,46 @@ def _process_symbol(self, symbol):
811845 if isinstance (symbol , pybamm .BinaryOperator ):
812846 # Pre-process children
813847 left , right = symbol .children
848+ # Catch case where diffusion is a scalar and turn it into an identity matrix vector field.
849+ if len (symbol .domain ) != 0 :
850+ spatial_method = self .spatial_methods [symbol .domain [0 ]]
851+ else :
852+ spatial_method = None
853+ if isinstance (spatial_method , pybamm .FiniteVolume2D ):
854+ if isinstance (left , pybamm .Scalar ) and (
855+ isinstance (right , pybamm .VectorField )
856+ or isinstance (right , pybamm .Gradient )
857+ ):
858+ left = pybamm .VectorField (left , left )
859+ elif isinstance (right , pybamm .Scalar ) and (
860+ isinstance (left , pybamm .VectorField )
861+ or isinstance (left , pybamm .Gradient )
862+ ):
863+ right = pybamm .VectorField (right , right )
814864 disc_left = self .process_symbol (left )
815865 disc_right = self .process_symbol (right )
816866 if symbol .domain == []:
867+ if isinstance (disc_left , pybamm .VectorField ) or isinstance (
868+ disc_right , pybamm .VectorField
869+ ):
870+ if not isinstance (disc_right , pybamm .VectorField ):
871+ disc_right = pybamm .VectorField (disc_right , disc_right )
872+ if not isinstance (disc_left , pybamm .VectorField ):
873+ disc_left = pybamm .VectorField (disc_left , disc_left )
874+ else : # both are vector fields already
875+ pass
876+ disc_lr = pybamm .simplify_if_constant (
877+ symbol .create_copy (
878+ new_children = [disc_left .lr_field , disc_right .lr_field ]
879+ )
880+ )
881+ disc_tb = pybamm .simplify_if_constant (
882+ symbol .create_copy (
883+ new_children = [disc_left .tb_field , disc_right .tb_field ]
884+ )
885+ )
886+ return pybamm .VectorField (disc_lr , disc_tb )
887+
817888 return pybamm .simplify_if_constant (
818889 symbol .create_copy (new_children = [disc_left , disc_right ])
819890 )
@@ -878,7 +949,10 @@ def _process_symbol(self, symbol):
878949 symbol .integration_variable [0 ].domain [0 ]
879950 ]
880951 out = integral_spatial_method .integral (
881- child , disc_child , symbol ._integration_dimension
952+ child ,
953+ disc_child ,
954+ symbol ._integration_dimension ,
955+ symbol .integration_variable ,
882956 )
883957 out .copy_domains (symbol )
884958 return out
@@ -888,6 +962,17 @@ def _process_symbol(self, symbol):
888962 child , vector_type = symbol .vector_type
889963 )
890964
965+ elif isinstance (symbol , pybamm .OneDimensionalIntegral ):
966+ child_spatial_method = self .spatial_methods [
967+ symbol .integration_domain [0 ]
968+ ]
969+ return child_spatial_method .one_dimensional_integral (
970+ symbol ,
971+ child ,
972+ disc_child ,
973+ symbol .integration_domain ,
974+ symbol .direction ,
975+ )
891976 elif isinstance (symbol , pybamm .BoundaryIntegral ):
892977 return child_spatial_method .boundary_integral (
893978 child , disc_child , symbol .region
@@ -918,6 +1003,14 @@ def _process_symbol(self, symbol):
9181003 return child_spatial_method .evaluate_at (
9191004 symbol , disc_child , symbol .position
9201005 )
1006+ elif isinstance (symbol , pybamm .UpwindDownwind2D ):
1007+ return spatial_method .upwind_or_downwind (
1008+ child ,
1009+ disc_child ,
1010+ self .bcs ,
1011+ symbol .lr_direction ,
1012+ symbol .tb_direction ,
1013+ )
9211014 elif isinstance (symbol , pybamm .UpwindDownwind ):
9221015 direction = symbol .name # upwind or downwind
9231016 return spatial_method .upwind_or_downwind (
@@ -926,8 +1019,24 @@ def _process_symbol(self, symbol):
9261019 elif isinstance (symbol , pybamm .NotConstant ):
9271020 # After discretisation, we can make the symbol constant
9281021 return disc_child
1022+ elif isinstance (symbol , pybamm .Magnitude ):
1023+ if not isinstance (disc_child , pybamm .VectorField ):
1024+ raise ValueError ("Magnitude can only be applied to a vector field" )
1025+ direction = symbol .direction
1026+ if direction == "lr" :
1027+ return disc_child .lr_field
1028+ elif direction == "tb" :
1029+ return disc_child .tb_field
1030+ else :
1031+ raise ValueError ("Invalid direction" )
9291032 else :
930- return symbol .create_copy (new_children = [disc_child ])
1033+ if isinstance (disc_child , pybamm .VectorField ):
1034+ return pybamm .VectorField (
1035+ symbol .create_copy (new_children = [disc_child .lr_field ]),
1036+ symbol .create_copy (new_children = [disc_child .tb_field ]),
1037+ )
1038+ else :
1039+ return symbol .create_copy (new_children = [disc_child ])
9311040
9321041 elif isinstance (symbol , pybamm .Function ):
9331042 disc_children = [self .process_symbol (child ) for child in symbol .children ]
@@ -996,6 +1105,12 @@ def _process_symbol(self, symbol):
9961105 elif isinstance (symbol , pybamm .CoupledVariable ):
9971106 new_symbol = self .process_symbol (symbol .children [0 ])
9981107 return new_symbol
1108+
1109+ elif isinstance (symbol , pybamm .VectorField ):
1110+ left_symbol = self .process_symbol (symbol .lr_field )
1111+ right_symbol = self .process_symbol (symbol .tb_field )
1112+ return symbol .create_copy (new_children = [left_symbol , right_symbol ])
1113+
9991114 elif isinstance (symbol , pybamm .Constant ):
10001115 # after discretisation we just care about the value, not the name
10011116 return self .process_symbol (pybamm .Scalar (symbol .value ))
0 commit comments