Skip to content

Commit ad67433

Browse files
committed
Match all facets with keep=False
1 parent 2863f5d commit ad67433

2 files changed

Lines changed: 42 additions & 1 deletion

File tree

packages/ref-core/src/cmip_ref_core/metrics.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def apply_filters(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
328328
Filtered data catalog
329329
"""
330330
for facet_filter in self.filters:
331+
mask = None
331332
for facet, value in facet_filter.facets.items():
332333
clean_value = value if isinstance(value, tuple) else (value,)
333334

@@ -336,7 +337,13 @@ def apply_filters(self, data_catalog: pd.DataFrame) -> pd.DataFrame:
336337
f"Facet {facet!r} not in data catalog columns: {data_catalog.columns.to_list()}"
337338
)
338339

339-
mask = data_catalog[facet].isin(clean_value)
340+
facet_mask = data_catalog[facet].isin(clean_value)
341+
if mask is None:
342+
mask = facet_mask
343+
else:
344+
mask &= facet_mask
345+
346+
if mask is not None:
340347
if not facet_filter.keep:
341348
mask = ~mask
342349

packages/ref-core/tests/unit/test_metrics.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,40 @@ def test_apply_filters_dont_keep(apply_data_catalog):
336336
)
337337

338338

339+
def test_apply_filters_dont_keep_multifacet(apply_data_catalog):
340+
"""Test that all facet values must match to exclude a file from the catalog."""
341+
requirement = DataRequirement(
342+
source_type=SourceDatasetType.CMIP6,
343+
filters=(
344+
FacetFilter(
345+
{
346+
"variable": "tas",
347+
"source_id": "CAS",
348+
},
349+
keep=False,
350+
),
351+
),
352+
group_by=None,
353+
)
354+
355+
filtered = requirement.apply_filters(apply_data_catalog)
356+
pd.testing.assert_frame_equal(
357+
filtered,
358+
pd.DataFrame(
359+
{
360+
"variable": ["tas", "pr", "rsut", "tas"],
361+
"source_id": [
362+
"CESM2",
363+
"CESM2",
364+
"CESM2",
365+
"ACCESS",
366+
],
367+
},
368+
index=[0, 1, 2, 3],
369+
),
370+
)
371+
372+
339373
def test_apply_filters_missing(apply_data_catalog):
340374
requirement = DataRequirement(
341375
source_type=SourceDatasetType.CMIP6,

0 commit comments

Comments
 (0)