Skip to content

Commit f13f07c

Browse files
committed
Made merge expressions function infer class wihthout warning
1 parent 7e8f363 commit f13f07c

File tree

3 files changed

+19
-24
lines changed

3 files changed

+19
-24
lines changed

linopy/expressions.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,50 +1941,43 @@ def merge(
19411941
dim : str
19421942
Dimension along which the expressions should be concatenated.
19431943
cls : type
1944-
Type of the resulting expression.
1944+
Explicitly set the type of the resulting expression (So that the type checker will know the return type)
19451945
**kwargs
19461946
Additional keyword arguments passed to xarray.concat. Defaults to
19471947
{coords: "minimal", compat: "override"} or, in the special case described
19481948
above, to {coords: "minimal", compat: "override", "join": "override"}.
19491949
19501950
Returns
19511951
-------
1952-
res : linopy.LinearExpression
1952+
res : linopy.LinearExpression or linopy.QuadraticExpression
19531953
"""
1954-
if cls is None:
1955-
warn(
1956-
"Using merge without specifying the class is deprecated",
1957-
DeprecationWarning,
1958-
)
1959-
cls = LinearExpression
1960-
1961-
linopy_types = (variables.Variable, LinearExpression, QuadraticExpression)
1962-
19631954
if not isinstance(exprs, list) and len(add_exprs):
19641955
warn(
19651956
"Passing a tuple to the merge function is deprecated. Please pass a list of objects to be merged",
19661957
DeprecationWarning,
19671958
)
19681959
exprs = [exprs] + list(add_exprs) # type: ignore
1969-
model = exprs[0].model
19701960

1971-
if (
1972-
cls is QuadraticExpression
1973-
and dim == TERM_DIM
1974-
and any(type(e) is LinearExpression for e in exprs)
1975-
):
1961+
has_quad_expression = any(type(e) is QuadraticExpression for e in exprs)
1962+
has_linear_expression = any(type(e) is LinearExpression for e in exprs)
1963+
if cls is None:
1964+
cls = QuadraticExpression if has_quad_expression else LinearExpression
1965+
1966+
if cls is QuadraticExpression and dim == TERM_DIM and has_linear_expression:
19761967
raise ValueError(
19771968
"Cannot merge linear and quadratic expressions along term dimension."
19781969
"Convert to QuadraticExpression first."
19791970
)
19801971

1981-
if cls is not QuadraticExpression and any(
1982-
type(e) is QuadraticExpression for e in exprs
1983-
):
1972+
if has_quad_expression and cls is not QuadraticExpression:
19841973
raise ValueError(
1985-
"Cannot merge linear and quadratic expressions to QuadraticExpression"
1974+
"Cannot merge linear expressions to QuadraticExpression"
19861975
)
19871976

1977+
linopy_types = (variables.Variable, LinearExpression, QuadraticExpression)
1978+
1979+
model = exprs[0].model
1980+
19881981
if cls in linopy_types and dim in HELPER_DIMS:
19891982
coord_dims = [
19901983
{k: v for k, v in e.sizes.items() if k not in HELPER_DIMS} for e in exprs

test/test_linear_expression.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,9 +1128,8 @@ def test_merge(x: Variable, y: Variable, z: Variable) -> None:
11281128
res = merge([expr1, expr2], cls=LinearExpression)
11291129
assert res.nterm == 6
11301130

1131-
with pytest.warns(DeprecationWarning):
1132-
res: LinearExpression = merge([expr1, expr2]) # type: ignore
1133-
assert res.nterm == 6
1131+
res: LinearExpression = merge([expr1, expr2]) # type: ignore
1132+
assert isinstance(res, LinearExpression)
11341133

11351134
# now concat with same length of terms
11361135
expr1 = z.sel(dim_0=0).sum("dim_1")

test/test_quadratic_expression.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ def test_merge_linear_expression_and_quadratic_expression(
238238
with pytest.raises(ValueError):
239239
merge([linexpr, quadexpr], cls=QuadraticExpression)
240240

241+
new_quad_ex = merge([linexpr.to_quadexpr(), quadexpr]) # type: ignore
242+
assert isinstance(new_quad_ex, QuadraticExpression)
243+
241244
with pytest.warns(DeprecationWarning):
242245
merge(quadexpr, quadexpr, cls=QuadraticExpression) # type: ignore
243246

0 commit comments

Comments
 (0)