Skip to content

Commit 53af02e

Browse files
committed
feat: add basic raster data loader support
- Implement RasterLoader to load raster files with rasterio - Support loading GeoTIFF, PNG, JP2 and other formats - Modified to add example usage of raster loader - Raster loader returning 3D array
1 parent a840447 commit 53af02e

File tree

8 files changed

+1501
-33
lines changed

8 files changed

+1501
-33
lines changed

examples/1-Per-Module/1-loader.ipynb

Lines changed: 317 additions & 4 deletions
Large diffs are not rendered by default.

examples/1-Per-Module/6-visualiser.ipynb

Lines changed: 525 additions & 9 deletions
Large diffs are not rendered by default.

src/urban_mapper/modules/loader/abc_loader.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,20 @@ def __init__(
4040
self.additional_loader_parameters: Dict[str, Any] = additional_loader_parameters
4141

4242
@abstractmethod
43-
def _load_data_from_file(self) -> gpd.GeoDataFrame:
43+
def _load_data_from_file(self) -> Any:
4444
"""Internal implementation method for loading data from a file.
4545
46-
This method is called by `load_data_from_file()` after validation is performed.
47-
48-
!!! warning "Method Not Implemented"
49-
This method must be implemented by subclasses. It should contain the logic
50-
for reading the file and converting it to a `GeoDataFrame`.
46+
Cette méthode doit être implémentée par les sous-classes.
47+
Pour les loaders tabulaires (CSV, Shapefile, Parquet), elle doit retourner un GeoDataFrame.
48+
Pour les loaders raster, elle peut retourner un dictionnaire ou un tableau numpy contenant les données raster et les métadonnées.
5149
5250
Returns:
53-
A `GeoDataFrame` containing the loaded spatial data.
51+
- Pour les loaders tabulaires : un `GeoDataFrame` contenant les données spatiales chargées.
52+
- Pour les loaders raster : un objet contenant les données raster (ex: dict ou numpy.ndarray).
5453
5554
Raises:
56-
ValueError: If required columns are missing or the file format is invalid.
57-
FileNotFoundError: If the file does not exist.
55+
ValueError: Si des colonnes requises sont manquantes ou le format de fichier est invalide.
56+
FileNotFoundError: Si le fichier n'existe pas.
5857
"""
5958
...
6059

@@ -112,3 +111,4 @@ def preview(self, format: str = "ascii") -> Any:
112111
ValueError: If an unsupported format is requested.
113112
"""
114113
pass
114+
pass

src/urban_mapper/modules/loader/loader_factory.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import json
2-
from collections import defaultdict
3-
from itertools import islice
1+
import json
2+
from collections import defaultdict
3+
from itertools import islice
44
from pathlib import Path
55
from typing import Optional, Union, Dict
66

@@ -16,6 +16,7 @@
1616
from urban_mapper.modules.loader.abc_loader import LoaderBase
1717
from urban_mapper.modules.loader.loaders.csv_loader import CSVLoader
1818
from urban_mapper.modules.loader.loaders.parquet_loader import ParquetLoader
19+
from urban_mapper.modules.loader.loaders.raster_loader import RasterLoader # Importing RasterLoader of the new raster loader module
1920
from urban_mapper.modules.loader.loaders.shapefile_loader import ShapefileLoader
2021
from urban_mapper.utils import require_attributes
2122
from urban_mapper.utils.helpers.reset_attribute_before import reset_attributes_before
@@ -24,6 +25,12 @@
2425
".csv": {"class": CSVLoader, "requires_columns": True},
2526
".shp": {"class": ShapefileLoader, "requires_columns": False},
2627
".parquet": {"class": ParquetLoader, "requires_columns": True},
28+
# Adding of new formats supported by RasterLoader
29+
".tif": {"class": RasterLoader, "requires_columns": False},
30+
".tiff": {"class": RasterLoader, "requires_columns": False},
31+
".jp2": {"class": RasterLoader, "requires_columns": False},
32+
".png": {"class": RasterLoader, "requires_columns": False},
33+
2734
}
2835

2936

@@ -75,6 +82,7 @@ def __init__(self):
7582
self.crs: str = DEFAULT_CRS
7683
self._instance: Optional[LoaderBase] = None
7784
self._preview: Optional[dict] = None
85+
self.options = {}
7886

7987
@reset_attributes_before(
8088
["source_type", "source_data", "latitude_column", "longitude_column"]
@@ -84,8 +92,8 @@ def from_file(self, file_path: str) -> "LoaderFactory":
8492
8593
This method sets up the factory to load data from a file path. The file format
8694
is determined by the file extension. Supported formats include `CSV`, `shapefile`,
87-
and `Parquet`.
88-
95+
and `Parquet`.
96+
8997
Args:
9098
file_path: Path to the data file to load.
9199
@@ -463,8 +471,32 @@ def with_map(
463471
f"WITH_MAP: Initialised LoaderFactory with map_columns={map_columns}",
464472
)
465473
return self
474+
475+
def with_options(self, **options,) -> "LoaderFactory":
476+
"""
477+
Set additional key-value options to configure loader behavior.
478+
479+
This method allows you to specify arbitrary configuration options, such as block size, resolution, or other loader parameters. These options will be forwarded to the loader upon instantiation.
480+
481+
Args:
482+
**options: Arbitrary keyword arguments representing loader configuration options.
483+
484+
Returns:
485+
The LoaderFactory instance for method chaining.
486+
487+
Examples:
488+
>>> loader = mapper.loader.from_file("data/raster.tif")\
489+
... .with_options(block_size=10, use_polygons=True)
490+
"""
491+
self.options.update(options)
492+
logger.log(
493+
"DEBUG_LOW",
494+
f"WITH_OPTIONS: Updated LoaderFactory with options={options}",
495+
)
496+
return self
497+
466498

467-
def _load_from_file(self, coordinate_reference_system: str) -> gpd.GeoDataFrame:
499+
def _load_from_file(self, coordinate_reference_system: str):
468500
file_path: str = self.source_data
469501
file_ext = Path(file_path).suffix.lower()
470502
loader_class = FILE_LOADER_FACTORY[file_ext]["class"]
@@ -475,7 +507,8 @@ def _load_from_file(self, coordinate_reference_system: str) -> gpd.GeoDataFrame:
475507
coordinate_reference_system=coordinate_reference_system,
476508
map_columns=self.map_columns,
477509
)
478-
return self._instance.load_data_from_file()
510+
# Appel générique, le type de retour dépend du loader (GeoDataFrame pour tabulaire, dict/array pour raster)
511+
return self._instance._load_data_from_file()
479512

480513
def _load_from_dataframe(
481514
self, coordinate_reference_system: str
@@ -502,9 +535,9 @@ def _load_from_dataframe(
502535
return geo_dataframe
503536

504537
@require_attributes(["source_type", "source_data"])
505-
def load(self, coordinate_reference_system: str = DEFAULT_CRS) -> gpd.GeoDataFrame:
506-
"""Load the data and return it as a `GeoDataFrame`.
507-
538+
def load(self, coordinate_reference_system: str = DEFAULT_CRS):
539+
"""Load the data and return it as a `GeoDataFrame` or raster object.
540+
508541
This method loads the data from the configured source and returns it as a
509542
geopandas `GeoDataFrame`. It handles the details of loading from different
510543
source types and formats.
@@ -543,7 +576,7 @@ def load(self, coordinate_reference_system: str = DEFAULT_CRS) -> gpd.GeoDataFra
543576
loaded_data = self._load_from_file(coordinate_reference_system)
544577
if self._preview is not None:
545578
self.preview(format=self._preview["format"])
546-
return loaded_data
579+
return loaded_data # Peut être un GeoDataFrame ou un objet raster selon le loader
547580
elif self.source_type == "dataframe":
548581
if self.latitude_column == "None" or self.longitude_column == "None":
549582
raise ValueError(
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
from ..abc_loader import LoaderBase
2+
import rasterio
3+
from typing import Any
4+
import numpy as np
5+
import geopandas as gpd
6+
from shapely.geometry import Point
7+
from shapely.geometry import Polygon
8+
from rasterio.transform import xy
9+
from pyproj import CRS, Transformer
10+
11+
class RasterLoader:
12+
"""
13+
Loader for raster files (GeoTIFF, PNG+world file, etc.) with block-wise downsampling (average pooling) and polygons as geometry.
14+
Returns a GeoDataFrame where each row corresponds to an aggregated pixel (block).
15+
16+
It allows fast preview of raster properties, pixel-wise spatialization, and direct integration with the UrbanMapper factory.
17+
18+
Attributes:
19+
file_path (str): Path to the raster file to load.
20+
gdf (geopandas.GeoDataFrame): GeoDataFrame where each row is a pixel (with geometry, area, coordinates, and value).
21+
meta (dict): Raster metadata (dimensions, CRS, etc.).
22+
bounds (tuple): Geographic extent of the raster as (left, bottom, right, top).
23+
block_size (int): Size of the blocks for downsampling (default is 10, meaning 10x10 pixels).
24+
25+
Example:
26+
>>> rst_loader = (
27+
mapper
28+
.loader # From the loader module
29+
.from_file("file_path.tif") # To update with your own path
30+
)
31+
>>> gdf = rst_loader.load() # Load the data and return data
32+
>>> gdf
33+
>>> meta = rst_loader._instance.meta
34+
>>> bounds = rst_loader._instance.bounds
35+
36+
37+
"""
38+
39+
def __init__(self, file_path, block_size=10, **kwargs): # block_size est le facteur de downsampling (4x4 par défaut)
40+
self.file_path = file_path
41+
self.block_size = block_size
42+
self.meta = None
43+
self.bounds = None
44+
45+
def _downsample_band(self, band):
46+
"""
47+
Downsamples the raster band by averaging non-overlapping blocks of pixels.Effectue le downsampling par blocs et calcule la moyenne pour chaque bloc non-recouvrant.
48+
"""
49+
h, w = band.shape
50+
bs = self.block_size
51+
52+
# Découpe l'image aux dimensions multiples du block_size pour éviter les bords incomplets
53+
h_ds = h // bs
54+
w_ds = w // bs
55+
band_cropped = band[:h_ds * bs, :w_ds * bs]
56+
57+
# Remodèle pour avoir un 4D (h_blocks, block_size, w_blocks, block_size) puis moyenne par bloc
58+
band_blocks = band_cropped.reshape(h_ds, bs, w_ds, bs)
59+
band_ds = band_blocks.mean(axis=(1, 3)) # moyenne sur les axes blocs internes
60+
61+
return band_ds
62+
63+
def _load_data_from_file(self) -> gpd.GeoDataFrame:
64+
"""
65+
Loads raster data and returns a GeoDataFrame where each row represents an aggregated pixel (bloc) by downsampling (with geometry, area, coordinates, and value).
66+
NoData pixels are included with value set to None.
67+
68+
Returns :
69+
-------
70+
gpd.GeoDataFrame
71+
A GeoDataFrame with columns: pixel_id, row, col, area, latitude, longitude, value, geometry.
72+
Raises:
73+
RuntimeError: If there is an error while loading the raster file.
74+
NB : the loader doesn't return metadata and bounds, but they are stored in the instance attributes (cf docstring example).
75+
"""
76+
try:
77+
with rasterio.open(self.file_path) as src:
78+
band = src.read(1)
79+
transform = src.transform
80+
crs = src.crs
81+
nodata = src.nodata
82+
83+
self.meta = src.meta
84+
self.bounds = src.bounds
85+
86+
# Handle NoData:
87+
if nodata is not None:
88+
band = np.where(band == nodata, np.nan, band)
89+
90+
# Downsampling
91+
band_ds = self._downsample_band(band)
92+
h_ds, w_ds = band_ds.shape
93+
94+
# Generate indices for the downsampled raster
95+
rows, cols = np.indices((h_ds, w_ds))
96+
97+
# Calcul coordinates of the center of each block
98+
bs = self.block_size
99+
center_rows = rows * bs + bs // 2
100+
center_cols = cols * bs + bs // 2
101+
102+
# Transform raster to world coordinates for block centers
103+
xs, ys = rasterio.transform.xy(transform, center_rows, center_cols)
104+
xs = np.array(xs).flatten()
105+
ys = np.array(ys).flatten()
106+
107+
# Flatten values for the downsampled band
108+
values = band_ds.flatten()
109+
110+
# Geometry creation : Polygon for each aggregated pixel
111+
# Each pixel is represented as a polygon with 4 corners
112+
geoms = []
113+
for r, c in zip(center_rows.flatten(), center_cols.flatten()):
114+
# (r, c) = ligne et colonne du bloc (dans la grille agrégée)
115+
min_row = r - bs // 2
116+
min_col = c - bs // 2
117+
max_row = min_row + bs
118+
max_col = min_col + bs
119+
120+
# Collect coordinates of the 4 corners of the aggregated pixel :
121+
corners = [
122+
rasterio.transform.xy(transform, min_row, min_col, offset='ul'), # haut-gauche
123+
rasterio.transform.xy(transform, min_row, max_col, offset='ur'), # haut-droit
124+
rasterio.transform.xy(transform, max_row, max_col, offset='lr'), # bas-droit
125+
rasterio.transform.xy(transform, max_row, min_col, offset='ll') # bas-gauche
126+
]
127+
poly = Polygon(corners)
128+
geoms.append(poly)
129+
130+
131+
# Latitude/longitude per transformation from CRS to WGS84
132+
transformer = Transformer.from_crs(crs, 4326, always_xy=True)
133+
lon, lat = transformer.transform(xs, ys)
134+
135+
# Area for each block (in projected CRS units) — area of the block
136+
if CRS.from_user_input(crs).is_projected:
137+
pixel_width = abs(transform.a)
138+
pixel_height = abs(transform.e)
139+
area = (pixel_width * pixel_height) * (bs * bs)
140+
areas = np.full(values.shape, area)
141+
else:
142+
areas = [None] * values.size # Hors CRS projeté, zone complexe : à raffiner si utile
143+
144+
145+
# Create GeoDataFrame with pixel_id, row, col, area, value, latitude, longitude, and geometry
146+
gdf = gpd.GeoDataFrame({
147+
'pixel_id': np.arange(len(values)),
148+
'row': rows.flatten(),
149+
'col': cols.flatten(),
150+
'area': areas,
151+
'value': values,
152+
'latitude': lat,
153+
'longitude': lon,
154+
'geometry': geoms
155+
}, crs=crs)
156+
157+
return gdf
158+
159+
except Exception as e:
160+
raise RuntimeError(f"Error while loading downsampled raster: {e}")
161+
162+
163+
def preview(self, format: str = "ascii") -> Any:
164+
"""
165+
Generates a preview of the loaded raster information.
166+
167+
Args:
168+
format (str): Output format ("ascii" for text display, "json" for dictionary).
169+
170+
Returns:
171+
str or dict: A summary of the raster properties.
172+
173+
Raises:
174+
ValueError: If the requested format is not supported.
175+
176+
"""
177+
# If metadata is not loaded, try to load it
178+
if self.meta is None:
179+
try:
180+
with rasterio.open(self.file_path) as src:
181+
self.meta = src.meta
182+
except Exception as e:
183+
return f"Unable to open raster: {e}"
184+
185+
# Get main raster information
186+
shape = (
187+
self.meta.get("count", "?"),
188+
self.meta.get("height", "?"),
189+
self.meta.get("width", "?")
190+
)
191+
dtype = self.meta.get("dtype", "?")
192+
crs = self.meta.get("crs", "?")
193+
194+
# Return preview according to requested format
195+
if format == "ascii":
196+
return (
197+
f"Loader: RasterLoader\n"
198+
f" File: {self.file_path}\n"
199+
f" Dimensions (bands, height, width): {shape}\n"
200+
f" Data type: {dtype}\n"
201+
f" CRS: {crs}"
202+
)
203+
elif format == "json":
204+
return {
205+
"loader": "RasterLoader",
206+
"file": self.file_path,
207+
"shape": shape,
208+
"dtype": str(dtype),
209+
"crs": str(crs)
210+
}
211+
else:
212+
raise ValueError(f"Unsupported format: {format}")
213+
214+
215+
216+

0 commit comments

Comments
 (0)