Skip to content

Commit ca11778

Browse files
committed
feat: add backwards compatibility for the expression parameter
1 parent 608349d commit ca11778

File tree

5 files changed

+297
-27
lines changed

5 files changed

+297
-27
lines changed

tests/test_dependencies.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55

66
from titiler.cmr import dependencies
7-
from titiler.cmr.dependencies import InterpolatedXarrayParams
7+
from titiler.cmr.dependencies import (
8+
CMRAssetsExprParams,
9+
CMRXarrayExprParams,
10+
InterpolatedXarrayParams,
11+
)
812
from titiler.cmr.models import GranuleSearch
913

1014

@@ -24,6 +28,40 @@ def test_granule_search_temporal_interval_unchanged():
2428
assert GranuleSearch(temporal=half_open).temporal == half_open
2529

2630

31+
def test_cmr_assets_expr_params_three_assets():
32+
"""Legacy expression with three assets: all detected and mapped in appearance order."""
33+
params = CMRAssetsExprParams(expression="(NIR-RED)/(NIR+RED+BLUE)")
34+
assert list(params.assets) == ["NIR", "RED", "BLUE"]
35+
assert params.expression == "(b1-b2)/(b1+b2+b3)"
36+
37+
38+
def test_cmr_assets_expr_params_substring_asset_names():
39+
"""Asset names that are substrings of each other are distinguished by word boundaries."""
40+
params = CMRAssetsExprParams(assets=["B4", "B4_mask"], expression="B4 * B4_mask")
41+
assert params.expression == "b1 * b2"
42+
43+
44+
def test_cmr_assets_expr_params_math_functions():
45+
"""Math function names (sqrt, log) are not treated as asset names."""
46+
params = CMRAssetsExprParams(expression="sqrt(NIR)/log(RED)")
47+
assert list(params.assets) == ["NIR", "RED"]
48+
assert params.expression == "sqrt(b1)/log(b2)"
49+
50+
51+
def test_cmr_assets_expr_params_extra_assets():
52+
"""Assets list with more entries than referenced in expression: order preserved, extras ignored."""
53+
params = CMRAssetsExprParams(assets=["B03", "B04", "B05"], expression="B04-B05")
54+
assert list(params.assets) == ["B03", "B04", "B05"]
55+
assert params.expression == "b2-b3"
56+
57+
58+
def test_cmr_assets_expr_params_b_prefix_asset_names():
59+
"""Asset names starting with 'b' but not new-style (e.g. blue, band1) are treated as legacy."""
60+
params = CMRAssetsExprParams(expression="(blue-red)/(blue+red)")
61+
assert list(params.assets) == ["blue", "red"]
62+
assert params.expression == "(b1-b2)/(b1+b2)"
63+
64+
2765
def test_interpolated_xarray_params_single_datetime():
2866
"""Test InterpolatedXarrayParams with single datetime interpolation."""
2967
xarray_params = InterpolatedXarrayParams(
@@ -94,6 +132,72 @@ def test_interpolated_xarray_params_no_sel():
94132
assert result.variables == ["temperature"]
95133

96134

135+
def test_cmr_assets_expr_params_legacy_no_assets():
136+
"""Legacy expression with no assets: auto-detect assets and translate expression."""
137+
params = CMRAssetsExprParams(expression="(B04-B05)/(B05+B04)")
138+
assert list(params.assets) == ["B04", "B05"]
139+
assert params.expression == "(b1-b2)/(b2+b1)"
140+
141+
142+
def test_cmr_assets_expr_params_legacy_with_assets():
143+
"""Legacy expression with assets provided: use assets order for mapping."""
144+
params = CMRAssetsExprParams(
145+
assets=["B05", "B04"], expression="(B04-B05)/(B04+B05)"
146+
)
147+
assert list(params.assets) == ["B05", "B04"]
148+
assert params.expression == "(b2-b1)/(b2+b1)"
149+
150+
151+
def test_cmr_assets_expr_params_new_style_passthrough():
152+
"""New-style expression (b1, b2, ...) passes through unchanged."""
153+
params = CMRAssetsExprParams(assets=["B04", "B05"], expression="(b1-b2)/(b1+b2)")
154+
assert list(params.assets) == ["B04", "B05"]
155+
assert params.expression == "(b1-b2)/(b1+b2)"
156+
157+
158+
def test_cmr_assets_expr_params_no_expression():
159+
"""No expression: assets unchanged, no error."""
160+
params = CMRAssetsExprParams(assets=["B04"])
161+
assert list(params.assets) == ["B04"]
162+
assert params.expression is None
163+
164+
165+
def test_cmr_xarray_expr_params_legacy_variable_names():
166+
"""Legacy variable names are translated to positional bN format."""
167+
params = CMRXarrayExprParams(
168+
variables=["temperature", "pressure"], expression="temperature/pressure"
169+
)
170+
assert params.expression == "b1/b2"
171+
172+
173+
def test_cmr_xarray_expr_params_legacy_partial_match():
174+
"""Legacy NDVI-style expression with partial variable subset."""
175+
params = CMRXarrayExprParams(
176+
variables=["nir", "red", "green"], expression="(nir-red)/(nir+red)"
177+
)
178+
assert params.expression == "(b1-b2)/(b1+b2)"
179+
180+
181+
def test_cmr_xarray_expr_params_new_style_passthrough():
182+
"""New-style bN expressions pass through unchanged."""
183+
params = CMRXarrayExprParams(variables=["nir", "red"], expression="(b1-b2)/(b1+b2)")
184+
assert params.expression == "(b1-b2)/(b1+b2)"
185+
186+
187+
def test_cmr_xarray_expr_params_with_math_functions():
188+
"""Math function names are not treated as variable names."""
189+
params = CMRXarrayExprParams(
190+
variables=["nir", "red"], expression="log10(nir)/sqrt(red)"
191+
)
192+
assert params.expression == "log10(b1)/sqrt(b2)"
193+
194+
195+
def test_cmr_xarray_expr_params_no_expression():
196+
"""No expression: no error, variables unchanged."""
197+
params = CMRXarrayExprParams(variables=["nir"], expression=None)
198+
assert params.expression is None
199+
200+
97201
def test_interpolated_xarray_params_multiple_templates():
98202
"""Test InterpolatedXarrayParams with multiple datetime templates."""
99203
xarray_params = InterpolatedXarrayParams(

titiler/cmr/dependencies.py

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
"""titiler.cmr FastAPI dependencies."""
22

3+
import re
34
from collections.abc import Callable
45
from dataclasses import dataclass, field
56
from datetime import datetime
67
from typing import Annotated, List, Optional
78

89
from fastapi import Depends, HTTPException, Query, Request
910
from httpx import Client
10-
from titiler.core.dependencies import DefaultDependency, ExpressionParams
11+
from pydantic import AfterValidator
12+
from titiler.core.dependencies import (
13+
AssetsExprParams,
14+
DefaultDependency,
15+
ExpressionParams,
16+
_parse_asset,
17+
)
1118
from titiler.xarray.dependencies import SelDimStr, XarrayIOParams
1219

1320
from titiler.cmr.models import (
@@ -58,7 +65,11 @@ def GranuleSearchParams(
5865

5966
@dataclass(init=False)
6067
class BackendParams(DefaultDependency):
61-
"""backend parameters."""
68+
"""Reader backend parameters sourced from application state.
69+
70+
Reads the HTTP client, Earthdata auth token, S3 access flag, and S3
71+
credential provider from the FastAPI app state on each request.
72+
"""
6273

6374
client: Client = field(init=False)
6475
auth_token: str | None = field(init=False)
@@ -76,7 +87,7 @@ def __init__(self, request: Request):
7687

7788
@dataclass
7889
class GranuleSearchBackendParams(DefaultDependency):
79-
"""PgSTAC parameters."""
90+
"""Backend parameters controlling granule search coverage behaviour."""
8091

8192
items_limit: Annotated[
8293
int | None,
@@ -118,6 +129,72 @@ def __post_init__(self):
118129
self.bands_regex = None
119130

120131

132+
def _translate_legacy_expr(expression: str, names: list[str]) -> str:
133+
"""Translate legacy name-based expression to positional bN format.
134+
135+
If the expression contains identifiers from `names` (not already bN-style,
136+
not function calls), replaces them with b1, b2, ... based on their position
137+
in `names`.
138+
"""
139+
identifiers = re.findall(r"\b([a-zA-Z_]\w*)\b(?!\s*\()", expression)
140+
new_style = re.compile(r"^b[1-9][0-9]*$", re.IGNORECASE)
141+
legacy = list(dict.fromkeys(n for n in identifiers if not new_style.match(n)))
142+
if not legacy:
143+
return expression
144+
mapping = {name: f"b{i + 1}" for i, name in enumerate(names)}
145+
expr = expression
146+
for name, band_ref in mapping.items():
147+
expr = re.sub(r"\b" + re.escape(name) + r"\b", band_ref, expr)
148+
return expr
149+
150+
151+
@dataclass
152+
class CMRAssetsExprParams(AssetsExprParams):
153+
"""AssetsExprParams with backwards-compatible legacy expression translation.
154+
155+
Detects legacy expressions that reference asset names directly (e.g. B04, NIR)
156+
and translates them to the new rio-tiler 9.0 positional band format (b1, b2, ...).
157+
"""
158+
159+
assets: Annotated[
160+
list[str] | None,
161+
AfterValidator(_parse_asset),
162+
Query(
163+
title="Asset names",
164+
description="Asset's names.",
165+
),
166+
] = None
167+
168+
def __post_init__(self):
169+
"""Translate legacy asset-name expressions to positional bN format.
170+
171+
If the expression already uses bN references (e.g. b1-b2), it is left
172+
unchanged. Otherwise, identifiers are matched against the provided or
173+
auto-detected asset list and substituted with b1, b2, ... in order.
174+
"""
175+
if not self.expression:
176+
return
177+
178+
identifiers = re.findall(r"\b([a-zA-Z_]\w*)\b(?!\s*\()", self.expression)
179+
new_style_pattern = re.compile(r"^b[1-9][0-9]*$", re.IGNORECASE)
180+
asset_names = list(
181+
dict.fromkeys(
182+
name for name in identifiers if not new_style_pattern.match(name)
183+
)
184+
)
185+
186+
if not asset_names:
187+
return
188+
189+
if self.assets:
190+
ordered_assets = list(self.assets)
191+
else:
192+
ordered_assets = asset_names
193+
self.assets = ordered_assets
194+
195+
self.expression = _translate_legacy_expr(self.expression, ordered_assets)
196+
197+
121198
@dataclass
122199
class XarrayDsParams(DefaultDependency):
123200
"""Xarray Dataset Options."""
@@ -152,12 +229,29 @@ class InterpolatedXarrayParams(XarrayParams):
152229
] = None
153230

154231

232+
@dataclass
233+
class CMRXarrayExprParams(InterpolatedXarrayParams):
234+
"""InterpolatedXarrayParams with legacy variable-name expression translation.
235+
236+
Translates expressions like `temperature/pressure` to `b1/b2` based on the
237+
order of `variables`.
238+
"""
239+
240+
def __post_init__(self):
241+
"""Translate legacy variable-name expressions to positional bN format.
242+
243+
Skipped when expression is already new-style (contains only bN refs) or
244+
when no expression is provided. Safe to call more than once — already-
245+
translated expressions are returned unchanged.
246+
"""
247+
if self.expression and self.variables:
248+
self.expression = _translate_legacy_expr(self.expression, self.variables)
249+
250+
155251
def interpolated_xarray_ds_params(
156-
xarray_params: Annotated[
157-
InterpolatedXarrayParams, Depends(InterpolatedXarrayParams)
158-
],
252+
xarray_params: Annotated[CMRXarrayExprParams, Depends(CMRXarrayExprParams)],
159253
granule_search: Annotated[GranuleSearch, Depends(GranuleSearchParams)],
160-
) -> InterpolatedXarrayParams:
254+
) -> CMRXarrayExprParams:
161255
"""
162256
Xarray parameters with string interpolation support for the sel parameter.
163257
@@ -184,9 +278,10 @@ def interpolated_xarray_ds_params(
184278
else:
185279
interpolated_sel.append(sel_item)
186280

187-
return InterpolatedXarrayParams(
281+
return CMRXarrayExprParams(
188282
variables=xarray_params.variables,
189283
group=xarray_params.group,
190284
sel=interpolated_sel,
191285
decode_times=xarray_params.decode_times,
286+
expression=xarray_params.expression,
192287
)

titiler/cmr/expression.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
"""Expression evaluation utilities for xarray band math."""
2+
3+
from typing import Any
4+
5+
import numpy as np
6+
import xarray as xr
7+
from rio_tiler.expression import get_expression_blocks
8+
9+
from titiler.cmr.logger import logger
10+
11+
_MATH_FUNCTIONS: dict[str, Any] = {
12+
"abs": np.abs,
13+
"ceil": np.ceil,
14+
"floor": np.floor,
15+
"round": np.round,
16+
"trunc": np.trunc,
17+
"sign": np.sign,
18+
"sqrt": np.sqrt,
19+
"exp": np.exp,
20+
"expm1": np.expm1,
21+
"log": np.log,
22+
"log1p": np.log1p,
23+
"log10": np.log10,
24+
"log2": np.log2,
25+
"sin": np.sin,
26+
"cos": np.cos,
27+
"tan": np.tan,
28+
"arcsin": np.arcsin,
29+
"arccos": np.arccos,
30+
"arctan": np.arctan,
31+
"arctan2": np.arctan2,
32+
"sinh": np.sinh,
33+
"cosh": np.cosh,
34+
"tanh": np.tanh,
35+
"arcsinh": np.arcsinh,
36+
"arccosh": np.arccosh,
37+
"arctanh": np.arctanh,
38+
"isnan": np.isnan,
39+
"isfinite": np.isfinite,
40+
"isinf": np.isinf,
41+
"signbit": np.signbit,
42+
"fmod": np.fmod,
43+
"hypot": np.hypot,
44+
"maximum": np.maximum,
45+
"minimum": np.minimum,
46+
"where": np.where,
47+
}
48+
49+
50+
def apply_expression(
51+
da: xr.DataArray,
52+
expression: str,
53+
) -> xr.DataArray:
54+
"""Evaluate a band-math expression against a DataArray.
55+
56+
The DataArray must have a "band" dimension. Each band is exposed as b1, b2, ...
57+
in the expression namespace, along with numexpr-compatible math functions and
58+
the full `np` and `xr` namespaces for backwards compatibility.
59+
60+
Args:
61+
da: Input DataArray with a "band" dimension.
62+
expression: Band-math expression string (e.g. "log10(b1)/sqrt(b2)").
63+
64+
Returns:
65+
Result DataArray, preserving the CRS if present.
66+
"""
67+
logger.info(f"applying expression: {expression}")
68+
pre_expression_crs = da.rio.crs
69+
expression_blocks = get_expression_blocks(expression)
70+
band_vars = {
71+
f"b{i + 1}": da.isel(band=i, drop=True) for i in range(da.sizes["band"])
72+
}
73+
namespace = {"np": np, "xr": xr, **_MATH_FUNCTIONS, **band_vars}
74+
results = [
75+
eval(block, {"__builtins__": {}}, namespace) for block in expression_blocks
76+
]
77+
result = results[0] if len(results) == 1 else xr.concat(results, dim="band")
78+
if pre_expression_crs is not None:
79+
result = result.rio.write_crs(pre_expression_crs)
80+
return result

0 commit comments

Comments
 (0)