Skip to content

Commit 5b5f21f

Browse files
authored
Match all facet values when using FacetFilter with keep=False (#209)
1 parent 78cd1c8 commit 5b5f21f

3 files changed

Lines changed: 42 additions & 5 deletions

File tree

changelog/209.fix.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fixed the behaviour of FacetFilter with `keep=False` so all facets need to match
2+
before excluding a file.

packages/ref-core/src/cmip_ref_core/metrics.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,19 +328,20 @@ def apply_filters(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
328328
Filtered data catalog
329329
"""
330330
for facet_filter in self.filters:
331+
values = {}
331332
for facet, value in facet_filter.facets.items():
332333
clean_value = value if isinstance(value, tuple) else (value,)
333334

334335
if facet not in data_catalog.columns:
335336
raise KeyError(
336337
f"Facet {facet!r} not in data catalog columns: {data_catalog.columns.to_list()}"
337338
)
339+
values[facet] = clean_value
338340

339-
mask = data_catalog[facet].isin(clean_value)
340-
if not facet_filter.keep:
341-
mask = ~mask
342-
343-
data_catalog = data_catalog[mask]
341+
mask = data_catalog[list(values)].isin(values).all(axis="columns")
342+
if not facet_filter.keep:
343+
mask = ~mask
344+
data_catalog = data_catalog[mask]
344345
return data_catalog
345346

346347

packages/ref-core/tests/unit/test_metrics.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,40 @@ def test_apply_filters_dont_keep(apply_data_catalog):
336336
)
337337

338338

339+
def test_apply_filters_dont_keep_multifacet(apply_data_catalog):
340+
"""Test that all facet values must match to exclude a file from the catalog."""
341+
requirement = DataRequirement(
342+
source_type=SourceDatasetType.CMIP6,
343+
filters=(
344+
FacetFilter(
345+
{
346+
"variable": "tas",
347+
"source_id": "CAS",
348+
},
349+
keep=False,
350+
),
351+
),
352+
group_by=None,
353+
)
354+
355+
filtered = requirement.apply_filters(apply_data_catalog)
356+
pd.testing.assert_frame_equal(
357+
filtered,
358+
pd.DataFrame(
359+
{
360+
"variable": ["tas", "pr", "rsut", "tas"],
361+
"source_id": [
362+
"CESM2",
363+
"CESM2",
364+
"CESM2",
365+
"ACCESS",
366+
],
367+
},
368+
index=[0, 1, 2, 3],
369+
),
370+
)
371+
372+
339373
def test_apply_filters_missing(apply_data_catalog):
340374
requirement = DataRequirement(
341375
source_type=SourceDatasetType.CMIP6,

0 commit comments

Comments
 (0)