Skip to content

Commit a1b211c

Browse files
authored
ENH: Load and write RPCs #837 (#890)
1 parent 199e917 commit a1b211c

File tree

9 files changed

+376
-47
lines changed

9 files changed

+376
-47
lines changed

docs/history.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ History
33

44
Latest
55
------
6+
- ENH: RPCs: Load and write RPCs (#837)
67

78
0.20.0
89
------

rioxarray/_io.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1209,8 +1209,12 @@ def open_rasterio(
12091209
coord_name = "band"
12101210
coords[coord_name] = numpy.asarray(riods.indexes)
12111211

1212+
# Handle GCPs and RPCs
12121213
has_gcps = riods.gcps[0]
1213-
if has_gcps:
1214+
has_rpcs = riods.rpcs
1215+
1216+
# Only parse coordinates in case the array is georeferenced
1217+
if has_gcps or has_rpcs:
12141218
parse_coordinates = False
12151219

12161220
# Get geospatial coordinates
@@ -1281,6 +1285,8 @@ def open_rasterio(
12811285
result.rio.write_crs(rio_crs, inplace=True)
12821286
if has_gcps:
12831287
result.rio.write_gcps(*riods.gcps, inplace=True)
1288+
if has_rpcs:
1289+
result.rio.write_rpcs(riods.rpcs, inplace=True)
12841290

12851291
if chunks is not None:
12861292
result = _prepare_dask(

rioxarray/raster_array.py

Lines changed: 96 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -458,37 +458,26 @@ def reproject(
458458
"CRS not found. Please set the CRS with 'rio.write_crs()'."
459459
f"{_get_data_var_message(self._obj)}"
460460
)
461-
gcps = self.get_gcps()
462-
if gcps:
463-
kwargs.setdefault("gcps", gcps)
464461

465-
use_affine = (
466-
"gcps" not in kwargs
467-
and "rpcs" not in kwargs
468-
and "src_geoloc_array" not in kwargs
469-
)
470-
src_affine = None if not use_affine else self.transform(recalc=True)
471-
if transform is None:
472-
dst_affine, dst_width, dst_height = _make_dst_affine(
473-
src_data_array=self._obj,
474-
src_crs=self.crs,
475-
dst_crs=dst_crs,
476-
dst_resolution=resolution,
477-
dst_shape=shape,
478-
**kwargs,
479-
)
480-
else:
481-
dst_affine = transform
482-
if shape is not None:
483-
dst_height, dst_width = shape
484-
else:
485-
dst_height, dst_width = self.shape
462+
kwargs = self._reproj_update_kwargs(**kwargs)
463+
486464
if isinstance(resampling, str):
487465
resampling = _convert_str_to_resampling(resampling)
488466

489-
dst_data = self._create_dst_data(dst_height=dst_height, dst_width=dst_width)
467+
# Get source data from inputs
468+
src_affine, use_affine = self._reproj_get_src(**kwargs)
490469

491-
dst_nodata = self._get_dst_nodata(nodata)
470+
# Get destination data from inputs
471+
dst_data, dst_height, dst_width, dst_affine, dst_nodata = self._reproj_get_dst(
472+
dst_crs=dst_crs,
473+
resolution=resolution,
474+
shape=shape,
475+
transform=transform,
476+
nodata=nodata,
477+
**kwargs,
478+
)
479+
480+
# Do the reprojection using rasterio
492481
rasterio.warp.reproject(
493482
source=self._obj.values,
494483
destination=dst_data,
@@ -501,8 +490,33 @@ def reproject(
501490
resampling=resampling,
502491
**kwargs,
503492
)
493+
494+
# Convert the ndarray to a xarray
495+
return self._reproj_convert_to_xarray(
496+
dst_data=dst_data,
497+
dst_nodata=dst_nodata,
498+
dst_affine=dst_affine,
499+
dst_width=dst_width,
500+
dst_height=dst_height,
501+
dst_crs=dst_crs,
502+
use_affine=use_affine,
503+
)
504+
505+
def _reproj_convert_to_xarray(
506+
self,
507+
*,
508+
dst_data: numpy.ndarray,
509+
dst_nodata: float,
510+
dst_affine: Affine,
511+
dst_width: int,
512+
dst_height: int,
513+
dst_crs: Any,
514+
use_affine: bool,
515+
):
516+
"""Helper function creating a proper xarray (with correct attributes, etc) from the reprojection output"""
504517
# add necessary attributes
505518
new_attrs = _generate_attrs(src_data_array=self._obj, dst_nodata=dst_nodata)
519+
506520
# make sure dimensions with coordinates renamed to x,y
507521
dst_dims: list[Hashable] = []
508522
for dim in self._obj.dims:
@@ -529,8 +543,63 @@ def reproject(
529543
xda.rio.write_transform(dst_affine, inplace=True)
530544
xda.rio.write_crs(dst_crs, inplace=True)
531545
xda.rio.write_coordinate_system(inplace=True)
546+
532547
return xda
533548

549+
def _reproj_update_kwargs(self, **kwargs):
550+
"""Helper function updating kwargs from internal members"""
551+
gcps = self.get_gcps()
552+
if gcps:
553+
kwargs.setdefault("gcps", gcps)
554+
555+
rpcs = self.get_rpcs()
556+
if rpcs:
557+
kwargs.setdefault("rpcs", rpcs)
558+
559+
return kwargs
560+
561+
def _reproj_get_src(self, **kwargs):
562+
"""Helper function creating source data from inputs"""
563+
use_affine = (
564+
"gcps" not in kwargs
565+
and "rpcs" not in kwargs
566+
and "src_geoloc_array" not in kwargs
567+
)
568+
src_affine = None if not use_affine else self.transform(recalc=True)
569+
return src_affine, use_affine
570+
571+
def _reproj_get_dst(
572+
self,
573+
*,
574+
dst_crs: Any,
575+
resolution: Optional[Union[float, tuple[float, float]]] = None,
576+
shape: Optional[tuple[int, int]] = None,
577+
transform: Optional[Affine] = None,
578+
nodata: Optional[float] = None,
579+
**kwargs,
580+
):
581+
"""Helper function creating destination data from inputs"""
582+
if transform is None:
583+
dst_affine, dst_width, dst_height = _make_dst_affine(
584+
src_data_array=self._obj,
585+
src_crs=self.crs,
586+
dst_crs=dst_crs,
587+
dst_resolution=resolution,
588+
dst_shape=shape,
589+
**kwargs,
590+
)
591+
else:
592+
dst_affine = transform
593+
if shape is not None:
594+
dst_height, dst_width = shape
595+
else:
596+
dst_height, dst_width = self.shape
597+
598+
dst_data = self._create_dst_data(dst_height=dst_height, dst_width=dst_width)
599+
dst_nodata = self._get_dst_nodata(nodata)
600+
601+
return dst_data, dst_height, dst_width, dst_affine, dst_nodata
602+
534603
def _get_dst_nodata(self, nodata: Optional[float]) -> Optional[float]:
535604
default_nodata = (
536605
_NODATA_DTYPE_MAP.get(dtype_rev[self._obj.dtype.name])
@@ -1194,6 +1263,7 @@ def to_raster(
11941263
crs=self.crs,
11951264
transform=self.transform(recalc=recalc_transform),
11961265
gcps=self.get_gcps(),
1266+
rpcs=self.get_rpcs(),
11971267
nodata=rio_nodata,
11981268
windowed=windowed,
11991269
lock=lock,

rioxarray/raster_writer.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
"""
1111
import numpy
1212
import rasterio
13+
import xarray
14+
from rasterio.rpc import RPC
1315
from rasterio.windows import Window
1416
from xarray.conventions import encode_cf_variable
1517

@@ -184,6 +186,41 @@ def _get_dtypes(*, rasterio_dtype, encoded_rasterio_dtype, dataarray_dtype):
184186
return rasterio_dtype, numpy_dtype
185187

186188

189+
def _to_raster_dtypes(xarray_dataarray: xarray.DataArray, **kwargs):
190+
"""Helper function managing dtypes in to_raster, managing '_Unsigned' scenario"""
191+
kwargs["dtype"], numpy_dtype = _get_dtypes(
192+
rasterio_dtype=kwargs["dtype"],
193+
encoded_rasterio_dtype=xarray_dataarray.encoding.get("rasterio_dtype"),
194+
dataarray_dtype=xarray_dataarray.encoding.get(
195+
"dtype", str(xarray_dataarray.dtype)
196+
),
197+
)
198+
# there is no equivalent for netCDF _Unsigned
199+
# across output GDAL formats. It is safest to convert beforehand.
200+
# https://github.com/OSGeo/gdal/issues/6352#issuecomment-1245981837
201+
if "_Unsigned" in xarray_dataarray.encoding:
202+
unsigned_dtype = _get_unsigned_dtype(
203+
unsigned=xarray_dataarray.encoding["_Unsigned"] == "true",
204+
dtype=numpy_dtype,
205+
)
206+
if unsigned_dtype is not None:
207+
numpy_dtype = unsigned_dtype
208+
kwargs["dtype"] = unsigned_dtype
209+
xarray_dataarray.encoding["rasterio_dtype"] = str(unsigned_dtype)
210+
xarray_dataarray.encoding["dtype"] = str(unsigned_dtype)
211+
212+
return xarray_dataarray, numpy_dtype, kwargs
213+
214+
215+
def _to_raster_validate_rpcs(**kwargs):
216+
"""Helper function validating RPCs in to_raster: RPCs should be either be a RPC object or a GDAL-compatible dict"""
217+
rpcs = kwargs.get("rpcs")
218+
if rpcs is not None:
219+
assert isinstance(
220+
rpcs, (RPC, dict)
221+
), "RPCs must be of type 'rasterio.rpc.RPC' or dict."
222+
223+
187224
class RasterioWriter:
188225
"""
189226
@@ -254,26 +291,9 @@ def to_raster(self, *, xarray_dataarray, tags, windowed, lock, compute, **kwargs
254291
Keyword arguments to pass into writing the raster.
255292
"""
256293
xarray_dataarray = xarray_dataarray.copy()
257-
kwargs["dtype"], numpy_dtype = _get_dtypes(
258-
rasterio_dtype=kwargs["dtype"],
259-
encoded_rasterio_dtype=xarray_dataarray.encoding.get("rasterio_dtype"),
260-
dataarray_dtype=xarray_dataarray.encoding.get(
261-
"dtype", str(xarray_dataarray.dtype)
262-
),
294+
xarray_dataarray, numpy_dtype, kwargs = _to_raster_dtypes(
295+
xarray_dataarray, **kwargs
263296
)
264-
# there is no equivalent for netCDF _Unsigned
265-
# across output GDAL formats. It is safest to convert beforehand.
266-
# https://github.com/OSGeo/gdal/issues/6352#issuecomment-1245981837
267-
if "_Unsigned" in xarray_dataarray.encoding:
268-
unsigned_dtype = _get_unsigned_dtype(
269-
unsigned=xarray_dataarray.encoding["_Unsigned"] == "true",
270-
dtype=numpy_dtype,
271-
)
272-
if unsigned_dtype is not None:
273-
numpy_dtype = unsigned_dtype
274-
kwargs["dtype"] = unsigned_dtype
275-
xarray_dataarray.encoding["rasterio_dtype"] = str(unsigned_dtype)
276-
xarray_dataarray.encoding["dtype"] = str(unsigned_dtype)
277297

278298
if kwargs["nodata"] is not None:
279299
# Ensure dtype of output data matches the expected dtype.
@@ -283,12 +303,17 @@ def to_raster(self, *, xarray_dataarray, tags, windowed, lock, compute, **kwargs
283303
original_nodata=kwargs["nodata"], new_dtype=numpy_dtype
284304
)
285305

306+
# Check RPC validity: RPCs should be either be a RPC object or a GDAL-compatible dict
307+
_to_raster_validate_rpcs(**kwargs)
308+
309+
# RPCs and GCPs are propagated through **kwargs
286310
with rasterio.open(self.raster_path, "w", **kwargs) as rds:
287311
_write_metatata_to_raster(
288312
raster_handle=rds, xarray_dataset=xarray_dataarray, tags=tags
289313
)
314+
290315
if not (lock and is_dask_collection(xarray_dataarray.data)):
291-
# write data to raster immmediately if not dask array
316+
# write data to raster immediately if not dask array
292317
if windowed:
293318
window_iter = rds.block_windows(1)
294319
else:

rioxarray/rioxarray.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pyproj.database import query_utm_crs_info
2121
from rasterio.control import GroundControlPoint
2222
from rasterio.crs import CRS
23+
from rasterio.rpc import RPC
2324

2425
from rioxarray._options import EXPORT_GRID_MAPPING, get_option
2526
from rioxarray.crs import crs_from_user_input
@@ -301,6 +302,7 @@ def __init__(self, xarray_obj: Union[xarray.DataArray, xarray.Dataset]):
301302
self._width: Optional[int] = None
302303
self._crs: Union[rasterio.crs.CRS, None, Literal[False]] = None
303304
self._gcps: Optional[list[GroundControlPoint]] = None
305+
self._rpcs: Optional[RPC] = None
304306

305307
@property
306308
def crs(self) -> Optional[rasterio.crs.CRS]:
@@ -360,6 +362,7 @@ def _get_obj(self, inplace: bool) -> Union[xarray.Dataset, xarray.DataArray]:
360362
obj_copy.rio._height = self._height
361363
obj_copy.rio._crs = self._crs
362364
obj_copy.rio._gcps = self._gcps
365+
obj_copy.rio._rpcs = self._rpcs
363366
return obj_copy
364367

365368
def set_crs(
@@ -1332,6 +1335,76 @@ def _parse_gcp(gcp) -> GroundControlPoint:
13321335
self._gcps = [_parse_gcp(gcp) for gcp in geojson_gcps["features"]]
13331336
return self._gcps
13341337

1338+
def write_rpcs(
1339+
self,
1340+
rpcs: RPC,
1341+
*,
1342+
grid_mapping_name: Optional[str] = None,
1343+
inplace: bool = False,
1344+
) -> xarray.Dataset | xarray.DataArray:
1345+
"""
1346+
Write the Rational Polynomial Coefficients to the dataset.
1347+
1348+
https://rasterio.readthedocs.io/en/latest/topics/georeferencing.html#rational-polynomial-coefficients
1349+
1350+
Parameters
1351+
----------
1352+
rpcs: :obj:`rasterio.rpc.RPC`
1353+
The Rational Polynomial Coefficients to integrate to the dataset.
1354+
grid_mapping_name: str, optional
1355+
Name of the grid_mapping coordinate to store the RPCs information in.
1356+
Default is the grid_mapping name of the dataset.
1357+
inplace: bool, optional
1358+
If True, it will write to the existing dataset. Default is False.
1359+
1360+
Returns
1361+
-------
1362+
:obj:`xarray.Dataset` | :obj:`xarray.DataArray`:
1363+
Modified dataset with Rational Polynomial Coefficients written.
1364+
"""
1365+
grid_mapping_name = (
1366+
self.grid_mapping if grid_mapping_name is None else grid_mapping_name
1367+
)
1368+
data_obj = self._get_obj(inplace=True)
1369+
1370+
# RPC CRS is always 4326
1371+
data_obj = data_obj.rio.write_crs(
1372+
"epsg:4326", grid_mapping_name=grid_mapping_name, inplace=inplace
1373+
)
1374+
try:
1375+
grid_map_attrs = data_obj.coords[grid_mapping_name].attrs.copy()
1376+
except KeyError:
1377+
data_obj.coords[grid_mapping_name] = xarray.Variable((), 0)
1378+
grid_map_attrs = data_obj.coords[grid_mapping_name].attrs.copy()
1379+
1380+
# Store the RPCCs
1381+
grid_map_attrs["rpcs"] = json.dumps(rpcs.to_dict())
1382+
data_obj.coords[grid_mapping_name].rio.set_attrs(grid_map_attrs, inplace=True)
1383+
self._rpcs = rpcs
1384+
1385+
return data_obj
1386+
1387+
def get_rpcs(self) -> Optional[RPC]:
1388+
"""
1389+
Get the Rational Polynomial Coefficients from the dataset.
1390+
1391+
https://rasterio.readthedocs.io/en/latest/topics/georeferencing.html#rational-polynomial-coefficients
1392+
1393+
Returns
1394+
-------
1395+
:obj:`rasterio.rpc.RPC` or None
1396+
The Rational Polynomial Coefficients from the dataset or None if not applicable
1397+
"""
1398+
if self._rpcs is not None:
1399+
return self._rpcs
1400+
try:
1401+
json_rpcs = json.loads(self._obj.coords[self.grid_mapping].attrs["rpcs"])
1402+
except (KeyError, AttributeError):
1403+
return None
1404+
1405+
self._rpcs = RPC(**json_rpcs)
1406+
return self._rpcs
1407+
13351408

13361409
def _convert_gcps_to_geojson(
13371410
gcps: Iterable[GroundControlPoint],

0 commit comments

Comments
 (0)