Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog/209.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fixed the behaviour of FacetFilter with `keep=False` so all facets need to match
before excluding a file.
9 changes: 8 additions & 1 deletion packages/ref-core/src/cmip_ref_core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def apply_filters(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
Filtered data catalog
"""
for facet_filter in self.filters:
mask = None
for facet, value in facet_filter.facets.items():
clean_value = value if isinstance(value, tuple) else (value,)

Expand All @@ -336,7 +337,13 @@ def apply_filters(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
f"Facet {facet!r} not in data catalog columns: {data_catalog.columns.to_list()}"
)

mask = data_catalog[facet].isin(clean_value)
facet_mask = data_catalog[facet].isin(clean_value)
if mask is None:
mask = facet_mask
else:
mask &= facet_mask

if mask is not None:
if not facet_filter.keep:
mask = ~mask

Expand Down
34 changes: 34 additions & 0 deletions packages/ref-core/tests/unit/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,40 @@ def test_apply_filters_dont_keep(apply_data_catalog):
)


def test_apply_filters_dont_keep_multifacet(apply_data_catalog):
"""Test that all facet values must match to exclude a file from the catalog."""
requirement = DataRequirement(
source_type=SourceDatasetType.CMIP6,
filters=(
FacetFilter(
{
"variable": "tas",
"source_id": "CAS",
},
keep=False,
),
),
group_by=None,
)

filtered = requirement.apply_filters(apply_data_catalog)
pd.testing.assert_frame_equal(
filtered,
pd.DataFrame(
{
"variable": ["tas", "pr", "rsut", "tas"],
"source_id": [
"CESM2",
"CESM2",
"CESM2",
"ACCESS",
],
},
index=[0, 1, 2, 3],
),
)


def test_apply_filters_missing(apply_data_catalog):
requirement = DataRequirement(
source_type=SourceDatasetType.CMIP6,
Expand Down