Skip to content

Commit 2c82a7e

Browse files
committed
fix: restore NaN validation in polars-based variable bounds checking
The polars migration broke NaN validation because check_has_nulls_polars only checked for null values, not NaN values. In polars, these are distinct concepts. This fix enhances the validation to detect both null and NaN values in numeric columns while avoiding type errors on non-numeric columns. Fixes failing tests in test_inconsistency_checks.py that expected ValueError to be raised when variables have NaN bounds.
1 parent 3395c45 commit 2c82a7e

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

linopy/common.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -363,22 +363,35 @@ def to_polars(ds: Dataset, **kwargs: Any) -> pl.DataFrame:
363363

364364
def check_has_nulls_polars(df: pl.DataFrame, name: str = "") -> None:
365365
"""
366-
Checks if the given DataFrame contains any null values and raises a ValueError if it does.
366+
Checks if the given DataFrame contains any null or NaN values and raises a ValueError if it does.
367367
368368
Args:
369369
----
370-
df (pl.DataFrame): The DataFrame to check for null values.
370+
df (pl.DataFrame): The DataFrame to check for null or NaN values.
371371
name (str): The name of the data container being checked.
372372
373373
Raises:
374374
------
375-
ValueError: If the DataFrame contains null values,
376-
a ValueError is raised with a message indicating the name of the constraint and the fields containing null values.
375+
ValueError: If the DataFrame contains null or NaN values,
376+
a ValueError is raised with a message indicating the name of the constraint and the fields containing null/NaN values.
377377
"""
378+
# Check for null values in all columns
378379
has_nulls = df.select(pl.col("*").is_null().any())
379380
null_columns = [col for col in has_nulls.columns if has_nulls[col][0]]
380-
if null_columns:
381-
raise ValueError(f"{name} contains nan's in field(s) {null_columns}")
381+
382+
# Check for NaN values only in numeric columns (avoid enum/categorical columns)
383+
numeric_cols = [
384+
col for col, dtype in zip(df.columns, df.dtypes) if dtype.is_numeric()
385+
]
386+
387+
nan_columns = []
388+
if numeric_cols:
389+
has_nans = df.select(pl.col(numeric_cols).is_nan().any())
390+
nan_columns = [col for col in has_nans.columns if has_nans[col][0]]
391+
392+
invalid_columns = list(set(null_columns + nan_columns))
393+
if invalid_columns:
394+
raise ValueError(f"{name} contains nan's in field(s) {invalid_columns}")
382395

383396

384397
def filter_nulls_polars(df: pl.DataFrame) -> pl.DataFrame:

0 commit comments

Comments
 (0)