Skip to content

Commit 62cdebb

Browse files
fix: support multiindexed and arbitrarly-named dimensions for grouping (#373)
1 parent 78331b9 commit 62cdebb

File tree

4 files changed

+90
-20
lines changed

4 files changed

+90
-20
lines changed

doc/release_notes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ Upcoming Version
66

77
* When writing out an LP file, large variables and constraints are now chunked to avoid memory issues. This is especially useful for large models with constraints with many terms. The chunk size can be set with the `slice_size` argument in the `solve` function.
88
* Constraints which of the form `<= infinity` and `>= -infinity` are now automatically filtered out when solving. The `solve` function now has a new argument `sanitize_infinities` to control this feature. Default is set to `True`.
9+
* Grouping expressions is now supported on dimensions called "group" and dimensions that have the same name as the grouping object.
10+
* Grouping dimensions which have multiindexed coordinates is now supported.
911

1012
Version 0.3.15
1113
--------------

linopy/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
TERM_DIM = "_term"
3737
STACKED_TERM_DIM = "_stacked_term"
3838
GROUPED_TERM_DIM = "_grouped_term"
39+
GROUP_DIM = "_group"
3940
FACTOR_DIM = "_factor"
4041
CONCAT_DIM = "_concat"
4142
HELPER_DIMS = [TERM_DIM, STACKED_TERM_DIM, GROUPED_TERM_DIM, FACTOR_DIM, CONCAT_DIM]

linopy/expressions.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
EQUAL,
6363
FACTOR_DIM,
6464
GREATER_EQUAL,
65+
GROUP_DIM,
6566
GROUPED_TERM_DIM,
6667
HELPER_DIMS,
6768
LESS_EQUAL,
@@ -218,42 +219,43 @@ def sum(self, use_fallback: bool = False, **kwargs) -> LinearExpression:
218219
group: pd.Series | pd.DataFrame | xr.DataArray = self.group
219220
if isinstance(group, pd.DataFrame):
220221
# dataframes do not have a name, so we need to set it
221-
group_name = "group"
222+
final_group_name = "group"
222223
else:
223-
group_name = getattr(group, "name", "group") or "group"
224+
final_group_name = getattr(group, "name", "group") or "group"
224225

225226
if isinstance(group, DataArray):
226227
group = group.to_pandas()
227228

228229
int_map = None
229230
if isinstance(group, pd.DataFrame):
231+
index_name = group.index.name
230232
group = group.reindex(self.data.indexes[group.index.name])
233+
group.index.name = index_name # ensure name for multiindex
231234
int_map = get_index_map(*group.values.T)
232235
orig_group = group
233236
group = group.apply(tuple, axis=1).map(int_map)
234237

235238
group_dim = group.index.name
236-
if group_name == group_dim:
237-
raise ValueError(
238-
"Group name cannot be the same as group dimension in non-fallback mode."
239-
)
240239

241240
arrays = [group, group.groupby(group).cumcount()]
242-
idx = pd.MultiIndex.from_arrays(
243-
arrays, names=[group_name, GROUPED_TERM_DIM]
244-
)
245-
coords = Coordinates.from_pandas_multiindex(idx, group_dim)
246-
ds = self.data.assign_coords(coords)
241+
idx = pd.MultiIndex.from_arrays(arrays, names=[GROUP_DIM, GROUPED_TERM_DIM])
242+
new_coords = Coordinates.from_pandas_multiindex(idx, group_dim)
243+
coords = self.data.indexes[group_dim]
244+
names_to_drop = [coords.name]
245+
if isinstance(coords, pd.MultiIndex):
246+
names_to_drop += list(coords.names)
247+
ds = self.data.drop_vars(names_to_drop).assign_coords(new_coords)
247248
ds = ds.unstack(group_dim, fill_value=LinearExpression._fill_value)
248249
ds = LinearExpression._sum(ds, dim=GROUPED_TERM_DIM)
249250

250251
if int_map is not None:
251-
index = ds.indexes["group"].map({v: k for k, v in int_map.items()})
252+
index = ds.indexes[GROUP_DIM].map({v: k for k, v in int_map.items()})
252253
index.names = [str(col) for col in orig_group.columns]
253-
index.name = group_name
254-
coords = Coordinates.from_pandas_multiindex(index, group_name)
255-
ds = xr.Dataset(ds.assign_coords(coords))
254+
index.name = GROUP_DIM
255+
new_coords = Coordinates.from_pandas_multiindex(index, GROUP_DIM)
256+
ds = xr.Dataset(ds.assign_coords(new_coords))
256257

258+
ds = ds.rename({GROUP_DIM: final_group_name})
257259
return LinearExpression(ds, self.model)
258260

259261
def func(ds):
@@ -1428,6 +1430,8 @@ def to_polars(self) -> pl.DataFrame:
14281430

14291431
drop = exprwrap(Dataset.drop)
14301432

1433+
drop_vars = exprwrap(Dataset.drop_vars)
1434+
14311435
drop_sel = exprwrap(Dataset.drop_sel)
14321436

14331437
drop_isel = exprwrap(Dataset.drop_isel)
@@ -1452,6 +1456,8 @@ def to_polars(self) -> pl.DataFrame:
14521456

14531457
rename = exprwrap(Dataset.rename)
14541458

1459+
reset_index = exprwrap(Dataset.reset_index)
1460+
14551461
rename_dims = exprwrap(Dataset.rename_dims)
14561462

14571463
roll = exprwrap(Dataset.roll)

test/test_linear_expression.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,17 @@ def test_linear_expression_diff(v):
668668

669669
@pytest.mark.parametrize("use_fallback", [True, False])
670670
def test_linear_expression_groupby(v, use_fallback):
671+
expr = 1 * v
672+
dim = v.dims[0]
673+
groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords, name=dim)
674+
grouped = expr.groupby(groups).sum(use_fallback=use_fallback)
675+
assert dim in grouped.dims
676+
assert (grouped.data[dim] == [1, 2]).all()
677+
assert grouped.nterm == 10
678+
679+
680+
@pytest.mark.parametrize("use_fallback", [True, False])
681+
def test_linear_expression_groupby_on_same_name_as_target_dim(v, use_fallback):
671682
expr = 1 * v
672683
groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords)
673684
grouped = expr.groupby(groups).sum(use_fallback=use_fallback)
@@ -719,20 +730,31 @@ def test_linear_expression_groupby_series_with_name(v, use_fallback):
719730

720731

721732
@pytest.mark.parametrize("use_fallback", [True, False])
722-
def test_linear_expression_groupby_with_series_false(v, use_fallback):
733+
def test_linear_expression_groupby_with_series_with_same_group_name(v, use_fallback):
734+
"""
735+
Test that the group by works with a series whose name is the same as
736+
the dimension to group.
737+
"""
723738
expr = 1 * v
724739
groups = pd.Series([1] * 10 + [2] * 10, index=v.indexes["dim_2"])
725740
groups.name = "dim_2"
726-
if not use_fallback:
727-
with pytest.raises(ValueError):
728-
expr.groupby(groups).sum(use_fallback=use_fallback)
729-
return
730741
grouped = expr.groupby(groups).sum(use_fallback=use_fallback)
731742
assert "dim_2" in grouped.dims
732743
assert (grouped.data.dim_2 == [1, 2]).all()
733744
assert grouped.nterm == 10
734745

735746

747+
@pytest.mark.parametrize("use_fallback", [True, False])
748+
def test_linear_expression_groupby_with_series_on_multiindex(u, use_fallback):
749+
expr = 1 * u
750+
len_grouped_dim = len(u.data["dim_3"])
751+
groups = pd.Series([1] * len_grouped_dim, index=u.indexes["dim_3"])
752+
grouped = expr.groupby(groups).sum(use_fallback=use_fallback)
753+
assert "group" in grouped.dims
754+
assert (grouped.data.group == [1]).all()
755+
assert grouped.nterm == len_grouped_dim
756+
757+
736758
@pytest.mark.parametrize("use_fallback", [True, False])
737759
def test_linear_expression_groupby_with_dataframe(v, use_fallback):
738760
expr = 1 * v
@@ -751,6 +773,45 @@ def test_linear_expression_groupby_with_dataframe(v, use_fallback):
751773
assert grouped.nterm == 3
752774

753775

776+
@pytest.mark.parametrize("use_fallback", [True, False])
777+
def test_linear_expression_groupby_with_dataframe_with_same_group_name(v, use_fallback):
778+
"""
779+
Test that the group by works with a dataframe whose column name is the same as
780+
the dimension to group.
781+
"""
782+
expr = 1 * v
783+
groups = pd.DataFrame(
784+
{"dim_2": [1] * 10 + [2] * 10, "b": list(range(4)) * 5},
785+
index=v.indexes["dim_2"],
786+
)
787+
if use_fallback:
788+
with pytest.raises(ValueError):
789+
expr.groupby(groups).sum(use_fallback=use_fallback)
790+
return
791+
792+
grouped = expr.groupby(groups).sum(use_fallback=use_fallback)
793+
index = pd.MultiIndex.from_frame(groups)
794+
assert "group" in grouped.dims
795+
assert set(grouped.data.group.values) == set(index.values)
796+
assert grouped.nterm == 3
797+
798+
799+
@pytest.mark.parametrize("use_fallback", [True, False])
800+
def test_linear_expression_groupby_with_dataframe_on_multiindex(u, use_fallback):
801+
expr = 1 * u
802+
len_grouped_dim = len(u.data["dim_3"])
803+
groups = pd.DataFrame({"a": [1] * len_grouped_dim}, index=u.indexes["dim_3"])
804+
805+
if use_fallback:
806+
with pytest.raises(ValueError):
807+
expr.groupby(groups).sum(use_fallback=use_fallback)
808+
return
809+
grouped = expr.groupby(groups).sum(use_fallback=use_fallback)
810+
assert "group" in grouped.dims
811+
assert isinstance(grouped.indexes["group"], pd.MultiIndex)
812+
assert grouped.nterm == len_grouped_dim
813+
814+
754815
@pytest.mark.parametrize("use_fallback", [True, False])
755816
def test_linear_expression_groupby_with_dataarray(v, use_fallback):
756817
expr = 1 * v

0 commit comments

Comments
 (0)