Skip to content

Commit 6618fe1

Browse files
authored
constant extrapolation in 2d (#5178)
* constant extrapolation in 2d * coverage fix
1 parent 7ccb364 commit 6618fe1

File tree

2 files changed

+161
-1
lines changed

2 files changed

+161
-1
lines changed

src/pybamm/spatial_methods/finite_volume_2d.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1155,6 +1155,23 @@ def boundary_value_or_flux(self, symbol, discretised_child, bcs=None):
11551155
)
11561156
additive = pybamm.Scalar(0)
11571157

1158+
elif extrap_order_value == "constant":
1159+
# For constant extrapolation, use the first column value
1160+
row_indices = np.arange(0, n_tb)
1161+
col_indices_0 = np.arange(0, n_tb * n_lr, n_lr)
1162+
vals_0 = np.ones(n_tb)
1163+
sub_matrix = csr_matrix(
1164+
(
1165+
vals_0,
1166+
(
1167+
row_indices,
1168+
col_indices_0,
1169+
),
1170+
),
1171+
shape=(n_tb, n_tb * n_lr),
1172+
)
1173+
additive = pybamm.Scalar(0)
1174+
11581175
elif extrap_order_value == "quadratic":
11591176
if use_bcs and pybamm.has_bc_of_form(
11601177
child, symbol.side, bcs, "Neumann"
@@ -1234,6 +1251,24 @@ def boundary_value_or_flux(self, symbol, discretised_child, bcs=None):
12341251
shape=(n_tb, n_tb * n_lr),
12351252
)
12361253
additive = pybamm.Scalar(0)
1254+
1255+
elif extrap_order_value == "constant":
1256+
# For constant extrapolation, use the last column value
1257+
row_indices = np.arange(0, n_tb)
1258+
col_indices_N = np.arange(n_lr - 1, n_lr * n_tb, n_lr)
1259+
vals_N = np.ones(n_tb)
1260+
sub_matrix = csr_matrix(
1261+
(
1262+
vals_N,
1263+
(
1264+
row_indices,
1265+
col_indices_N,
1266+
),
1267+
),
1268+
shape=(n_tb, n_tb * n_lr),
1269+
)
1270+
additive = pybamm.Scalar(0)
1271+
12371272
elif extrap_order_value == "quadratic":
12381273
if use_bcs and pybamm.has_bc_of_form(
12391274
child, symbol.side, bcs, "Neumann"
@@ -1359,7 +1394,16 @@ def boundary_value_or_flux(self, symbol, discretised_child, bcs=None):
13591394
raise NotImplementedError
13601395

13611396
elif side_first == "top":
1362-
if extrap_order_value == "linear":
1397+
if extrap_order_value == "constant":
1398+
first_val = np.ones(n_lr)
1399+
rows_first = np.arange(0, n_lr)
1400+
cols_first = np.arange((n_tb - 1) * n_lr, n_tb * n_lr)
1401+
sub_matrix = csr_matrix(
1402+
(first_val, (rows_first, cols_first)),
1403+
shape=(n_lr, n_lr * n_tb),
1404+
)
1405+
additive = pybamm.Scalar(0)
1406+
elif extrap_order_value == "linear":
13631407
if use_bcs and pybamm.has_bc_of_form(
13641408
child, side_first, bcs, "Neumann"
13651409
):
@@ -1452,6 +1496,26 @@ def boundary_value_or_flux(self, symbol, discretised_child, bcs=None):
14521496
shape=(1, n_tb),
14531497
)
14541498
sub_matrix = sub_matrix_second @ sub_matrix
1499+
1500+
elif extrap_order_value == "constant":
1501+
# For constant extrapolation, use the bottom row value
1502+
# Select bottom row elements: 0, n_tb, 2*n_tb, ..., (n_lr-1)*n_tb
1503+
row_indices = [0]
1504+
col_indices = [0]
1505+
vals = [1]
1506+
sub_matrix_second = csr_matrix(
1507+
(
1508+
vals,
1509+
(
1510+
row_indices,
1511+
col_indices,
1512+
),
1513+
),
1514+
shape=(1, n_tb),
1515+
)
1516+
additive = pybamm.Scalar(0)
1517+
sub_matrix = sub_matrix_second @ sub_matrix
1518+
14551519
else:
14561520
dx0 = dx0_tb
14571521
dx1 = dx1_tb
@@ -1485,6 +1549,26 @@ def boundary_value_or_flux(self, symbol, discretised_child, bcs=None):
14851549
shape=(1, n_tb),
14861550
)
14871551
sub_matrix = sub_matrix_second @ sub_matrix
1552+
1553+
elif extrap_order_value == "constant":
1554+
# For constant extrapolation, use the top row value
1555+
# Select top row elements: n_tb-1, 2*n_tb-1, 3*n_tb-1, ..., n_lr*n_tb-1
1556+
row_indices = [0]
1557+
col_indices = [n_tb - 1]
1558+
vals = [1]
1559+
sub_matrix_second = csr_matrix(
1560+
(
1561+
vals,
1562+
(
1563+
row_indices,
1564+
col_indices,
1565+
),
1566+
),
1567+
shape=(1, n_tb),
1568+
)
1569+
additive = pybamm.Scalar(0)
1570+
sub_matrix = sub_matrix_second @ sub_matrix
1571+
14881572
else:
14891573
dxN = dxN_tb
14901574
dxNm1 = dxNm1_tb

tests/unit/test_spatial_methods/test_finite_volume_2d/test_extrapolation.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ def test_boundary_value_finite_volume_2d(self, use_bcs, order):
1212
if order == "quadratic" and use_bcs:
1313
# Not implemented
1414
return
15+
if order == "constant" and use_bcs:
16+
# Constant extrapolation doesn't use boundary conditions
17+
return
1518
# Create discretisation
1619
mesh = get_mesh_for_testing_2d()
1720
spatial_methods = {
@@ -447,3 +450,76 @@ def test_boundary_gradient_quadratic_function_finite_volume_2d(self, use_bcs):
447450
solutions_TB[direction],
448451
decimal=5,
449452
)
453+
454+
def test_boundary_value_constant_extrapolation(self):
455+
"""Test constant extrapolation for left and right boundaries."""
456+
# Create discretisation with constant extrapolation
457+
mesh = get_mesh_for_testing_2d()
458+
spatial_methods = {
459+
"macroscale": pybamm.FiniteVolume2D(
460+
{
461+
"extrapolation": {
462+
"order": {"gradient": "linear", "value": "constant"},
463+
"use bcs": False, # Constant extrapolation doesn't use BCS
464+
}
465+
}
466+
),
467+
}
468+
disc = pybamm.Discretisation(mesh, spatial_methods)
469+
submesh = mesh[("negative electrode", "separator", "positive electrode")]
470+
471+
# Create a variable and test data
472+
var = pybamm.Variable(
473+
"test_var", ["negative electrode", "separator", "positive electrode"]
474+
)
475+
disc.set_variable_slices([var])
476+
477+
# Create test data where each column has a distinct value
478+
# Column i has value i+1, so we can verify constant extrapolation takes from boundary columns
479+
LR, TB = np.meshgrid(submesh.nodes_lr, submesh.nodes_tb)
480+
lr = LR.flatten()
481+
tb = TB.flatten()
482+
483+
# For constant extrapolation:
484+
# - Left boundary: should return constant value 1 (leftmost column value)
485+
# - Right boundary: should return constant value len(submesh.nodes_lr) (rightmost column value)
486+
expected_left = np.full(len(submesh.nodes_tb), submesh.nodes_lr[0])
487+
expected_right = np.full(len(submesh.nodes_tb), submesh.nodes_lr[-1])
488+
expected_top = np.full(len(submesh.nodes_lr), submesh.nodes_tb[-1])
489+
expected_bottom = np.full(len(submesh.nodes_lr), submesh.nodes_tb[0])
490+
491+
# Test left boundary
492+
boundary_value_left = pybamm.BoundaryValue(var, "left")
493+
discretised_left = disc.process_symbol(boundary_value_left)
494+
result_left = discretised_left.evaluate(y=lr).flatten()
495+
np.testing.assert_array_almost_equal(result_left, expected_left)
496+
497+
# Test right boundary
498+
boundary_value_right = pybamm.BoundaryValue(var, "right")
499+
discretised_right = disc.process_symbol(boundary_value_right)
500+
result_right = discretised_right.evaluate(y=lr).flatten()
501+
np.testing.assert_array_almost_equal(result_right, expected_right)
502+
503+
# Test top-left (s/b same as left)
504+
boundary_value_top_left = pybamm.BoundaryValue(var, "top-left")
505+
discretised_top_left = disc.process_symbol(boundary_value_top_left)
506+
result_top_left = discretised_top_left.evaluate(y=lr).flatten()
507+
np.testing.assert_array_almost_equal(result_top_left, submesh.nodes_lr[0])
508+
509+
# Test bottom-right (s/b same as right)
510+
boundary_value_bottom_right = pybamm.BoundaryValue(var, "bottom-right")
511+
discretised_bottom_right = disc.process_symbol(boundary_value_bottom_right)
512+
result_bottom_right = discretised_bottom_right.evaluate(y=lr).flatten()
513+
np.testing.assert_array_almost_equal(result_bottom_right, submesh.nodes_lr[-1])
514+
515+
# Test top boundary
516+
boundary_value_top = pybamm.BoundaryValue(var, "top")
517+
discretised_top = disc.process_symbol(boundary_value_top)
518+
result_top = discretised_top.evaluate(y=tb).flatten()
519+
np.testing.assert_array_almost_equal(result_top, expected_top)
520+
521+
# Test bottom boundary
522+
boundary_value_bottom = pybamm.BoundaryValue(var, "bottom")
523+
discretised_bottom = disc.process_symbol(boundary_value_bottom)
524+
result_bottom = discretised_bottom.evaluate(y=tb).flatten()
525+
np.testing.assert_array_almost_equal(result_bottom, expected_bottom)

0 commit comments

Comments
 (0)