Skip to content

Commit b688ed5

Browse files
authored
Merge pull request #1420 from UXARRAY/rajeeja/fix_issue1393
to_raster: handle unset extent; accept singleton extra dims
2 parents 7e1c311 + c8ce122 commit b688ed5

File tree

4 files changed

+71
-15
lines changed

4 files changed

+71
-15
lines changed

test/test_plot.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,25 @@ def test_to_raster(gridpath):
103103
mesh_path = gridpath("mpas", "QU", "oQU480.231010.nc")
104104
uxds = ux.open_dataset(mesh_path, mesh_path)
105105

106-
raster = uxds['bottomDepth'].to_raster(ax=ax)
106+
with pytest.warns(UserWarning, match=r"Axes extent was default"):
107+
raster = uxds['bottomDepth'].to_raster(ax=ax)
108+
109+
assert isinstance(raster, np.ndarray)
110+
111+
112+
def test_to_raster_with_extra_dims(gridpath):
113+
fig, ax = plt.subplots(
114+
subplot_kw={'projection': ccrs.Robinson()},
115+
constrained_layout=True,
116+
figsize=(10, 5),
117+
)
118+
119+
mesh_path = gridpath("mpas", "QU", "oQU480.231010.nc")
120+
uxds = ux.open_dataset(mesh_path, mesh_path)
121+
122+
da = uxds['bottomDepth'].expand_dims(time=[0])
123+
with pytest.warns(UserWarning, match=r"Axes extent was default"):
124+
raster = da.to_raster(ax=ax)
107125

108126
assert isinstance(raster, np.ndarray)
109127

@@ -121,9 +139,10 @@ def test_to_raster_reuse_mapping(gridpath, tmpdir):
121139
uxds = ux.open_dataset(mesh_path, mesh_path)
122140

123141
# Returning
124-
raster1, pixel_mapping = uxds['bottomDepth'].to_raster(
125-
ax=ax, pixel_ratio=0.5, return_pixel_mapping=True
126-
)
142+
with pytest.warns(UserWarning, match=r"Axes extent was default"):
143+
raster1, pixel_mapping = uxds['bottomDepth'].to_raster(
144+
ax=ax, pixel_ratio=0.5, return_pixel_mapping=True
145+
)
127146
assert isinstance(raster1, np.ndarray)
128147
assert isinstance(pixel_mapping, xr.DataArray)
129148

uxarray/core/dataarray.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,16 +375,16 @@ def to_raster(
375375
_RasterAxAttrs,
376376
)
377377

378-
_ensure_dimensions(self)
378+
data = _ensure_dimensions(self)
379379

380380
if not isinstance(ax, GeoAxes):
381381
raise TypeError("`ax` must be an instance of cartopy.mpl.geoaxes.GeoAxes")
382382

383383
pixel_ratio_set = pixel_ratio is not None
384384
if not pixel_ratio_set:
385385
pixel_ratio = 1.0
386-
input_ax_attrs = _RasterAxAttrs.from_ax(ax, pixel_ratio=pixel_ratio)
387386
if pixel_mapping is not None:
387+
input_ax_attrs = _RasterAxAttrs.from_ax(ax, pixel_ratio=pixel_ratio)
388388
if isinstance(pixel_mapping, xr.DataArray):
389389
pixel_ratio_input = pixel_ratio
390390
pixel_ratio = pixel_mapping.attrs["pixel_ratio"]
@@ -403,9 +403,43 @@ def to_raster(
403403
+ input_ax_attrs._value_comparison_message(pm_ax_attrs)
404404
)
405405
pixel_mapping = np.asarray(pixel_mapping, dtype=INT_DTYPE)
406+
else:
407+
408+
def _is_default_extent() -> bool:
409+
# Default extents are indicated by xlim/ylim being (0, 1)
410+
# when autoscale is still on (no extent has been explicitly set)
411+
if not ax.get_autoscale_on():
412+
return False
413+
xlim, ylim = ax.get_xlim(), ax.get_ylim()
414+
return np.allclose(xlim, (0.0, 1.0)) and np.allclose(ylim, (0.0, 1.0))
415+
416+
if _is_default_extent():
417+
try:
418+
import cartopy.crs as ccrs
419+
420+
lon_min = float(self.uxgrid.node_lon.min(skipna=True).values)
421+
lon_max = float(self.uxgrid.node_lon.max(skipna=True).values)
422+
lat_min = float(self.uxgrid.node_lat.min(skipna=True).values)
423+
lat_max = float(self.uxgrid.node_lat.max(skipna=True).values)
424+
ax.set_extent(
425+
(lon_min, lon_max, lat_min, lat_max),
426+
crs=ccrs.PlateCarree(),
427+
)
428+
warn(
429+
"Axes extent was default; auto-setting from grid lon/lat bounds for rasterization. "
430+
"Set the extent explicitly to control this, e.g. via ax.set_global(), "
431+
"ax.set_extent(...), or ax.set_xlim(...) + ax.set_ylim(...).",
432+
stacklevel=2,
433+
)
434+
except Exception as e:
435+
warn(
436+
f"Failed to auto-set extent from grid bounds: {e}",
437+
stacklevel=2,
438+
)
439+
input_ax_attrs = _RasterAxAttrs.from_ax(ax, pixel_ratio=pixel_ratio)
406440

407441
raster, pixel_mapping_np = _nearest_neighbor_resample(
408-
self,
442+
data,
409443
ax,
410444
pixel_ratio=pixel_ratio,
411445
pixel_mapping=pixel_mapping,

uxarray/plot/matplotlib.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,21 @@ def _ensure_dimensions(data: UxDataArray) -> UxDataArray:
2222
ValueError
2323
If the sole dimension is not named "n_face".
2424
"""
25-
# Check dimensionality
26-
if data.ndim != 1:
25+
# Allow extra singleton dimensions as long as there's exactly one non-singleton dim
26+
non_trivial_dims = [dim for dim, size in zip(data.dims, data.shape) if size != 1]
27+
28+
if len(non_trivial_dims) != 1:
2729
raise ValueError(
28-
f"Expected a 1D DataArray over 'n_face', but got {data.ndim} dimensions: {data.dims}"
30+
"Expected data with a single dimension (other axes may be length 1), "
31+
f"but got dims {data.dims} with shape {data.shape}"
2932
)
3033

31-
# Check dimension name
32-
if data.dims[0] != "n_face":
33-
raise ValueError(f"Expected dimension 'n_face', but got '{data.dims[0]}'")
34+
sole_dim = non_trivial_dims[0]
35+
if sole_dim != "n_face":
36+
raise ValueError(f"Expected dimension 'n_face', but got '{sole_dim}'")
3437

35-
return data
38+
# Squeeze any singleton axes to ensure we return a true 1D array over n_face
39+
return data.squeeze()
3640

3741

3842
class _RasterAxAttrs(NamedTuple):

uxarray/utils/computing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def dot_fma(v1, v2):
102102
----------
103103
S. Graillat, Ph. Langlois, and N. Louvet. "Accurate dot products with FMA." Presented at RNC 7, 2007, Nancy, France.
104104
DALI-LP2A Laboratory, University of Perpignan, France.
105-
[Poster](https://www-pequan.lip6.fr/~graillat/papers/posterRNC7.pdf)
106105
"""
107106
if len(v1) != len(v2):
108107
raise ValueError("Input vectors must be of the same length")

0 commit comments

Comments
 (0)