Skip to content

Commit 9859bb6

Browse files
committed
o Fix all test cases, add unittest.TestCase inheritance. TestCase allows the class to inherit unittest’s testing functionality, such as assertEqual, assertTrue, and automatic test discovery. This change is needed for writing and executing proper unit tests in a consistent framework, improving code quality and enabling integration with test runners.
1 parent 1d20619 commit 9859bb6

File tree

4 files changed

+92
-89
lines changed

4 files changed

+92
-89
lines changed

test/test_cross_sections.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import uxarray as ux
22
import pytest
33
import numpy as np
4+
5+
from unittest import TestCase
46
from pathlib import Path
57
import os
68

@@ -17,7 +19,7 @@
1719

1820

1921

20-
class TestQuadHex:
22+
class TestQuadHex(TestCase):
2123
"""The quad hexagon grid contains four faces.
2224
2325
Top Left Face: Index 1
@@ -108,7 +110,7 @@ def test_constant_lon_cross_section_uxds(self):
108110
uxds['t2m'].cross_section.constant_longitude(lon=10.0, )
109111

110112

111-
class TestCubeSphere:
113+
class TestCubeSphere(TestCase):
112114

113115
def test_north_pole(self):
114116
uxgrid = ux.open_grid(cube_sphere_grid)
@@ -132,7 +134,7 @@ def test_south_pole(self):
132134

133135

134136

135-
class TestCandidateFacesUsingBounds:
137+
class TestCandidateFacesUsingBounds(TestCase):
136138

137139
def test_constant_lat(self):
138140
bounds = np.array([
@@ -173,4 +175,4 @@ def test_constant_lat_out_of_bounds(self):
173175
face_bounds_lat=bounds_rad[:, 0],
174176
)
175177

176-
assert len(candidate_faces) == 0
178+
assert len(candidate_faces) == 0

test/test_esmf.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import uxarray as ux
2-
2+
from unittest import TestCase
33
import os
44
from pathlib import Path
55

@@ -12,36 +12,36 @@
1212

1313

1414

15+
class Test_ESMF(TestCase):
16+
def test_read_esmf(self):
17+
"""Tests the reading of an ESMF grid file and its encoding into the UGRID
18+
conventions."""
1519

16-
def test_read_esmf():
17-
"""Tests the reading of an ESMF grid file and its encoding into the UGRID
18-
conventions."""
19-
20-
uxgrid = ux.open_grid(esmf_ne30_grid_path)
20+
uxgrid = ux.open_grid(esmf_ne30_grid_path)
2121

22-
dims = ['n_node', 'n_face', 'n_max_face_nodes']
22+
dims = ['n_node', 'n_face', 'n_max_face_nodes']
2323

24-
coords = ['node_lon', 'node_lat', 'face_lon', 'face_lat']
24+
coords = ['node_lon', 'node_lat', 'face_lon', 'face_lat']
2525

26-
conns = ['face_node_connectivity', 'n_nodes_per_face']
26+
conns = ['face_node_connectivity', 'n_nodes_per_face']
2727

28-
for dim in dims:
29-
assert dim in uxgrid._ds.dims
28+
for dim in dims:
29+
assert dim in uxgrid._ds.dims
3030

31-
for coord in coords:
32-
assert coord in uxgrid._ds
31+
for coord in coords:
32+
assert coord in uxgrid._ds
3333

34-
for conn in conns:
35-
assert conn in uxgrid._ds
34+
for conn in conns:
35+
assert conn in uxgrid._ds
3636

37-
def test_read_esmf_dataset():
38-
"""Tests the constructing of a UxDataset from an ESMF Grid and Data
39-
File."""
37+
def test_read_esmf_dataset(self):
38+
"""Tests the constructing of a UxDataset from an ESMF Grid and Data
39+
File."""
4040

41-
uxds = ux.open_dataset(esmf_ne30_grid_path, esmf_ne30_data_path)
41+
uxds = ux.open_dataset(esmf_ne30_grid_path, esmf_ne30_data_path)
4242

4343

44-
dims = ['n_node', 'n_face']
44+
dims = ['n_node', 'n_face']
4545

46-
for dim in dims:
47-
assert dim in uxds.dims
46+
for dim in dims:
47+
assert dim in uxds.dims

test/test_from_topology.py

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import uxarray as ux
2-
2+
from unittest import TestCase
33
from uxarray.constants import INT_FILL_VALUE
44
import numpy.testing as nt
55
import os
@@ -18,70 +18,70 @@
1818

1919

2020

21+
class TestFromTopology(TestCase):
22+
def test_minimal_class_method(self):
23+
"""Tests the minimal required variables for constructing a grid using the
24+
from topology class method."""
2125

22-
def test_minimal_class_method():
23-
"""Tests the minimal required variables for constructing a grid using the
24-
from topology class method."""
26+
for grid_path in GRID_PATHS:
27+
uxgrid = ux.open_grid(grid_path)
2528

26-
for grid_path in GRID_PATHS:
27-
uxgrid = ux.open_grid(grid_path)
29+
uxgrid_ft = ux.Grid.from_topology(node_lon=uxgrid.node_lon.values,
30+
node_lat=uxgrid.node_lat.values,
31+
face_node_connectivity=uxgrid.face_node_connectivity.values,
32+
fill_value=INT_FILL_VALUE,
33+
start_index=0)
2834

29-
uxgrid_ft = ux.Grid.from_topology(node_lon=uxgrid.node_lon.values,
30-
node_lat=uxgrid.node_lat.values,
31-
face_node_connectivity=uxgrid.face_node_connectivity.values,
32-
fill_value=INT_FILL_VALUE,
33-
start_index=0)
35+
nt.assert_array_equal(uxgrid.node_lon.values, uxgrid_ft.node_lon.values)
36+
nt.assert_array_equal(uxgrid.node_lat.values, uxgrid_ft.node_lat.values)
37+
nt.assert_array_equal(uxgrid.face_node_connectivity.values, uxgrid_ft.face_node_connectivity.values)
3438

35-
nt.assert_array_equal(uxgrid.node_lon.values, uxgrid_ft.node_lon.values)
36-
nt.assert_array_equal(uxgrid.node_lat.values, uxgrid_ft.node_lat.values)
37-
nt.assert_array_equal(uxgrid.face_node_connectivity.values, uxgrid_ft.face_node_connectivity.values)
3839

40+
def test_minimal_api(self):
41+
"""Tests the minimal required variables for constructing a grid using the
42+
``ux.open_dataset`` method."""
3943

40-
def test_minimal_api():
41-
"""Tests the minimal required variables for constructing a grid using the
42-
``ux.open_dataset`` method."""
44+
for grid_path in GRID_PATHS:
45+
uxgrid = ux.open_grid(grid_path)
4346

44-
for grid_path in GRID_PATHS:
45-
uxgrid = ux.open_grid(grid_path)
47+
uxgrid_ft = ux.Grid.from_topology(node_lon=uxgrid.node_lon.values,
48+
node_lat=uxgrid.node_lat.values,
49+
face_node_connectivity=uxgrid.face_node_connectivity.values,
50+
fill_value=INT_FILL_VALUE,
51+
start_index=0)
4652

47-
uxgrid_ft = ux.Grid.from_topology(node_lon=uxgrid.node_lon.values,
48-
node_lat=uxgrid.node_lat.values,
49-
face_node_connectivity=uxgrid.face_node_connectivity.values,
50-
fill_value=INT_FILL_VALUE,
51-
start_index=0)
53+
grid_topology = {'node_lon': uxgrid.node_lon.values,
54+
'node_lat': uxgrid.node_lat.values,
55+
'face_node_connectivity': uxgrid.face_node_connectivity.values,
56+
'fill_value': INT_FILL_VALUE,
57+
'start_index': 0}
5258

53-
grid_topology = {'node_lon': uxgrid.node_lon.values,
54-
'node_lat': uxgrid.node_lat.values,
55-
'face_node_connectivity': uxgrid.face_node_connectivity.values,
56-
'fill_value': INT_FILL_VALUE,
57-
'start_index': 0}
59+
uxgrid_ft = ux.open_grid(grid_topology)
5860

59-
uxgrid_ft = ux.open_grid(grid_topology)
60-
61-
nt.assert_array_equal(uxgrid.node_lon.values, uxgrid_ft.node_lon.values)
62-
nt.assert_array_equal(uxgrid.node_lat.values, uxgrid_ft.node_lat.values)
63-
nt.assert_array_equal(uxgrid.face_node_connectivity.values, uxgrid_ft.face_node_connectivity.values)
61+
nt.assert_array_equal(uxgrid.node_lon.values, uxgrid_ft.node_lon.values)
62+
nt.assert_array_equal(uxgrid.node_lat.values, uxgrid_ft.node_lat.values)
63+
nt.assert_array_equal(uxgrid.face_node_connectivity.values, uxgrid_ft.face_node_connectivity.values)
6464

6565

66-
def test_dataset():
67-
uxds = ux.open_dataset(GRID_PATHS[0], GRID_PATHS[0])
66+
def test_dataset(self):
67+
uxds = ux.open_dataset(GRID_PATHS[0], GRID_PATHS[0])
6868

69-
grid_topology = {'node_lon': uxds.uxgrid.node_lon.values,
70-
'node_lat': uxds.uxgrid.node_lat.values,
71-
'face_node_connectivity': uxds.uxgrid.face_node_connectivity.values,
72-
'fill_value': INT_FILL_VALUE,
73-
'start_index': 0,
74-
"dims_dict" : {"nVertices": "n_node"}}
69+
grid_topology = {'node_lon': uxds.uxgrid.node_lon.values,
70+
'node_lat': uxds.uxgrid.node_lat.values,
71+
'face_node_connectivity': uxds.uxgrid.face_node_connectivity.values,
72+
'fill_value': INT_FILL_VALUE,
73+
'start_index': 0,
74+
"dims_dict" : {"nVertices": "n_node"}}
7575

7676

77-
uxds_ft = ux.open_grid(grid_topology, GRID_PATHS[1])
77+
uxds_ft = ux.open_grid(grid_topology, GRID_PATHS[1])
7878

79-
uxgrid = uxds.uxgrid
80-
uxgrid_ft = uxds_ft
79+
uxgrid = uxds.uxgrid
80+
uxgrid_ft = uxds_ft
8181

8282

83-
nt.assert_array_equal(uxgrid.node_lon.values, uxgrid_ft.node_lon.values)
84-
nt.assert_array_equal(uxgrid.node_lat.values, uxgrid_ft.node_lat.values)
85-
nt.assert_array_equal(uxgrid.face_node_connectivity.values, uxgrid_ft.face_node_connectivity.values)
83+
nt.assert_array_equal(uxgrid.node_lon.values, uxgrid_ft.node_lon.values)
84+
nt.assert_array_equal(uxgrid.node_lat.values, uxgrid_ft.node_lat.values)
85+
nt.assert_array_equal(uxgrid.face_node_connectivity.values, uxgrid_ft.face_node_connectivity.values)
8686

87-
assert uxds_ft.dims == {'n_face', 'n_node', 'n_max_face_nodes'}
87+
assert uxds_ft.dims == {'n_face', 'n_node', 'n_max_face_nodes'}

test/test_geos.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,32 @@
11
import uxarray as ux
22

33
import os
4+
from unittest import TestCase
45
from pathlib import Path
56

67
current_path = Path(os.path.dirname(os.path.realpath(__file__)))
78

89
gridfile_geos_cs = current_path / "meshfiles" / "geos-cs" / "c12" / "test-c12.native.nc4"
910

1011

12+
class Test_GEOS(TestCase):
13+
def test_read_geos_cs_grid(self):
14+
"""Tests the conversion of a CS12 GEOS-CS Grid to the UGRID conventions.
1115
12-
def test_read_geos_cs_grid():
13-
"""Tests the conversion of a CS12 GEOS-CS Grid to the UGRID conventions.
16+
A CS12 grid has 6 faces, each with 12x12 faces and 13x13 nodes each.
17+
"""
1418

15-
A CS12 grid has 6 faces, each with 12x12 faces and 13x13 nodes each.
16-
"""
19+
uxgrid = ux.open_grid(gridfile_geos_cs)
1720

18-
uxgrid = ux.open_grid(gridfile_geos_cs)
21+
n_face = 6 * 12 * 12
22+
n_node = 6 * 13 * 13
1923

20-
n_face = 6 * 12 * 12
21-
n_node = 6 * 13 * 13
24+
assert uxgrid.n_face == n_face
25+
assert uxgrid.n_node == n_node
2226

23-
assert uxgrid.n_face == n_face
24-
assert uxgrid.n_node == n_node
2527

28+
def test_read_geos_cs_uxds(self):
29+
"""Tests the creating of a UxDataset from a CS12 GEOS-CS Grid."""
30+
uxds = ux.open_dataset(gridfile_geos_cs, gridfile_geos_cs)
2631

27-
def test_read_geos_cs_uxds():
28-
"""Tests the creating of a UxDataset from a CS12 GEOS-CS Grid."""
29-
uxds = ux.open_dataset(gridfile_geos_cs, gridfile_geos_cs)
30-
31-
assert uxds['T'].shape[-1] == uxds.uxgrid.n_face
32+
assert uxds['T'].shape[-1] == uxds.uxgrid.n_face

0 commit comments

Comments
 (0)