diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 730d3e04..93290680 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -3,6 +3,7 @@ Release Notes .. Upcoming Version +* Fix warning when multiplying variables with pd.Series containing time-zone aware index * Add support for SOS1 and SOS2 (Special Ordered Sets) constraints via ``Model.add_sos_constraints()`` and ``Model.remove_sos_constraints()`` * Add simplify method to LinearExpression to combine duplicate terms * Add convenience function to create LinearExpression from constant diff --git a/linopy/common.py b/linopy/common.py index 7dd97b65..c8f0f184 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -12,12 +12,13 @@ from collections.abc import Callable, Generator, Hashable, Iterable, Sequence from functools import partial, reduce, wraps from pathlib import Path -from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, ParamSpec, TypeVar, overload from warnings import warn import numpy as np import pandas as pd import polars as pl +import xarray as xr from numpy import arange, signedinteger from xarray import DataArray, Dataset, apply_ufunc, broadcast from xarray import align as xr_align @@ -45,6 +46,48 @@ from linopy.variables import Variable +class CoordAlignWarning(UserWarning): ... + + +class TimezoneAlignError(ValueError): ... + + +P = ParamSpec("P") +R = TypeVar("R") + + +class CatchDatetimeTypeError: + """Context manager that catches datetime-related TypeErrors and re-raises as TimezoneAlignError.""" + + def __enter__(self) -> CatchDatetimeTypeError: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> Literal[False]: + if exc_type is TypeError and exc_val is not None: + if "Cannot interpret 'datetime" in str(exc_val): + raise TimezoneAlignError( + "Timezone information across datetime coordinates not aligned." + ) from exc_val + return False + + +def catch_datetime_type_error_and_re_raise(func: Callable[P, R]) -> Callable[P, R]: + """Decorator that catches datetime-related TypeErrors and re-raises as TimezoneAlignError.""" + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + with CatchDatetimeTypeError(): + result = func(*args, **kwargs) + return result + + return wrapper + + def set_int_index(series: pd.Series) -> pd.Series: """ Convert string index to int index. @@ -128,6 +171,21 @@ def get_from_iterable(lst: DimsLike | None, index: int) -> Any | None: return lst[index] if 0 <= index < len(lst) else None +def try_to_convert_to_pd_datetime_index( + coord: xr.DataArray | Sequence | pd.Index | Any, +) -> pd.DatetimeIndex | xr.DataArray | Sequence | pd.Index | Any: + if isinstance(coord, pd.DatetimeIndex): + return coord + try: + if isinstance(coord, xr.DataArray): + index = coord.to_index() + assert isinstance(index, pd.DatetimeIndex) + return index + return pd.DatetimeIndex(coord) + except Exception: + return coord + + def pandas_to_dataarray( arr: pd.DataFrame | pd.Series, coords: CoordsLike | None = None, @@ -168,7 +226,10 @@ def pandas_to_dataarray( shared_dims = set(pandas_coords.keys()) & set(coords.keys()) non_aligned = [] for dim in shared_dims: + pd_coord = pandas_coords[dim] coord = coords[dim] + if isinstance(pd_coord, pd.DatetimeIndex): + coord = try_to_convert_to_pd_datetime_index(coord) if not isinstance(coord, pd.Index): coord = pd.Index(coord) if not pandas_coords[dim].equals(coord): @@ -178,7 +239,8 @@ def pandas_to_dataarray( f"coords for dimension(s) {non_aligned} is not aligned with the pandas object. " "Previously, the indexes of the pandas were ignored and overwritten in " "these cases. Now, the pandas object's coordinates are taken considered" - " for alignment." + " for alignment.", + CoordAlignWarning, ) return DataArray(arr, coords=None, dims=dims, **kwargs) @@ -449,6 +511,7 @@ def group_terms_polars(df: pl.DataFrame) -> pl.DataFrame: return df +@catch_datetime_type_error_and_re_raise def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset: """ Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal. @@ -458,7 +521,7 @@ def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset: except ValueError: warn( "Coordinates across variables not equal. Perform outer join.", - UserWarning, + CoordAlignWarning, ) arrs = xr_align(*dataarrays, join="outer") if integer_dtype: @@ -466,6 +529,7 @@ def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset: return Dataset({ds.name: ds for ds in arrs}) +@catch_datetime_type_error_and_re_raise def assign_multiindex_safe(ds: Dataset, **fields: Any) -> Dataset: """ Assign a field to a xarray Dataset while being safe against warnings about multiindex corruption. diff --git a/linopy/expressions.py b/linopy/expressions.py index 10e243de..59749479 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -47,6 +47,7 @@ LocIndexer, as_dataarray, assign_multiindex_safe, + catch_datetime_type_error_and_re_raise, check_common_keys_values, check_has_nulls, check_has_nulls_polars, @@ -505,6 +506,7 @@ def __neg__(self: GenericExpression) -> GenericExpression: """ return self.assign_multiindex_safe(coeffs=-self.coeffs, const=-self.const) + @catch_datetime_type_error_and_re_raise def _multiply_by_linear_expression( self, other: LinearExpression | ScalarLinearExpression ) -> QuadraticExpression: @@ -526,6 +528,7 @@ def _multiply_by_linear_expression( res = res + self.reset_const() * other.const return res + @catch_datetime_type_error_and_re_raise def _multiply_by_constant( self: GenericExpression, other: ConstantLike ) -> GenericExpression: diff --git a/linopy/variables.py b/linopy/variables.py index e2570b5d..60d689f2 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -34,6 +34,7 @@ LocIndexer, as_dataarray, assign_multiindex_safe, + catch_datetime_type_error_and_re_raise, check_has_nulls, check_has_nulls_polars, filter_nulls_polars, @@ -295,6 +296,7 @@ def loc(self) -> LocIndexer: def to_pandas(self) -> pd.Series: return self.labels.to_pandas() + @catch_datetime_type_error_and_re_raise def to_linexpr( self, coefficient: ConstantLike = 1, diff --git a/pyproject.toml b/pyproject.toml index b5105230..c089da13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ dev = [ "types-requests", "gurobipy", "highspy", + "types-pytz" ] solvers = [ "gurobipy", diff --git a/test/test_common.py b/test/test_common.py index db218375..f59c100a 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -5,17 +5,21 @@ @author: fabian """ +from datetime import datetime + import numpy as np import pandas as pd import polars as pl import pytest import xarray as xr +from pytz import UTC from test_linear_expression import m, u, x # noqa: F401 from xarray import DataArray from xarray.testing.assertions import assert_equal from linopy import LinearExpression, Model, Variable from linopy.common import ( + CoordAlignWarning, align, as_dataarray, assign_multiindex_safe, @@ -73,6 +77,67 @@ def test_as_dataarray_with_series_dims_priority() -> None: assert list(da.coords[target_dim].values) == target_index +def test_as_datarray_with_tz_aware_series_index() -> None: + time_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=UTC, + name="time", + ) + other_index = pd.Index(name="time", data=[0, 1, 2, 3]) + + panda_series = pd.Series(index=time_index, data=1.0) + + data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[time_index]) + result = as_dataarray(arr=panda_series, coords=data_array.coords) + assert time_index.equals(result.coords["time"].to_index()) + + data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[other_index]) + with pytest.warns(CoordAlignWarning): + result = as_dataarray(arr=panda_series, coords=data_array.coords) + assert time_index.equals(result.coords["time"].to_index()) + + coords = {"time": time_index} + result = as_dataarray(arr=panda_series, coords=coords) + assert time_index.equals(result.coords["time"].to_index()) + + coords = {"time": [0, 1, 2, 3]} + result = as_dataarray(arr=panda_series, coords=coords) + assert time_index.equals(result.coords["time"].to_index()) + + +def test_as_datarray_with_tz_aware_dataframe_columns_index() -> None: + time_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=UTC, + name="time", + ) + other_index = pd.Index(name="time", data=[0, 1, 2, 3]) + + index = pd.Index([0, 1, 2, 3], name="x") + pandas_df = pd.DataFrame(index=index, columns=time_index, data=1.0) + + data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[time_index]) + result = as_dataarray(arr=pandas_df, coords=data_array.coords) + assert time_index.equals(result.coords["time"].to_index()) + + data_array = xr.DataArray(data=[0, 1, 2, 3], coords=[other_index]) + with pytest.warns(CoordAlignWarning): + result = as_dataarray(arr=pandas_df, coords=data_array.coords) + assert time_index.equals(result.coords["time"].to_index()) + + coords = {"time": time_index} + result = as_dataarray(arr=pandas_df, coords=coords) + assert time_index.equals(result.coords["time"].to_index()) + + coords = {"time": [0, 1, 2, 3]} + result = as_dataarray(arr=pandas_df, coords=coords) + assert time_index.equals(result.coords["time"].to_index()) + + def test_as_dataarray_with_series_dims_subset() -> None: target_dim = "dim_0" target_index = ["a", "b", "c"] @@ -99,7 +164,7 @@ def test_as_dataarray_with_series_override_coords() -> None: target_dim = "dim_0" target_index = ["a", "b", "c"] s = pd.Series([1, 2, 3], index=target_index) - with pytest.warns(UserWarning): + with pytest.warns(CoordAlignWarning): da = as_dataarray(s, coords=[[1, 2, 3]]) assert isinstance(da, DataArray) assert da.dims == (target_dim,) @@ -218,7 +283,7 @@ def test_as_dataarray_dataframe_override_coords() -> None: target_index = ["a", "b"] target_columns = ["A", "B"] df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns) - with pytest.warns(UserWarning): + with pytest.warns(CoordAlignWarning): da = as_dataarray(df, coords=[[1, 2], [2, 3]]) assert isinstance(da, DataArray) assert da.dims == target_dims diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index a75ace3f..fce31427 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -7,14 +7,18 @@ from __future__ import annotations +from datetime import datetime + import numpy as np import pandas as pd import polars as pl import pytest import xarray as xr +from pytz import UTC from xarray.testing import assert_equal from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge +from linopy.common import TimezoneAlignError from linopy.constants import HELPER_DIMS, TERM_DIM from linopy.expressions import ScalarLinearExpression from linopy.testing import assert_linequal, assert_quadequal @@ -1230,6 +1234,30 @@ def test_cumsum(m: Model, multiple: float) -> None: cumsum.nterm == 2 +def test_timezone_alignment_failure() -> None: + utc_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=UTC, + name="time", + ) + tz_naive_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=None, + name="time", + ) + model = Model() + series1 = pd.Series(index=tz_naive_index, data=1.0) + expr = model.add_variables(coords=[utc_index], name="var1") * 1.0 + + with pytest.raises(TimezoneAlignError): + # We expect to get a useful error (TimezoneAlignError) instead of a not implemented error falsely claiming that we cannot multiply these types together + _ = expr * series1 + + def test_simplify_basic(x: Variable) -> None: """Test basic simplification with duplicate terms.""" expr = 2 * x + 3 * x + 1 * x diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index fc1bb25f..4e47844a 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -1,13 +1,17 @@ #!/usr/bin/env python3 +from datetime import datetime + import numpy as np import pandas as pd import polars as pl import pytest +from pytz import UTC from scipy.sparse import csc_matrix from xarray import DataArray from linopy import Model, Variable, merge +from linopy.common import TimezoneAlignError from linopy.constants import FACTOR_DIM, TERM_DIM from linopy.expressions import LinearExpression, QuadraticExpression from linopy.testing import assert_quadequal @@ -360,3 +364,28 @@ def test_power_of_three(x: Variable) -> None: x**3 with pytest.raises(TypeError): (x * x) * (x * x) + + +def test_timezone_alignment_failure() -> None: + utc_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=UTC, + name="time", + ) + tz_naive_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=None, + name="time", + ) + model = Model() + series1 = pd.Series(index=tz_naive_index, data=1.0) + var = model.add_variables(coords=[utc_index], name="var1") + expr = var * var + + with pytest.raises(TimezoneAlignError): + # We expect to get a useful error (TimezoneAlignError) instead of a not implemented error falsely claiming that we cannot multiply these types together + _ = expr * series1 diff --git a/test/test_variables.py b/test/test_variables.py index 3984b091..aed5219c 100644 --- a/test/test_variables.py +++ b/test/test_variables.py @@ -3,15 +3,20 @@ This module aims at testing the correct behavior of the Variables class. """ +import warnings +from datetime import datetime + import numpy as np import pandas as pd import pytest import xarray as xr import xarray.core.indexes import xarray.core.utils +from pytz import UTC import linopy from linopy import Model +from linopy.common import CoordAlignWarning, TimezoneAlignError from linopy.testing import assert_varequal from linopy.variables import ScalarVariable @@ -122,3 +127,48 @@ def test_scalar_variable(m: Model) -> None: x = ScalarVariable(label=0, model=m) assert isinstance(x, ScalarVariable) assert x.__rmul__(x) is NotImplemented # type: ignore + + +def test_timezone_alignment_with_multiplication() -> None: + utc_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=UTC, + name="time", + ) + model = Model() + series1 = pd.Series(index=utc_index, data=1.0) + var1 = model.add_variables(coords=[utc_index], name="var1") + + with warnings.catch_warnings(): + warnings.simplefilter("error", CoordAlignWarning) + expr = var1 * series1 + + index: pd.DatetimeIndex = expr.coords["time"].to_index() + assert index.equals(utc_index) + assert index.tzinfo is UTC + + +def test_timezone_alignment_failure() -> None: + utc_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=UTC, + name="time", + ) + tz_naive_index = pd.date_range( + start=datetime(2025, 1, 1), + freq="15min", + periods=4, + tz=None, + name="time", + ) + model = Model() + series1 = pd.Series(index=tz_naive_index, data=1.0) + var1 = model.add_variables(coords=[utc_index], name="var1") + + with pytest.raises(TimezoneAlignError): + # We expect to get a useful error (TimezoneAlignError) instead of a not implemented error falsely claiming that we cannot multiply these types together + _ = var1 * series1