Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog/209.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed the behaviour of FacetFilter with `keep=False` so all facets need to match
before excluding a file.
11 changes: 6 additions & 5 deletions packages/ref-core/src/cmip_ref_core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,19 +328,20 @@ def apply_filters(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
Filtered data catalog
"""
for facet_filter in self.filters:
values = {}
for facet, value in facet_filter.facets.items():
clean_value = value if isinstance(value, tuple) else (value,)

if facet not in data_catalog.columns:
raise KeyError(
f"Facet {facet!r} not in data catalog columns: {data_catalog.columns.to_list()}"
)
values[facet] = clean_value

mask = data_catalog[facet].isin(clean_value)
if not facet_filter.keep:
mask = ~mask

data_catalog = data_catalog[mask]
mask = data_catalog[list(values)].isin(values).all(axis="columns")
if not facet_filter.keep:
mask = ~mask
data_catalog = data_catalog[mask]
return data_catalog


Expand Down
34 changes: 34 additions & 0 deletions packages/ref-core/tests/unit/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,40 @@ def test_apply_filters_dont_keep(apply_data_catalog):
)


def test_apply_filters_dont_keep_multifacet(apply_data_catalog):
"""Test that all facet values must match to exclude a file from the catalog."""
requirement = DataRequirement(
source_type=SourceDatasetType.CMIP6,
filters=(
FacetFilter(
{
"variable": "tas",
"source_id": "CAS",
},
keep=False,
),
),
group_by=None,
)

filtered = requirement.apply_filters(apply_data_catalog)
pd.testing.assert_frame_equal(
filtered,
pd.DataFrame(
{
"variable": ["tas", "pr", "rsut", "tas"],
"source_id": [
"CESM2",
"CESM2",
"CESM2",
"ACCESS",
],
},
index=[0, 1, 2, 3],
),
)


def test_apply_filters_missing(apply_data_catalog):
requirement = DataRequirement(
source_type=SourceDatasetType.CMIP6,
Expand Down