Skip to content

Commit ee6f1ec

Browse files
committed
refactor
1 parent 9e5847a commit ee6f1ec

File tree

2 files changed

+28
-45
lines changed

2 files changed

+28
-45
lines changed

linopy/model.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
108108
return func(*args, **kwargs)
109109

110110
# The objective contains a constant term
111-
if not model.allow_constant_objective:
112-
raise ConstantObjectiveError(
113-
"Objective function contains constant terms. Please use expr.drop_constants() or set model.allow_constant_objective=True."
114-
)
115-
116111
# Modify the model objective to drop the constant term
117112
model = self
118113
constant = float(self.objective.expression.const.values)
@@ -162,7 +157,6 @@ class Model:
162157
_dual: Dataset
163158
_status: str
164159
_termination_condition: str
165-
_allow_constant_objective: bool
166160
_xCounter: int
167161
_cCounter: int
168162
_varnameCounter: int
@@ -184,7 +178,6 @@ class Model:
184178
# hidden attributes
185179
"_status",
186180
"_termination_condition",
187-
"_allow_constant_objective",
188181
# TODO: move counters to Variables and Constraints class
189182
"_xCounter",
190183
"_cCounter",
@@ -236,7 +229,6 @@ def __init__(
236229

237230
self._status: str = "initialized"
238231
self._termination_condition: str = ""
239-
self._allow_constant_objective: bool = False
240232
self._xCounter: int = 0
241233
self._cCounter: int = 0
242234
self._varnameCounter: int = 0
@@ -277,9 +269,10 @@ def objective(
277269
self, obj: Objective | LinearExpression | QuadraticExpression
278270
) -> Objective:
279271
if not isinstance(obj, Objective):
280-
obj = Objective(obj, self)
281-
282-
self._objective = obj
272+
expr = obj
273+
else:
274+
expr = obj.expression
275+
self.add_objective(expr=expr, overwrite=True, allow_constant=False)
283276
return self._objective
284277

285278
@property
@@ -789,17 +782,6 @@ def add_constraints(
789782
self.constraints.add(constraint)
790783
return constraint
791784

792-
@property
793-
def allow_constant_objective(self) -> bool:
794-
"""
795-
Whether constant terms in the objective function are allowed.
796-
"""
797-
return self._allow_constant_objective
798-
799-
@allow_constant_objective.setter
800-
def allow_constant_objective(self, allow: bool) -> None:
801-
self._allow_constant_objective = allow
802-
803785
def add_objective(
804786
self,
805787
expr: Variable
@@ -808,7 +790,7 @@ def add_objective(
808790
| Sequence[tuple[ConstantLike, VariableLike]],
809791
overwrite: bool = False,
810792
sense: str = "min",
811-
allow_constant_objective: bool | None = None,
793+
allow_constant: bool = False,
812794
) -> None:
813795
"""
814796
Add an objective function to the model.
@@ -819,8 +801,8 @@ def add_objective(
819801
Expression describing the objective function.
820802
overwrite : False, optional
821803
Whether to overwrite the existing objective. The default is False.
822-
allow_constant_objective : bool, optional
823-
Set the `Model.allow_constant_objective` attribute. If True, the objective is allowed to contain a constant term.
804+
allow_constant: bool, optional
805+
If True, the objective is allowed to contain a constant term. The default is False
824806
825807
Returns
826808
-------
@@ -835,16 +817,14 @@ def add_objective(
835817
if isinstance(expr, Variable):
836818
expr = 1 * expr
837819

838-
self.objective.expression = expr
839-
self.objective.sense = sense
840-
if allow_constant_objective is not None:
841-
self.allow_constant_objective = allow_constant_objective
842-
843-
if not self.allow_constant_objective and self.objective.has_constant:
820+
if not allow_constant and expr.has_constant:
844821
raise ConstantObjectiveError(
845-
"Objective function contains constant terms but this is not allowed as Model.allow_constant_objective=False. Either remove constants from the expression with expr.drop_constants() or pass allow_constant_objective=True.",
822+
"Objective contains constant term. Either remove constants from the expression with expr.drop_constants() or use model.add_objective(..., allow_constant=True).",
846823
)
847824

825+
objective = Objective(expression=expr, model=self, sense=sense)
826+
self._objective = objective
827+
848828
def remove_variables(self, name: str) -> None:
849829
"""
850830
Remove all variables stored under reference name `name` from the model.

test/test_optimization.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -956,21 +956,24 @@ def test_model_resolve(
956956
assert np.isclose(model.objective.value or 0, 5.25)
957957

958958

959-
def test_model_with_constant_in_objective_feasible(model: Model) -> None:
959+
def test_constant_not_allowed_unless_specified_explicitly(model: Model) -> None:
960960
objective = model.objective.expression + 1
961961

962962
with pytest.raises(ConstantObjectiveError):
963-
model.add_objective(
964-
expr=objective, overwrite=True, allow_constant_objective=False
965-
)
966-
967-
model.add_objective(expr=objective, overwrite=True, allow_constant_objective=True)
968-
model.allow_constant_objective = False
963+
model.add_objective(expr=objective, overwrite=True, allow_constant=False)
964+
with pytest.raises(ConstantObjectiveError):
965+
model.add_objective(expr=objective, overwrite=True)
969966

970967
with pytest.raises(ConstantObjectiveError):
971-
status, _ = model.solve(solver_name="highs")
968+
model.objective = objective
969+
970+
model.add_objective(expr=objective, overwrite=True, allow_constant=True)
971+
972+
973+
def test_constant_feasible(model: Model) -> None:
974+
objective = model.objective.expression + 1
975+
model.add_objective(expr=objective, overwrite=True, allow_constant=True)
972976

973-
model.allow_constant_objective = True
974977
status, _ = model.solve(solver_name="highs")
975978
assert status == "ok"
976979
# x = -0.1, y = 1.7
@@ -979,9 +982,9 @@ def test_model_with_constant_in_objective_feasible(model: Model) -> None:
979982
assert model.objective.expression.solution == 4.3
980983

981984

982-
def test_model_with_constant_in_objective_infeasible(model: Model) -> None:
985+
def test_constant_infeasible(model: Model) -> None:
983986
objective = model.objective.expression + 1
984-
model.add_objective(expr=objective, overwrite=True, allow_constant_objective=True)
987+
model.add_objective(expr=objective, overwrite=True, allow_constant=True)
985988
model.add_constraints([(1, "x")], "<=", 0)
986989
model.add_constraints([(1, "y")], "<=", 0)
987990

@@ -992,9 +995,9 @@ def test_model_with_constant_in_objective_infeasible(model: Model) -> None:
992995
assert model.objective.expression.const == 1
993996

994997

995-
def test_model_with_constant_in_objective_error(model: Model) -> None:
998+
def test_constant_error(model: Model) -> None:
996999
objective = model.objective.expression + 1
997-
model.add_objective(expr=objective, overwrite=True, allow_constant_objective=True)
1000+
model.add_objective(expr=objective, overwrite=True, allow_constant=True)
9981001
model.add_constraints([(1, "x")], "<=", 0)
9991002
model.add_constraints([(1, "y")], "<=", 0)
10001003

0 commit comments

Comments
 (0)