@@ -194,7 +194,7 @@ def __init__(
194194 Origin is at upper left corner of the grid. Use rowcol instead.
195195 :param rowcol: Indices of the cell in (row, col) format.
196196 Origin is at upper left corner of the grid
197- :param xy: Cell center coordinates in the CRS.
197+ :param xy: Geographic/projected (x, y) coordinates of the cell center in the CRS.
198198 """
199199
200200 super ().__init__ (model )
@@ -263,7 +263,7 @@ def rowcol(self) -> Coordinate | None:
263263 @property
264264 def xy (self ) -> FloatCoordinate | None :
265265 """
266- Cell center coordinates in the CRS.
266+ Geographic/projected (x, y) coordinates of the cell center in the CRS.
267267 """
268268 return self ._xy
269269
@@ -426,53 +426,104 @@ def coord_iter(self) -> Iterator[tuple[Cell, int, int]]:
426426 for col in range (self .height ):
427427 yield self .cells [row ][col ], row , col # cell, x, y
428428
429- def apply_raster (self , data : np .ndarray , attr_name : str | None = None ) -> None :
429+ def apply_raster (
430+ self , data : np .ndarray , attr_name : str | Sequence [str ] | None = None
431+ ) -> None :
430432 """
431433 Apply raster data to the cells.
432434
433- :param np.ndarray data: 2D numpy array with shape (1, height, width).
434- :param str | None attr_name: Name of the attribute to be added to the cells.
435- If None, a random name will be generated. Default is None.
436- :raises ValueError: If the shape of the data is not (1, height, width).
435+ :param np.ndarray data: 3D numpy array with shape (bands, height, width).
436+ :param str | Sequence[str] | None attr_name: Attribute name(s) to be added to the
437+ cells. For multi-band rasters, pass a list of names with length equal to
438+ the number of bands, or a single base name to be suffixed per band. If None,
439+ names are generated. Default is None.
440+ :raises ValueError: If the shape of the data does not match the raster.
437441 """
438442
439- if data .shape != ( 1 , self .height , self .width ):
443+ if data .ndim != 3 or data . shape [ 1 :] != ( self .height , self .width ):
440444 raise ValueError (
441445 f"Data shape does not match raster shape. "
442- f"Expected { ( 1 , self .height , self .width ) } , received { data .shape } ."
446+ f"Expected (*, { self .height } , { self .width } ) , received { data .shape } ."
443447 )
444- if attr_name is None :
445- attr_name = f"attribute_{ len (self .cell_cls .__dict__ )} "
446- self ._attributes .add (attr_name )
447- for grid_x in range (self .width ):
448- for grid_y in range (self .height ):
449- setattr (
450- self .cells [grid_x ][grid_y ],
451- attr_name ,
452- data [0 , self .height - grid_y - 1 , grid_x ],
453- )
448+ num_bands = data .shape [0 ]
449+
450+ if num_bands == 1 :
451+ if isinstance (attr_name , Sequence ) and not isinstance (attr_name , str ):
452+ if len (attr_name ) != 1 :
453+ raise ValueError (
454+ "attr_name sequence length must match the number of raster bands; "
455+ f"expected { num_bands } band names, got { len (attr_name )} ."
456+ )
457+ names = [attr_name [0 ]]
458+ else :
459+ names = [cast (str | None , attr_name )]
460+ else :
461+ if isinstance (attr_name , Sequence ) and not isinstance (attr_name , str ):
462+ if len (attr_name ) != num_bands :
463+ raise ValueError (
464+ "attr_name sequence length must match the number of raster bands; "
465+ f"expected { num_bands } band names, got { len (attr_name )} ."
466+ )
467+ names = list (attr_name )
468+ elif isinstance (attr_name , str ):
469+ names = [f"{ attr_name } _{ band_idx + 1 } " for band_idx in range (num_bands )]
470+ else :
471+ names = [None ] * num_bands
472+
473+ def _default_attr_name () -> str :
474+ base = f"attribute_{ len (self .cell_cls .__dict__ )} "
475+ if base not in self ._attributes :
476+ return base
477+ suffix = 1
478+ candidate = f"{ base } _{ suffix } "
479+ while candidate in self ._attributes :
480+ suffix += 1
481+ candidate = f"{ base } _{ suffix } "
482+ return candidate
483+
484+ for band_idx , name in enumerate (names ):
485+ attr = _default_attr_name () if name is None else name
486+ self ._attributes .add (attr )
487+ for grid_x in range (self .width ):
488+ for grid_y in range (self .height ):
489+ setattr (
490+ self .cells [grid_x ][grid_y ],
491+ attr ,
492+ data [band_idx , self .height - grid_y - 1 , grid_x ],
493+ )
454494
455- def get_raster (self , attr_name : str | None = None ) -> np .ndarray :
495+ def get_raster (self , attr_name : str | Sequence [ str ] | None = None ) -> np .ndarray :
456496 """
457497 Return the values of given attribute.
458498
459- :param str | None attr_name: Name of the attribute to be returned. If None,
460- returns all attributes. Default is None.
461- :return: The values of given attribute as a 2D numpy array with shape (1, height, width).
499+ :param str | Sequence[str] | None attr_name: Name(s) of attributes to be returned.
500+ If None, returns all attributes. Default is None.
501+ :return: The values of given attribute(s) as a numpy array with shape
502+ (bands, height, width).
462503 :rtype: np.ndarray
463504 """
464505
465- if attr_name is not None and attr_name not in self .attributes :
506+ if isinstance ( attr_name , str ) and attr_name not in self .attributes :
466507 raise ValueError (
467508 f"Attribute { attr_name } does not exist. "
468509 f"Choose from { self .attributes } , or set `attr_name` to `None` to retrieve all."
469510 )
511+ if isinstance (attr_name , Sequence ) and not isinstance (attr_name , str ):
512+ missing = [name for name in attr_name if name not in self .attributes ]
513+ if missing :
514+ raise ValueError (
515+ f"Attribute { missing [0 ]} does not exist. "
516+ f"Choose from { self .attributes } , or set `attr_name` to `None` to retrieve all."
517+ )
470518 if attr_name is None :
471519 num_bands = len (self .attributes )
472520 attr_names = self .attributes
521+ elif isinstance (attr_name , Sequence ) and not isinstance (attr_name , str ):
522+ num_bands = len (attr_name )
523+ attr_names = list (attr_name )
473524 else :
474525 num_bands = 1
475- attr_names = { attr_name }
526+ attr_names = [ attr_name ]
476527 data = np .empty ((num_bands , self .height , self .width ))
477528 for ind , name in enumerate (attr_names ):
478529 for grid_x in range (self .width ):
@@ -684,16 +735,18 @@ def from_file(
684735 raster_file : str ,
685736 model : Model ,
686737 cell_cls : type [Cell ] = Cell ,
687- attr_name : str | None = None ,
738+ attr_name : str | Sequence [ str ] | None = None ,
688739 rio_opener : Callable | None = None ,
689740 ) -> RasterLayer :
690741 """
691742 Creates a RasterLayer from a raster file.
692743
693744 :param str raster_file: Path to the raster file.
694745 :param Type[Cell] cell_cls: The class of the cells in the layer.
695- :param str | None attr_name: The name of the attribute to use for the cell values.
696- If None, a random name will be generated. Default is None.
746+ :param str | Sequence[str] | None attr_name: Attribute name(s) to use for the cell
747+ values. For multi-band rasters, pass a list of names with length equal to
748+ the number of bands, or a single base name to be suffixed per band. If None,
749+ names are generated. Default is None.
697750 :param Callable | None rio_opener: A callable passed to Rasterio open() function.
698751 """
699752
@@ -713,14 +766,17 @@ def from_file(
713766 return obj
714767
715768 def to_file (
716- self , raster_file : str , attr_name : str | None = None , driver : str = "GTiff"
769+ self ,
770+ raster_file : str ,
771+ attr_name : str | Sequence [str ] | None = None ,
772+ driver : str = "GTiff" ,
717773 ) -> None :
718774 """
719775 Writes a raster layer to a file.
720776
721777 :param str raster_file: The path to the raster file to write to.
722- :param str | None attr_name: The name of the attribute to write to the raster.
723- If None, all attributes are written. Default is None.
778+ :param str | Sequence[str] | None attr_name: The name(s) of attributes to write
779+ to the raster. If None, all attributes are written. Default is None.
724780 :param str driver: The GDAL driver to use for writing the raster file.
725781 Default is 'GTiff'. See GDAL docs at https://gdal.org/drivers/raster/index.html.
726782 """
0 commit comments