Skip to content

Commit 8eb9fdc

Browse files
committed
more changes
1 parent ee6f1ec commit 8eb9fdc

File tree

5 files changed

+68
-38
lines changed

5 files changed

+68
-38
lines changed

linopy/model.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,7 @@ def objective(self) -> Objective:
268268
def objective(
269269
self, obj: Objective | LinearExpression | QuadraticExpression
270270
) -> Objective:
271-
if not isinstance(obj, Objective):
272-
expr = obj
273-
else:
274-
expr = obj.expression
275-
self.add_objective(expr=expr, overwrite=True, allow_constant=False)
271+
self.add_objective(expr=obj, overwrite=True, allow_constant=False)
276272
return self._objective
277273

278274
@property
@@ -782,23 +778,46 @@ def add_constraints(
782778
self.constraints.add(constraint)
783779
return constraint
784780

781+
@overload
782+
def add_objective(
783+
self,
784+
expr: Objective,
785+
sense: None = None,
786+
overwrite: bool = False,
787+
allow_constant: bool = False,
788+
) -> None: ...
789+
790+
@overload
785791
def add_objective(
786792
self,
787793
expr: Variable
788794
| LinearExpression
789795
| QuadraticExpression
790796
| Sequence[tuple[ConstantLike, VariableLike]],
797+
sense: Literal["min", "max"] | None = None,
798+
overwrite: bool = False,
799+
allow_constant: bool = False,
800+
) -> None: ...
801+
802+
def add_objective(
803+
self,
804+
expr: Variable
805+
| LinearExpression
806+
| QuadraticExpression
807+
| Sequence[tuple[ConstantLike, VariableLike]]
808+
| Objective,
809+
sense: Literal["min", "max"] | None = None,
791810
overwrite: bool = False,
792-
sense: str = "min",
793811
allow_constant: bool = False,
794812
) -> None:
795813
"""
796814
Add an objective function to the model.
797815
798816
Parameters
799817
----------
800-
expr : linopy.LinearExpression, linopy.QuadraticExpression
818+
expr : linopy.Variable, linopy.LinearExpression, linopy.QuadraticExpression, Objective
801819
Expression describing the objective function.
820+
sense: "min" or "max", the sense to optimize for. Defaults to min. Cannot be set if passing Objective directly
802821
overwrite : False, optional
803822
Whether to overwrite the existing objective. The default is False.
804823
allow_constant: bool, optional
@@ -814,15 +833,20 @@ def add_objective(
814833
"Objective already defined."
815834
" Set `overwrite` to True to force overwriting."
816835
)
817-
if isinstance(expr, Variable):
818-
expr = 1 * expr
819836

820-
if not allow_constant and expr.has_constant:
837+
if isinstance(expr, Objective):
838+
assert sense is None, "Cannot set sense if objective object is passed"
839+
objective = expr
840+
assert objective.model == self
841+
else:
842+
sense = sense or "min"
843+
objective = Objective(expression=expr, model=self, sense=sense)
844+
845+
if not allow_constant and objective.expression.has_constant:
821846
raise ConstantObjectiveError(
822847
"Objective contains constant term. Either remove constants from the expression with expr.drop_constants() or use model.add_objective(..., allow_constant=True).",
823848
)
824849

825-
objective = Objective(expression=expr, model=self, sense=sense)
826850
self._objective = objective
827851

828852
def remove_variables(self, name: str) -> None:

linopy/objective.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import functools
1111
from collections.abc import Callable, Sequence
12-
from typing import TYPE_CHECKING, Any
12+
from typing import TYPE_CHECKING, Any, Literal
1313

1414
import numpy as np
1515
import polars as pl
@@ -24,6 +24,7 @@
2424

2525
from linopy import expressions
2626
from linopy.types import ConstantLike
27+
from linopy.variables import Variable
2728

2829
if TYPE_CHECKING:
2930
from linopy.expressions import LinearExpression, QuadraticExpression
@@ -64,13 +65,19 @@ class Objective:
6465

6566
def __init__(
6667
self,
67-
expression: expressions.LinearExpression | expressions.QuadraticExpression,
68+
expression: Variable
69+
| expressions.LinearExpression
70+
| expressions.QuadraticExpression,
6871
model: Model,
69-
sense: str = "min",
72+
sense: Literal["min", "max"] = "min",
7073
) -> None:
7174
self._model: Model = model
7275
self._value: float | None = None
7376

77+
if isinstance(expression, Variable):
78+
expression = 1 * expression
79+
80+
assert sense in ["min", "max"]
7481
self.sense: str = sense
7582
self.expression: (
7683
expressions.LinearExpression | expressions.QuadraticExpression

test/test_model.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
from tempfile import gettempdir
1010

1111
import numpy as np
12+
import pandas as pd
1213
import pytest
1314
import xarray as xr
1415

1516
from linopy import EQUAL, Model
17+
from linopy.model import ConstantObjectiveError
1618
from linopy.testing import assert_model_equal
1719

1820
target_shape: tuple[int, int] = (10, 10)
@@ -67,7 +69,7 @@ def test_objective() -> None:
6769
y = m.add_variables(lower, upper, name="y")
6870

6971
obj1 = (10 * x + 5 * y).sum()
70-
m.add_objective(obj1)
72+
m.add_objective(obj1, allow_constant=True)
7173
assert m.objective.vars.size == 200
7274

7375
# test overwriting
@@ -82,8 +84,9 @@ def test_objective() -> None:
8284
assert m.objectiverange.min() == 2
8385
assert m.objectiverange.max() == 2
8486

85-
# test objective with constant which is supported
86-
m.objective = m.objective + 3
87+
# test setting constant term in objective with explicitly allowing it
88+
with pytest.raises(ConstantObjectiveError):
89+
m.objective = m.objective + 3
8790

8891

8992
def test_remove_variable() -> None:
@@ -162,3 +165,20 @@ def test_assert_model_equal() -> None:
162165
m.add_objective(obj)
163166

164167
assert_model_equal(m, m)
168+
169+
170+
def test_constant_not_allowed_in_objective_unless_specified_explicitly() -> None:
171+
model = Model()
172+
days = pd.Index(["Mon", "Tue", "Wed", "Thu", "Fri"], name="day")
173+
x = model.add_variables(name="x", coords=[days])
174+
non_linear = x + 1
175+
176+
with pytest.raises(ConstantObjectiveError):
177+
model.add_objective(expr=non_linear, overwrite=True, allow_constant=False)
178+
with pytest.raises(ConstantObjectiveError):
179+
model.add_objective(expr=non_linear, overwrite=True)
180+
181+
with pytest.raises(ConstantObjectiveError):
182+
model.objective = non_linear
183+
184+
model.add_objective(expr=non_linear, overwrite=True, allow_constant=True)

test/test_objective.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,3 @@ def test_repr(linear_objective: Objective, quadratic_objective: Objective) -> No
187187

188188
assert "Linear" in linear_objective.__repr__()
189189
assert "Quadratic" in quadratic_objective.__repr__()
190-
191-
192-
def test_objective_constant() -> None:
193-
m = Model()
194-
linear_expr = LinearExpression(None, m) + 1
195-
m.objective = Objective(linear_expr, m)

test/test_optimization.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from linopy import GREATER_EQUAL, LESS_EQUAL, Model, solvers
2121
from linopy.common import to_path
2222
from linopy.expressions import LinearExpression
23-
from linopy.model import ConstantObjectiveError
2423
from linopy.solver_capabilities import (
2524
SolverFeature,
2625
get_available_solvers_with_feature,
@@ -956,20 +955,6 @@ def test_model_resolve(
956955
assert np.isclose(model.objective.value or 0, 5.25)
957956

958957

959-
def test_constant_not_allowed_unless_specified_explicitly(model: Model) -> None:
960-
objective = model.objective.expression + 1
961-
962-
with pytest.raises(ConstantObjectiveError):
963-
model.add_objective(expr=objective, overwrite=True, allow_constant=False)
964-
with pytest.raises(ConstantObjectiveError):
965-
model.add_objective(expr=objective, overwrite=True)
966-
967-
with pytest.raises(ConstantObjectiveError):
968-
model.objective = objective
969-
970-
model.add_objective(expr=objective, overwrite=True, allow_constant=True)
971-
972-
973958
def test_constant_feasible(model: Model) -> None:
974959
objective = model.objective.expression + 1
975960
model.add_objective(expr=objective, overwrite=True, allow_constant=True)

0 commit comments

Comments
 (0)