Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 84 additions & 25 deletions test/test_cross_sections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uxarray as ux
import pytest
import numpy as np
from pathlib import Path
import os

Expand All @@ -10,7 +11,9 @@
quad_hex_grid_path = current_path / 'meshfiles' / "ugrid" / "quad-hexagon" / 'grid.nc'
quad_hex_data_path = current_path / 'meshfiles' / "ugrid" / "quad-hexagon" / 'data.nc'

cube_sphere_grid = current_path / "meshfiles" / "geos-cs" / "c12" / "test-c12.native.nc4"
cube_sphere_grid = current_path / "meshfiles" / "ugrid" / "outCSne30" / "outCSne30.ug"

from uxarray.grid.intersections import constant_lat_intersections_face_bounds



Expand All @@ -32,93 +35,149 @@ class TestQuadHex:
All four faces intersect a constant latitude of 0.0
"""

def test_constant_lat_cross_section_grid(self):
@pytest.mark.parametrize("use_spherical_bounding_box", [True, False])
def test_constant_lat_cross_section_grid(self, use_spherical_bounding_box):



uxgrid = ux.open_grid(quad_hex_grid_path)

grid_top_two = uxgrid.cross_section.constant_latitude(lat=0.1)
grid_top_two = uxgrid.cross_section.constant_latitude(lat=0.1, use_spherical_bounding_box=use_spherical_bounding_box)

assert grid_top_two.n_face == 2

grid_bottom_two = uxgrid.cross_section.constant_latitude(lat=-0.1)
grid_bottom_two = uxgrid.cross_section.constant_latitude(lat=-0.1, use_spherical_bounding_box=use_spherical_bounding_box)

assert grid_bottom_two.n_face == 2

grid_all_four = uxgrid.cross_section.constant_latitude(lat=0.0)
grid_all_four = uxgrid.cross_section.constant_latitude(lat=0.0, use_spherical_bounding_box=use_spherical_bounding_box)

assert grid_all_four.n_face == 4

with pytest.raises(ValueError):
# no intersections found at this line
uxgrid.cross_section.constant_latitude(lat=10.0)
uxgrid.cross_section.constant_latitude(lat=10.0, use_spherical_bounding_box=use_spherical_bounding_box)

def test_constant_lon_cross_section_grid(self):
@pytest.mark.parametrize("use_spherical_bounding_box", [False])
def test_constant_lon_cross_section_grid(self, use_spherical_bounding_box):
uxgrid = ux.open_grid(quad_hex_grid_path)

grid_left_two = uxgrid.cross_section.constant_longitude(lon=-0.1)
grid_left_two = uxgrid.cross_section.constant_longitude(lon=-0.1, use_spherical_bounding_box=use_spherical_bounding_box)

assert grid_left_two.n_face == 2

grid_right_two = uxgrid.cross_section.constant_longitude(lon=0.2)
grid_right_two = uxgrid.cross_section.constant_longitude(lon=0.2, use_spherical_bounding_box=use_spherical_bounding_box)

assert grid_right_two.n_face == 2

with pytest.raises(ValueError):
# no intersections found at this line
uxgrid.cross_section.constant_longitude(lon=10.0)


def test_constant_lat_cross_section_uxds(self):
@pytest.mark.parametrize("use_spherical_bounding_box", [False])
def test_constant_lat_cross_section_uxds(self, use_spherical_bounding_box):
uxds = ux.open_dataset(quad_hex_grid_path, quad_hex_data_path)
uxds.uxgrid.normalize_cartesian_coordinates()

da_top_two = uxds['t2m'].cross_section.constant_latitude(lat=0.1)
da_top_two = uxds['t2m'].cross_section.constant_latitude(lat=0.1, use_spherical_bounding_box=use_spherical_bounding_box)

nt.assert_array_equal(da_top_two.data, uxds['t2m'].isel(n_face=[1, 2]).data)

da_bottom_two = uxds['t2m'].cross_section.constant_latitude(lat=-0.1)
da_bottom_two = uxds['t2m'].cross_section.constant_latitude(lat=-0.1, use_spherical_bounding_box=use_spherical_bounding_box)

nt.assert_array_equal(da_bottom_two.data, uxds['t2m'].isel(n_face=[0, 3]).data)

da_all_four = uxds['t2m'].cross_section.constant_latitude(lat=0.0)
da_all_four = uxds['t2m'].cross_section.constant_latitude(lat=0.0, use_spherical_bounding_box=use_spherical_bounding_box)

nt.assert_array_equal(da_all_four.data , uxds['t2m'].data)

with pytest.raises(ValueError):
# no intersections found at this line
uxds['t2m'].cross_section.constant_latitude(lat=10.0)
uxds['t2m'].cross_section.constant_latitude(lat=10.0, use_spherical_bounding_box=use_spherical_bounding_box)

def test_constant_lon_cross_section_uxds(self):
@pytest.mark.parametrize("use_spherical_bounding_box", [False])
def test_constant_lon_cross_section_uxds(self, use_spherical_bounding_box):
uxds = ux.open_dataset(quad_hex_grid_path, quad_hex_data_path)
uxds.uxgrid.normalize_cartesian_coordinates()

da_left_two = uxds['t2m'].cross_section.constant_longitude(lon=-0.1)
da_left_two = uxds['t2m'].cross_section.constant_longitude(lon=-0.1, use_spherical_bounding_box=use_spherical_bounding_box)

nt.assert_array_equal(da_left_two.data, uxds['t2m'].isel(n_face=[0, 2]).data)

da_right_two = uxds['t2m'].cross_section.constant_longitude(lon=0.2)
da_right_two = uxds['t2m'].cross_section.constant_longitude(lon=0.2, use_spherical_bounding_box=use_spherical_bounding_box)

nt.assert_array_equal(da_right_two.data, uxds['t2m'].isel(n_face=[1, 3]).data)

with pytest.raises(ValueError):
# no intersections found at this line
uxds['t2m'].cross_section.constant_longitude(lon=10.0)
uxds['t2m'].cross_section.constant_longitude(lon=10.0, use_spherical_bounding_box=use_spherical_bounding_box)


class TestGeosCubeSphere:
def test_north_pole(self):
class TestCubeSphere:
@pytest.mark.parametrize("use_spherical_bounding_box", [True, False])
def test_north_pole(self, use_spherical_bounding_box):
uxgrid = ux.open_grid(cube_sphere_grid)

lats = [89.85, 89.9, 89.95, 89.99]

for lat in lats:
cross_grid = uxgrid.cross_section.constant_latitude(lat=lat)
cross_grid = uxgrid.cross_section.constant_latitude(lat=lat, use_spherical_bounding_box=use_spherical_bounding_box)
# Cube sphere grid should have 4 faces centered around the pole
assert cross_grid.n_face == 4

def test_south_pole(self):
@pytest.mark.parametrize("use_spherical_bounding_box", [True, False])
def test_south_pole(self, use_spherical_bounding_box):
uxgrid = ux.open_grid(cube_sphere_grid)

lats = [-89.85, -89.9, -89.95, -89.99]

for lat in lats:
cross_grid = uxgrid.cross_section.constant_latitude(lat=lat)
cross_grid = uxgrid.cross_section.constant_latitude(lat=lat, use_spherical_bounding_box=use_spherical_bounding_box)
# Cube sphere grid should have 4 faces centered around the pole
assert cross_grid.n_face == 4



class TestCandidateFacesUsingBounds:

def test_constant_lat(self):
bounds = np.array([
[[-45, 45], [0, 360]],
[[-90, -45], [0, 360]],
[[45, 90], [0, 360]],
])

bounds_rad = np.deg2rad(bounds)

const_lat = 0

candidate_faces = constant_lat_intersections_face_bounds(
lat=const_lat,
face_min_lat_rad=bounds_rad[:, 0, 0],
face_max_lat_rad=bounds_rad[:, 0, 1],
)

# Expected output
expected_faces = np.array([0])

# Test the function output
nt.assert_array_equal(candidate_faces, expected_faces)

def test_constant_lat_out_of_bounds(self):

bounds = np.array([
[[-45, 45], [0, 360]],
[[-90, -45], [0, 360]],
[[45, 90], [0, 360]],
])

bounds_rad = np.deg2rad(bounds)

const_lat = 100

candidate_faces = constant_lat_intersections_face_bounds(
lat=const_lat,
face_min_lat_rad=bounds_rad[:, 0, 0],
face_max_lat_rad=bounds_rad[:, 0, 1],
)

assert len(candidate_faces) == 0
3 changes: 3 additions & 0 deletions uxarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def disable_fma():
uxarray.constants.ENABLE_FMA = False


disable_fma()


__all__ = (
"open_grid",
"open_dataset",
Expand Down
33 changes: 12 additions & 21 deletions uxarray/cross_sections/dataarray_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,16 @@ def __repr__(self):

return prefix + methods_heading

def constant_latitude(self, lat: float, method="fast"):
def constant_latitude(self, lat: float, use_spherical_bounding_box=False):
"""Extracts a cross-section of the data array at a specified constant
latitude.

Parameters
----------
lat : float
The latitude at which to extract the cross-section, in degrees.
method : str, optional
The internal method to use when identifying faces at the constant latitude.
Options are:
- 'fast': Uses a faster but potentially less accurate method for face identification.
- 'accurate': Uses a slower but more accurate method.
Default is 'fast'.
use_spherical_bounding_box : bool, optional
If True, uses a spherical bounding box for intersection calculations.

Raises
------
Expand All @@ -44,30 +40,23 @@ def constant_latitude(self, lat: float, method="fast"):
Examples
--------
>>> uxda.constant_latitude_cross_section(lat=-15.5)

Notes
-----
The accuracy and performance of the function can be controlled using the `method` parameter.
For higher precision requreiments, consider using method='acurate'.
"""
faces = self.uxda.uxgrid.get_faces_at_constant_latitude(lat, method)
faces = self.uxda.uxgrid.get_faces_at_constant_latitude(
lat, use_spherical_bounding_box
)

return self.uxda.isel(n_face=faces)

def constant_longitude(self, lon: float, method="fast"):
def constant_longitude(self, lon: float, use_spherical_bounding_box=False):
"""Extracts a cross-section of the data array at a specified constant
longitude.

Parameters
----------
lon : float
The longitude at which to extract the cross-section, in degrees.
method : str, optional
The internal method to use when identifying faces at the constant longitude.
Options are:
- 'fast': Uses a faster but potentially less accurate method for face identification.
- 'accurate': Uses a slower but more accurate method.
Default is 'fast'.
use_spherical_bounding_box : bool, optional
If True, uses a spherical bounding box for intersection calculations.

Raises
------
Expand All @@ -83,7 +72,9 @@ def constant_longitude(self, lon: float, method="fast"):
The accuracy and performance of the function can be controlled using the `method` parameter.
For higher precision requreiments, consider using method='acurate'.
"""
faces = self.uxda.uxgrid.get_faces_at_constant_longitude(lon, method)
faces = self.uxda.uxgrid.get_faces_at_constant_longitude(
lon, use_spherical_bounding_box
)

return self.uxda.isel(n_face=faces)

Expand Down
45 changes: 18 additions & 27 deletions uxarray/cross_sections/grid_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,12 @@ def __repr__(self):
methods_heading += " * constant_latitude(lat, )\n"
return prefix + methods_heading

def constant_latitude(self, lat: float, return_face_indices=False, method="fast"):
def constant_latitude(
self, lat: float, return_face_indices=False, use_spherical_bounding_box=False
):
"""Extracts a cross-section of the grid at a specified constant
latitude.

This method identifies and returns all faces (or grid elements) that intersect
with a given latitude. The returned cross-section can include either just the grid
or both the grid elements and the corresponding face indices, depending
on the `return_face_indices` parameter.

Parameters
----------
lat : float
Expand All @@ -37,12 +34,9 @@ def constant_latitude(self, lat: float, return_face_indices=False, method="fast"
If True, returns both the grid at the specified latitude and the indices
of the intersecting faces. If False, only the grid is returned.
Default is False.
method : str, optional
The internal method to use when identifying faces at the constant latitude.
Options are:
- 'fast': Uses a faster but potentially less accurate method for face identification.
- 'accurate': Uses a slower but more accurate method.
Default is 'fast'.
use_spherical_bounding_box : bool, optional
If True, uses a spherical bounding box for intersection calculations.


Returns
-------
Expand All @@ -69,7 +63,9 @@ def constant_latitude(self, lat: float, return_face_indices=False, method="fast"
The accuracy and performance of the function can be controlled using the `method` parameter.
For higher precision requreiments, consider using method='acurate'.
"""
faces = self.uxgrid.get_faces_at_constant_latitude(lat, method)
faces = self.uxgrid.get_faces_at_constant_latitude(
lat, use_spherical_bounding_box
)

if len(faces) == 0:
raise ValueError(f"No intersections found at lat={lat}.")
Expand All @@ -81,29 +77,22 @@ def constant_latitude(self, lat: float, return_face_indices=False, method="fast"
else:
return grid_at_constant_lat

def constant_longitude(self, lon: float, return_face_indices=False, method="fast"):
def constant_longitude(
self, lon: float, use_spherical_bounding_box=False, return_face_indices=False
):
"""Extracts a cross-section of the grid at a specified constant
longitude.

This method identifies and returns all faces (or grid elements) that intersect
with a given longitude. The returned cross-section can include either just the grid
or both the grid elements and the corresponding face indices, depending
on the `return_face_indices` parameter.

Parameters
----------
lon : float
The longitude at which to extract the cross-section, in degrees.
use_spherical_bounding_box : bool, optional
If True, uses a spherical bounding box for intersection calculations.
return_face_indices : bool, optional
If True, returns both the grid at the specified longitude and the indices
of the intersecting faces. If False, only the grid is returned.
Default is False.
method : str, optional
The internal method to use when identifying faces at the constant longitude.
Options are:
- 'fast': Uses a faster but potentially less accurate method for face identification.
- 'accurate': Uses a slower but more accurate method.
Default is 'fast'.

Returns
-------
Expand All @@ -130,10 +119,12 @@ def constant_longitude(self, lon: float, return_face_indices=False, method="fast
The accuracy and performance of the function can be controlled using the `method` parameter.
For higher precision requreiments, consider using method='acurate'.
"""
faces = self.uxgrid.get_faces_at_constant_longitude(lon, method)
faces = self.uxgrid.get_faces_at_constant_longitude(
lon, use_spherical_bounding_box
)

if len(faces) == 0:
raise ValueError(f"No intersections found at lon={lon}.")
raise ValueError(f"No intersections found at lon={lon}")

grid_at_constant_lon = self.uxgrid.isel(n_face=faces)

Expand Down
3 changes: 1 addition & 2 deletions uxarray/grid/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@


def _unique_points(points, tolerance=ERROR_TOLERANCE):
"""Identify unique intersection points from a list of points, considering
floating point precision errors.
"""Identify unique intersection points from a list of points, considering floating point precision errors.

Parameters
----------
Expand Down
Loading