Skip to content

Commit 3a10b12

Browse files
author
Robbie Muir
committed
fix test
1 parent 61ad751 commit 3a10b12

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

linopy/common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
from linopy.variables import Variable
4343

4444

45+
class CoordAlignWarning(UserWarning): ...
46+
47+
4548
def set_int_index(series: pd.Series) -> pd.Series:
4649
"""
4750
Convert string index to int index.
@@ -189,7 +192,8 @@ def pandas_to_dataarray(
189192
f"coords for dimension(s) {non_aligned} is not aligned with the pandas object. "
190193
"Previously, the indexes of the pandas were ignored and overwritten in "
191194
"these cases. Now, the pandas object's coordinates are taken considered"
192-
" for alignment."
195+
" for alignment.",
196+
CoordAlignWarning,
193197
)
194198

195199
return DataArray(arr, coords=None, dims=dims, **kwargs)
@@ -469,7 +473,7 @@ def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset:
469473
except ValueError:
470474
warn(
471475
"Coordinates across variables not equal. Perform outer join.",
472-
UserWarning,
476+
CoordAlignWarning,
473477
)
474478
arrs = xr_align(*dataarrays, join="outer")
475479
if integer_dtype:

test/test_variables.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@
44
"""
55

66
import warnings
7-
from datetime import UTC, datetime
7+
from datetime import datetime
88

99
import numpy as np
1010
import pandas as pd
1111
import pytest
1212
import xarray as xr
1313
import xarray.core.indexes
1414
import xarray.core.utils
15+
from pytz import UTC
1516

1617
import linopy
1718
from linopy import Model
19+
from linopy.common import CoordAlignWarning
1820
from linopy.testing import assert_varequal
1921
from linopy.variables import ScalarVariable
2022

@@ -139,9 +141,10 @@ def test_timezone_alignment_with_multiplication() -> None:
139141
series1 = pd.Series(index=utc_index, data=1.0)
140142
var1 = model.add_variables(coords=[utc_index], name="var1")
141143

142-
with warnings.catch_warnings():
144+
with warnings.catch_warnings(category=CoordAlignWarning):
143145
warnings.simplefilter("error")
144146
expr = var1 * series1
147+
145148
index: pd.DatetimeIndex = expr.coords["time"].to_index() # type: ignore
146149
assert index.equals(utc_index)
147150
assert index.tzinfo is UTC

0 commit comments

Comments
 (0)