Skip to content

Commit a6ee251

Browse files
committed
Issue #699 migrate asset handling from bands_from_stac_collection
1 parent a276c4b commit a6ee251

File tree

4 files changed

+73
-45
lines changed

4 files changed

+73
-45
lines changed

openeo/metadata.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from openeo.internal.jupyter import render_component
2525
from openeo.util import Rfc3339, deep_get
26-
from openeo.utils.normalize import normalize_resample_resolution
26+
from openeo.utils.normalize import normalize_resample_resolution, unique
2727

2828
_log = logging.getLogger(__name__)
2929

@@ -655,23 +655,6 @@ def metadata_from_stac(url: str) -> CubeMetadata:
655655
:return: A :py:class:`CubeMetadata` containing the DataCube band metadata from the url.
656656
"""
657657

658-
# TODO move these nested functions and other logic to _StacMetadataParser
659-
660-
def get_band_metadata(eo_bands_location: dict) -> List[Band]:
661-
# TODO #699 eliminate or migrate to _StacMetadataParser
662-
# TODO: return None iso empty list when no metadata?
663-
return [
664-
Band(name=band["name"], common_name=band.get("common_name"), wavelength_um=band.get("center_wavelength"))
665-
for band in eo_bands_location.get("eo:bands", [])
666-
]
667-
668-
def get_band_names(bands: List[Band]) -> List[str]:
669-
# TODO #699 eliminate or migrate to _StacMetadataParser
670-
return [band.name for band in bands]
671-
672-
def is_band_asset(asset: pystac.Asset) -> bool:
673-
# TODO #699 eliminate or migrate to _StacMetadataParser
674-
return "eo:bands" in asset.extra_fields
675658

676659
stac_object = pystac.read_file(href=url)
677660

@@ -681,17 +664,8 @@ def is_band_asset(asset: pystac.Asset) -> bool:
681664
elif isinstance(stac_object, pystac.Collection):
682665
# TODO #699: migrate to _StacMetadataParser
683666
collection = stac_object
684-
bands = get_band_metadata(collection.summaries.lists)
685-
686-
# Summaries is not a required field in a STAC collection, so also check the assets
687-
for itm in collection.get_items():
688-
band_assets = {asset_id: asset for asset_id, asset in itm.get_assets().items() if is_band_asset(asset)}
667+
bands = _StacMetadataParser().bands_from_stac_collection(collection=stac_object)
689668

690-
for asset in band_assets.values():
691-
asset_bands = get_band_metadata(asset.extra_fields)
692-
for asset_band in asset_bands:
693-
if asset_band.name not in get_band_names(bands):
694-
bands.append(asset_band)
695669
if _PYSTAC_1_9_EXTENSION_INTERFACE and collection.ext.has("item_assets"):
696670
# TODO #575 support unordered band names and avoid conversion to a list.
697671
bands = list(_StacMetadataParser().get_bands_from_item_assets(collection.ext.item_assets))
@@ -739,6 +713,12 @@ def __init__(self, bands: Iterable[Band]):
739713
def band_names(self) -> List[str]:
740714
return [band.name for band in self]
741715

716+
@classmethod
717+
def merge(cls, band_lists: Iterable["_Bands"]) -> "_Bands":
718+
"""Merge multiple lists of bands into a single list (unique by name)."""
719+
all_bands = (band for bands in band_lists for band in bands)
720+
return cls(unique(all_bands, key=lambda b: b.name))
721+
742722
def __init__(self, *, logger=_log, log_level=logging.DEBUG):
743723
self._logger = logger
744724
self._log_level = log_level
@@ -886,33 +866,48 @@ def bands_from_stac_catalog(self, catalog: pystac.Catalog) -> _Bands:
886866
return self._Bands(self._band_from_common_bands_metadata(b) for b in summaries["bands"])
887867

888868
# TODO: instead of warning: exception, or return None?
889-
self._warn("bands_from_stac_catalog no band name source found")
869+
self._warn("bands_from_stac_catalog: no band name source found")
890870
return self._Bands([])
891871

892-
def bands_from_stac_collection(self, collection: pystac.Collection) -> _Bands:
872+
def bands_from_stac_collection(
873+
self, collection: pystac.Collection, *, consult_items: bool = True, consult_assets: bool = True
874+
) -> _Bands:
893875
# TODO: "eo:bands" vs "bands" priority based on STAC and EO extension version information
894876
self._log(f"bands_from_stac_collection with {collection.summaries.lists.keys()=}")
895877
if "eo:bands" in collection.summaries.lists:
896878
return self._Bands(self._band_from_eo_bands_metadata(b) for b in collection.summaries.lists["eo:bands"])
897879
elif "bands" in collection.summaries.lists:
898880
return self._Bands(self._band_from_common_bands_metadata(b) for b in collection.summaries.lists["bands"])
881+
elif consult_items:
882+
bands = self._Bands.merge(
883+
self.bands_from_stac_item(item=i, consult_parent=False, consult_assets=consult_assets)
884+
for i in collection.get_items()
885+
)
886+
if bands:
887+
return bands
899888

900889
# TODO: instead of warning: exception, or return None?
901-
self._warn("bands_from_stac_collection no band name source found")
890+
self._warn("bands_from_stac_collection: no band name source found")
902891
return self._Bands([])
903892

904-
def bands_from_stac_item(self, item: pystac.Item) -> _Bands:
893+
def bands_from_stac_item(
894+
self, item: pystac.Item, *, consult_parent: bool = True, consult_assets: bool = True
895+
) -> _Bands:
905896
# TODO: "eo:bands" vs "bands" priority based on STAC and EO extension version information
906897
self._log(f"bands_from_stac_item with {item.properties.keys()=}")
907898
if "eo:bands" in item.properties:
908899
return self._Bands(self._band_from_eo_bands_metadata(b) for b in item.properties["eo:bands"])
909900
elif "bands" in item.properties:
910901
return self._Bands(self._band_from_common_bands_metadata(b) for b in item.properties["bands"])
911-
elif (parent_collection := item.get_collection()) is not None:
902+
elif consult_parent and (parent_collection := item.get_collection()) is not None:
912903
return self.bands_from_stac_collection(collection=parent_collection)
904+
elif consult_assets:
905+
bands = self._Bands.merge(self.bands_from_stac_asset(asset=a) for a in item.get_assets().values())
906+
if bands:
907+
return bands
913908

914909
# TODO: instead of warning: exception, or return None?
915-
self._warn("bands_from_stac_item no band name source found")
910+
self._warn("bands_from_stac_item: no band name source found")
916911
return self._Bands([])
917912

918913
def bands_from_stac_asset(self, asset: pystac.Asset) -> _Bands:
@@ -925,5 +920,5 @@ def bands_from_stac_asset(self, asset: pystac.Asset) -> _Bands:
925920
return self._Bands(self._band_from_common_bands_metadata(b) for b in asset.extra_fields["bands"])
926921

927922
# TODO: instead of warning: exception, or return None?
928-
self._warn("bands_from_stac_asset no band name source found")
923+
self._warn("bands_from_stac_asset: no band name source found")
929924
return self._Bands([])

openeo/utils/normalize.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple, Union
1+
from typing import Iterable, Tuple, Union
22

33

44
def normalize_resample_resolution(
@@ -14,3 +14,13 @@ def normalize_resample_resolution(
1414
):
1515
return tuple(resolution)
1616
raise ValueError(f"Invalid resolution {resolution!r}")
17+
18+
19+
def unique(iterable, key=lambda x: x) -> Iterable:
20+
"""Deduplicate an iterable based on a key function."""
21+
seen = set()
22+
for x in iterable:
23+
k = key(x)
24+
if k not in seen:
25+
seen.add(k)
26+
yield x

tests/rest/test_connection.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3114,19 +3114,22 @@ def test_load_stac_band_filtering_no_band_metadata_default(self, dummy_backend,
31143114
}
31153115

31163116
@pytest.mark.parametrize(
3117-
["bands", "has_band_dimension", "expected_pg_args", "expected_warning"],
3117+
["bands", "has_band_dimension", "expected_pg_args", "expected_warnings"],
31183118
[
3119-
(None, False, {}, None),
3119+
(None, False, {}, ["bands_from_stac_collection: no band name source found"]),
31203120
(
31213121
["B02", "B03"],
31223122
True,
31233123
{"bands": ["B02", "B03"]},
3124-
"Bands ['B02', 'B03'] were specified in `load_stac`, but no band dimension was detected in the STAC metadata. Working with band dimension and specified bands.",
3124+
[
3125+
"bands_from_stac_collection: no band name source found",
3126+
"Bands ['B02', 'B03'] were specified in `load_stac`, but no band dimension was detected in the STAC metadata. Working with band dimension and specified bands.",
3127+
],
31253128
),
31263129
],
31273130
)
31283131
def test_load_stac_band_filtering_no_band_dimension(
3129-
self, dummy_backend, build_stac_ref, bands, has_band_dimension, expected_pg_args, expected_warning, caplog
3132+
self, dummy_backend, build_stac_ref, bands, has_band_dimension, expected_pg_args, expected_warnings, caplog
31303133
):
31313134
stac_ref = build_stac_ref(StacDummyBuilder.collection())
31323135

@@ -3152,10 +3155,7 @@ def metadata_from_stac(url: str):
31523155
"url": stac_ref,
31533156
}
31543157

3155-
if expected_warning:
3156-
assert expected_warning in caplog.text
3157-
else:
3158-
assert not caplog.text
3158+
assert caplog.messages == expected_warnings
31593159

31603160
def test_load_stac_band_filtering_no_band_metadata(self, dummy_backend, build_stac_ref, caplog):
31613161
caplog.set_level(logging.WARNING)

tests/utils/test_nomalize.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from typing import Iterable
2+
13
import pytest
24

3-
from openeo.utils.normalize import normalize_resample_resolution
5+
from openeo.utils.normalize import normalize_resample_resolution, unique
46

57

68
@pytest.mark.parametrize(
@@ -27,3 +29,24 @@ def test_normalize_resample_resolution(resolution, expected):
2729
def test_normalize_resample_resolution(resolution):
2830
with pytest.raises(ValueError, match="Invalid resolution"):
2931
normalize_resample_resolution(resolution)
32+
33+
34+
@pytest.mark.parametrize(
35+
"input, expected",
36+
[
37+
([], []),
38+
(["foo"], ["foo"]),
39+
("foo", ["f", "o"]),
40+
([1, 2, 2, 3, 1, 2, 3, 1, 4], [1, 2, 3, 4]),
41+
(["a", "b", "a", "c"], ["a", "b", "c"]),
42+
([(1, 2), (1, 2), (2, 3)], [(1, 2), (2, 3)]),
43+
([1, "a", "1", "a"], [1, "a", "1"]),
44+
((x for x in [1, 2, 2, 3]), [1, 2, 3]),
45+
(range(5), [0, 1, 2, 3, 4]),
46+
(iter("hello"), ["h", "e", "l", "o"]),
47+
],
48+
)
49+
def test_unique(input, expected):
50+
actual = unique(input)
51+
assert isinstance(actual, Iterable)
52+
assert list(actual) == expected

0 commit comments

Comments
 (0)