Skip to content

Commit b28e4fa

Browse files
committed
Add type variable for solver's environment type
1 parent 9a5a128 commit b28e4fa

File tree

3 files changed

+64
-17
lines changed

3 files changed

+64
-17
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
.coverage
22
.eggs
33
.DS_Store
4+
.mypy_cache
45
linopy/__pycache__
56
test/__pycache__
67
linopy.egg-info

linopy/solvers.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections import namedtuple
1818
from collections.abc import Callable, Generator
1919
from pathlib import Path
20-
from typing import TYPE_CHECKING, Any
20+
from typing import TYPE_CHECKING, Any, Generic, TypeVar
2121

2222
import numpy as np
2323
import pandas as pd
@@ -36,6 +36,8 @@
3636

3737
from linopy.model import Model
3838

39+
EnvType = TypeVar("EnvType")
40+
3941
QUADRATIC_SOLVERS = [
4042
"gurobi",
4143
"xpress",
@@ -197,7 +199,7 @@ def maybe_adjust_objective_sign(
197199
return solution
198200

199201

200-
class Solver(ABC):
202+
class Solver(ABC, Generic[EnvType]):
201203
"""
202204
Abstract base class for solving a given linear problem.
203205
@@ -244,7 +246,7 @@ def solve_problem_from_model(
244246
log_fn: Path | None = None,
245247
warmstart_fn: Path | None = None,
246248
basis_fn: Path | None = None,
247-
env: None = None,
249+
env: EnvType | None = None,
248250
explicit_coordinate_names: bool = False,
249251
) -> Result:
250252
"""
@@ -264,7 +266,7 @@ def solve_problem_from_file(
264266
log_fn: Path | None = None,
265267
warmstart_fn: Path | None = None,
266268
basis_fn: Path | None = None,
267-
env: None = None,
269+
env: EnvType | None = None,
268270
) -> Result:
269271
"""
270272
Abstract method to solve a linear problem from a problem file.
@@ -283,7 +285,7 @@ def solve_problem(
283285
log_fn: Path | None = None,
284286
warmstart_fn: Path | None = None,
285287
basis_fn: Path | None = None,
286-
env: None = None,
288+
env: EnvType | None = None,
287289
explicit_coordinate_names: bool = False,
288290
) -> Result:
289291
"""
@@ -324,7 +326,7 @@ def solver_name(self) -> SolverName:
324326
return SolverName[self.__class__.__name__]
325327

326328

327-
class CBC(Solver):
329+
class CBC(Solver[None]):
328330
"""
329331
Solver subclass for the CBC solver.
330332
@@ -505,7 +507,7 @@ def get_solver_solution() -> Solution:
505507
return Result(status, solution, CbcModel(mip_gap, runtime))
506508

507509

508-
class GLPK(Solver):
510+
class GLPK(Solver[None]):
509511
"""
510512
Solver subclass for the GLPK solver.
511513
@@ -675,7 +677,7 @@ def get_solver_solution() -> Solution:
675677
return Result(status, solution)
676678

677679

678-
class Highs(Solver):
680+
class Highs(Solver[None]):
679681
"""
680682
Solver subclass for the Highs solver. Highs must be installed
681683
for usage. Find the documentation at https://www.maths.ed.ac.uk/hall/HiGHS/.
@@ -921,7 +923,7 @@ def get_solver_solution() -> Solution:
921923
return Result(status, solution, h)
922924

923925

924-
class Gurobi(Solver):
926+
class Gurobi(Solver[gurobipy.Env | dict[str, Any] | None]):
925927
"""
926928
Solver subclass for the gurobi solver.
927929
@@ -1156,7 +1158,7 @@ def get_solver_solution() -> Solution:
11561158
return Result(status, solution, m)
11571159

11581160

1159-
class Cplex(Solver):
1161+
class Cplex(Solver[None]):
11601162
"""
11611163
Solver subclass for the Cplex solver.
11621164
@@ -1312,7 +1314,7 @@ def get_solver_solution() -> Solution:
13121314
return Result(status, solution, m)
13131315

13141316

1315-
class SCIP(Solver):
1317+
class SCIP(Solver[None]):
13161318
"""
13171319
Solver subclass for the SCIP solver.
13181320
@@ -1465,7 +1467,7 @@ def get_solver_solution() -> Solution:
14651467
return Result(status, solution, m)
14661468

14671469

1468-
class Xpress(Solver):
1470+
class Xpress(Solver[None]):
14691471
"""
14701472
Solver subclass for the xpress solver.
14711473
@@ -1602,7 +1604,7 @@ def get_solver_solution() -> Solution:
16021604
mosek_bas_re = re.compile(r" (XL|XU)\s+([^ \t]+)\s+([^ \t]+)| (LL|UL|BS)\s+([^ \t]+)")
16031605

16041606

1605-
class Mosek(Solver):
1607+
class Mosek(Solver[None]):
16061608
"""
16071609
Solver subclass for the Mosek solver.
16081610
@@ -1932,7 +1934,7 @@ def get_solver_solution() -> Solution:
19321934
return Result(status, solution)
19331935

19341936

1935-
class COPT(Solver):
1937+
class COPT(Solver[None]):
19361938
"""
19371939
Solver subclass for the COPT solver.
19381940
@@ -2073,7 +2075,7 @@ def get_solver_solution() -> Solution:
20732075
return Result(status, solution, m)
20742076

20752077

2076-
class MindOpt(Solver):
2078+
class MindOpt(Solver[None]):
20772079
"""
20782080
Solver subclass for the MindOpt solver.
20792081
@@ -2216,7 +2218,7 @@ def get_solver_solution() -> Solution:
22162218
return Result(status, solution, m)
22172219

22182220

2219-
class PIPS(Solver):
2221+
class PIPS(Solver[None]):
22202222
"""
22212223
Solver subclass for the PIPS solver.
22222224
"""

test/test_solvers.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77

88
from pathlib import Path
99

10+
import pandas as pd
1011
import pytest
12+
import xarray as xr
1113

12-
from linopy import solvers
14+
from linopy import LESS_EQUAL, Model, solvers
1315

1416
free_mps_problem = """NAME sample_mip
1517
ROWS
@@ -44,6 +46,22 @@
4446
"""
4547

4648

49+
@pytest.fixture
50+
def model() -> Model:
51+
m = Model()
52+
53+
x = m.add_variables(4, pd.Series([8, 10]), name="x")
54+
y = m.add_variables(0, pd.DataFrame([[1, 2], [3, 4]]), name="y")
55+
56+
m.add_constraints(x + y, LESS_EQUAL, 10)
57+
58+
m.add_objective(2 * x + 3 * y)
59+
60+
m.parameters["param"] = xr.DataArray([1, 2, 3, 4], dims=["x"])
61+
62+
return m
63+
64+
4765
@pytest.mark.parametrize("solver", set(solvers.available_solvers))
4866
def test_free_mps_solution_parsing(solver: str, tmp_path: Path) -> None:
4967
try:
@@ -64,3 +82,29 @@ def test_free_mps_solution_parsing(solver: str, tmp_path: Path) -> None:
6482

6583
assert result.status.is_ok
6684
assert result.solution.objective == 30.0
85+
86+
87+
@pytest.mark.skipif(
88+
"gurobi" not in set(solvers.available_solvers), reason="Gurobi is not installed"
89+
)
90+
def test_gurobi_environment_parameters(model: Model, tmp_path: Path) -> None:
91+
gurobi = solvers.Gurobi()
92+
93+
mps_file = tmp_path / "problem.mps"
94+
mps_file.write_text(free_mps_problem)
95+
sol_file = tmp_path / "solution.sol"
96+
97+
log1_file = tmp_path / "gurobi1.log"
98+
result = gurobi.solve_problem(
99+
problem_fn=mps_file, solution_fn=sol_file, env={"LogFile": str(log1_file)}
100+
)
101+
102+
assert result.status.is_ok
103+
assert log1_file.exists()
104+
105+
log2_file = tmp_path / "gurobi2.log"
106+
gurobi.solve_problem(
107+
model=model, solution_fn=sol_file, env={"LogFile": str(log2_file)}
108+
)
109+
assert result.status.is_ok
110+
assert log2_file.exists()

0 commit comments

Comments
 (0)