|
1 | 1 | import os |
2 | 2 | import xarray as xr |
3 | 3 | import warnings |
| 4 | +import numpy as np |
4 | 5 | import numpy.testing as nt |
5 | 6 | import pytest |
6 | 7 |
|
7 | 8 | import uxarray as ux |
8 | 9 | from uxarray.constants import INT_DTYPE, INT_FILL_VALUE |
| 10 | +from uxarray.io._scrip import _detect_multigrid |
9 | 11 |
|
10 | 12 |
|
11 | 13 | def test_read_ugrid(gridpath, mesh_constants): |
@@ -50,3 +52,69 @@ def test_to_xarray_ugrid(gridpath): |
50 | 52 | reloaded_grid._ds.close() |
51 | 53 | del reloaded_grid |
52 | 54 | os.remove("scrip_ugrid_csne8.nc") |
| 55 | + |
| 56 | + |
| 57 | +def test_oasis_multigrid_format_detection(): |
| 58 | + """Detect OASIS-style multi-grid naming.""" |
| 59 | + ds = xr.Dataset() |
| 60 | + ds["ocn.cla"] = xr.DataArray(np.random.rand(100, 4), dims=["nc_ocn", "nv_ocn"]) |
| 61 | + ds["ocn.clo"] = xr.DataArray(np.random.rand(100, 4), dims=["nc_ocn", "nv_ocn"]) |
| 62 | + ds["atm.cla"] = xr.DataArray(np.random.rand(200, 4), dims=["nc_atm", "nv_atm"]) |
| 63 | + ds["atm.clo"] = xr.DataArray(np.random.rand(200, 4), dims=["nc_atm", "nv_atm"]) |
| 64 | + |
| 65 | + format_type, grids = _detect_multigrid(ds) |
| 66 | + assert format_type == "multi_scrip" |
| 67 | + assert set(grids.keys()) == {"ocn", "atm"} |
| 68 | + |
| 69 | + |
| 70 | +def test_open_multigrid_with_masks(gridpath): |
| 71 | + """Load OASIS multi-grids with masks applied.""" |
| 72 | + grid_file = gridpath("scrip", "oasis", "grids.nc") |
| 73 | + mask_file = gridpath("scrip", "oasis", "masks.nc") |
| 74 | + |
| 75 | + grids = ux.open_multigrid(grid_file, mask_filename=mask_file) |
| 76 | + assert grids["ocn"].n_face == 8 |
| 77 | + assert grids["atm"].n_face == 20 |
| 78 | + |
| 79 | + ocean_only = ux.open_multigrid( |
| 80 | + grid_file, gridnames=["ocn"], mask_filename=mask_file |
| 81 | + ) |
| 82 | + assert set(ocean_only.keys()) == {"ocn"} |
| 83 | + assert ocean_only["ocn"].n_face == 8 |
| 84 | + |
| 85 | + grid_names = ux.list_grid_names(grid_file) |
| 86 | + assert set(grid_names) == {"ocn", "atm"} |
| 87 | + |
| 88 | + |
| 89 | +def test_open_multigrid_mask_active_value_default(gridpath): |
| 90 | + """Default mask semantics keep value==1 active for both grids.""" |
| 91 | + grid_file = gridpath("scrip", "oasis", "grids.nc") |
| 92 | + mask_file = gridpath("scrip", "oasis", "masks_no_atm.nc") |
| 93 | + |
| 94 | + grids = ux.open_multigrid(grid_file, mask_filename=mask_file) |
| 95 | + |
| 96 | + with xr.open_dataset(mask_file) as mask_ds: |
| 97 | + expected_ocn = int(mask_ds["ocn.msk"].values.sum()) |
| 98 | + expected_atm = int(mask_ds["atm.msk"].values.sum()) |
| 99 | + |
| 100 | + assert grids["ocn"].n_face == expected_ocn |
| 101 | + assert grids["atm"].n_face == expected_atm |
| 102 | + |
| 103 | + |
| 104 | +def test_open_multigrid_mask_active_value_per_grid_override(gridpath): |
| 105 | + """Per-grid override supports masks with different active values.""" |
| 106 | + grid_file = gridpath("scrip", "oasis", "grids.nc") |
| 107 | + mask_file = gridpath("scrip", "oasis", "masks_no_atm.nc") |
| 108 | + |
| 109 | + grids = ux.open_multigrid( |
| 110 | + grid_file, |
| 111 | + mask_filename=mask_file, |
| 112 | + mask_active_value={"atm": 0, "ocn": 1}, |
| 113 | + ) |
| 114 | + |
| 115 | + with xr.open_dataset(mask_file) as mask_ds: |
| 116 | + expected_ocn = int(mask_ds["ocn.msk"].values.sum()) |
| 117 | + expected_atm = int((mask_ds["atm.msk"].values == 0).sum()) |
| 118 | + |
| 119 | + assert grids["ocn"].n_face == expected_ocn |
| 120 | + assert grids["atm"].n_face == expected_atm |
0 commit comments