Skip to content

Commit 395610d

Browse files
committed
Add tests for total_horizontal_gradient function
Includes tests for: - Agreement with a synthetic model - Rejection of 1D and 3D grids - Handling of grids containing NaN values Ensures correctness and robustness of the total_horizontal_gradient filter.
1 parent 8e3f55c commit 395610d

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed

harmonica/tests/test_transformations.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
reduction_to_pole,
2828
tilt_angle,
2929
total_gradient_amplitude,
30+
total_horizontal_gradient,
3031
upward_continuation,
3132
)
3233
from .utils import root_mean_square_error
@@ -602,6 +603,70 @@ def test_invalid_grid_with_nans(self, sample_potential):
602603
total_gradient_amplitude(sample_potential)
603604

604605

606+
class TestTotalHorizontalGradient:
607+
"""
608+
Test total_horizontal_gradient function
609+
"""
610+
611+
def test_against_synthetic(
612+
self, sample_potential, sample_g_n, sample_g_e
613+
):
614+
"""
615+
Test total_horizontal_gradient function against the synthetic model
616+
"""
617+
pad_width = {
618+
"easting": sample_potential.easting.size // 3,
619+
"northing": sample_potential.northing.size // 3,
620+
}
621+
potential_padded = xrft.pad(
622+
sample_potential.drop_vars("upward"),
623+
pad_width=pad_width,
624+
)
625+
thg = total_horizontal_gradient(potential_padded)
626+
thg = xrft.unpad(thg, pad_width)
627+
628+
trim = 6
629+
thg = thg[trim:-trim, trim:-trim]
630+
g_e = sample_g_e[trim:-trim, trim:-trim] * 1e-5 # convert to SI
631+
g_n = sample_g_n[trim:-trim, trim:-trim] * 1e-5
632+
g_thg = np.sqrt(g_e**2 + g_n**2)
633+
rms = root_mean_square_error(thg, g_thg)
634+
assert rms / np.abs(g_thg).max() < 0.1
635+
636+
def test_invalid_grid_single_dimension(self):
637+
"""
638+
Check if total_horizontal_gradient raises error on grid with single
639+
dimension
640+
"""
641+
x = np.linspace(0, 10, 11)
642+
y = x**2
643+
grid = xr.DataArray(y, coords={"x": x}, dims=("x",))
644+
with pytest.raises(ValueError, match="Invalid grid with 1 dimensions."):
645+
total_horizontal_gradient(grid)
646+
647+
def test_invalid_grid_three_dimensions(self):
648+
"""
649+
Check if total_horizontal_gradient raises error on grid with three
650+
dimensions
651+
"""
652+
x = np.linspace(0, 10, 11)
653+
y = np.linspace(-4, 4, 9)
654+
z = np.linspace(20, 30, 5)
655+
xx, yy, zz = np.meshgrid(x, y, z)
656+
data = xx + yy + zz
657+
grid = xr.DataArray(data, coords={"x": x, "y": y, "z": z}, dims=("y", "x", "z"))
658+
with pytest.raises(ValueError, match="Invalid grid with 3 dimensions."):
659+
total_horizontal_gradient(grid)
660+
661+
def test_invalid_grid_with_nans(self, sample_potential):
662+
"""
663+
Check if total_horizontal_gradient raises error if grid contains nans
664+
"""
665+
sample_potential.values[0, 0] = np.nan
666+
with pytest.raises(ValueError, match="Found nan"):
667+
total_horizontal_gradient(sample_potential)
668+
669+
605670
class TestTilt:
606671
"""
607672
Test tilt function

0 commit comments

Comments
 (0)