Skip to content

Commit 85e4c27

Browse files
committed
EHN: Add linprog_simplex
1 parent 1f1ad4b commit 85e4c27

File tree

3 files changed

+490
-1
lines changed

3 files changed

+490
-1
lines changed

quantecon/optimize/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
Initialization of the optimize subpackage
44
"""
5-
5+
from .linprog_simplex import linprog_simplex, solve_tableau, get_solution
66
from .scalar_maximization import brent_max
77
from .nelder_mead import nelder_mead
88
from .root_finding import newton, newton_halley, newton_secant, bisect, brentq

quantecon/optimize/linprog_simplex.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
"""
2+
Contain a linear programming solver routine based on the Simplex Method.
3+
4+
"""
5+
from collections import namedtuple
6+
import numpy as np
7+
from numba import jit
8+
from .pivoting import _pivoting, _lex_min_ratio_test
9+
10+
11+
FEA_TOL = 1e-6
12+
13+
14+
SimplexResult = namedtuple(
15+
'SimplexResult', ['x', 'lambd', 'fun', 'success', 'status', 'num_iter']
16+
)
17+
18+
19+
@jit(nopython=True, cache=True)
20+
def linprog_simplex(c, A_ub=np.empty((0, 0)), b_ub=np.empty((0,)),
21+
A_eq=np.empty((0, 0)), b_eq=np.empty((0,)), max_iter=10**6,
22+
tableau=None, basis=None, x=None, lambd=None):
23+
n, m, k = c.shape[0], A_ub.shape[0], A_eq.shape[0]
24+
L = m + k
25+
26+
if tableau is None:
27+
tableau = np.empty((L+1, n+m+L+1))
28+
if basis is None:
29+
basis = np.empty(L, dtype=np.int_)
30+
if x is None:
31+
x = np.empty(n)
32+
if lambd is None:
33+
lambd = np.empty(L)
34+
35+
num_iter = 0
36+
fun = -np.inf
37+
38+
b_signs = np.empty(L, dtype=np.bool_)
39+
for i in range(m):
40+
b_signs[i] = True if b_ub[i] >= 0 else False
41+
for i in range(k):
42+
b_signs[m+i] = True if b_eq[i] >= 0 else False
43+
44+
# Construct initial tableau for Phase 1
45+
_initialize_tableau(A_ub, b_ub, A_eq, b_eq, tableau, basis)
46+
47+
# Phase 1
48+
success, status, num_iter_1 = \
49+
solve_tableau(tableau, basis, max_iter, skip_aux=False)
50+
num_iter += num_iter_1
51+
if not success: # max_iter exceeded
52+
return SimplexResult(x, lambd, fun, success, status, num_iter)
53+
if tableau[-1, -1] > FEA_TOL: # Infeasible
54+
success = False
55+
status = 2
56+
return SimplexResult(x, lambd, fun, success, status, num_iter)
57+
58+
# Modify the criterion row for Phase 2
59+
_set_criterion_row(c, basis, tableau)
60+
61+
# Phase 2
62+
success, status, num_iter_2 = \
63+
solve_tableau(tableau, basis, max_iter-num_iter, skip_aux=True)
64+
num_iter += num_iter_2
65+
fun = get_solution(tableau, basis, x, lambd, b_signs)
66+
67+
return SimplexResult(x, lambd, fun, success, status, num_iter)
68+
69+
70+
@jit(nopython=True, cache=True)
71+
def _initialize_tableau(A_ub, b_ub, A_eq, b_eq, tableau, basis):
72+
m, k = A_ub.shape[0], A_eq.shape[0]
73+
L = m + k
74+
n = tableau.shape[1] - (m+L+1)
75+
76+
for i in range(m):
77+
for j in range(n):
78+
tableau[i, j] = A_ub[i, j]
79+
for i in range(k):
80+
for j in range(n):
81+
tableau[m+i, j] = A_eq[i, j]
82+
83+
tableau[:L, n:-1] = 0
84+
85+
for i in range(m):
86+
tableau[i, -1] = b_ub[i]
87+
if tableau[i, -1] < 0:
88+
for j in range(n):
89+
tableau[i, j] *= -1
90+
tableau[i, n+i] = -1
91+
tableau[i, -1] *= -1
92+
else:
93+
tableau[i, n+i] = 1
94+
tableau[i, n+m+i] = 1
95+
for i in range(k):
96+
tableau[m+i, -1] = b_eq[i]
97+
if tableau[m+i, -1] < 0:
98+
for j in range(n):
99+
tableau[m+i, j] *= -1
100+
tableau[m+i, -1] *= -1
101+
tableau[m+i, n+m+m+i] = 1
102+
103+
tableau[-1, :] = 0
104+
for i in range(L):
105+
for j in range(n+m):
106+
tableau[-1, j] += tableau[i, j]
107+
tableau[-1, -1] += tableau[i, -1]
108+
109+
for i in range(L):
110+
basis[i] = n+m+i
111+
112+
return tableau, basis
113+
114+
115+
@jit(nopython=True, cache=True)
116+
def _set_criterion_row(c, basis, tableau):
117+
n = c.shape[0]
118+
L = basis.shape[0]
119+
120+
for j in range(n):
121+
tableau[-1, j] = c[j]
122+
tableau[-1, n:] = 0
123+
124+
for i in range(L):
125+
multiplier = tableau[-1, basis[i]]
126+
for j in range(tableau.shape[1]):
127+
tableau[-1, j] -= tableau[i, j] * multiplier
128+
129+
return tableau
130+
131+
132+
@jit(nopython=True, cache=True)
133+
def solve_tableau(tableau, basis, max_iter=10**6, skip_aux=True):
134+
"""
135+
Perform the simplex algorithm on a given tableau in canonical form.
136+
137+
Used to solve a linear program in the following form:
138+
139+
maximize: c @ x
140+
141+
subject to: A_ub @ x <= b_ub
142+
A_eq @ x == b_eq
143+
x >= 0
144+
145+
where A_ub is of shape (m, n) and A_eq is of shape (k, n). Thus,
146+
`tableau` is of shape (L+1, n+m+L+1), where L=m+k, and
147+
148+
* `tableau[np.arange(L), :][:, basis]` must be an identity matrix,
149+
and
150+
* the elements of `tableau[:-1, -1]` must be nonnegative.
151+
152+
Parameters
153+
----------
154+
tableau : ndarray(float, ndim=2)
155+
ndarray of shape (L+1, n+m+L+1) containing the tableau. Modified
156+
in place.
157+
158+
basis : ndarray(int, ndim=1)
159+
ndarray of shape (L,) containing the basic variables. Modified
160+
in place.
161+
162+
max_iter : scalar(int), optional(default=10**6)
163+
Maximum number of pivoting steps.
164+
165+
skip_aux : bool, optional(default=True)
166+
Whether to skip the coefficients of the auxiliary (or
167+
artificial) variables in pivot column selection.
168+
169+
"""
170+
L = tableau.shape[0] - 1
171+
172+
# Array to store row indices in lex_min_ratio_test
173+
argmins = np.empty(L, dtype=np.int_)
174+
175+
success = False
176+
status = 1
177+
num_iter = 0
178+
179+
while num_iter < max_iter:
180+
num_iter += 1
181+
182+
pivcol_found, pivcol = _pivot_col(tableau, skip_aux)
183+
184+
if not pivcol_found: # Optimal
185+
success = True
186+
status = 0
187+
break
188+
189+
aux_start = tableau.shape[1] - L - 1
190+
pivrow_found, pivrow = _lex_min_ratio_test(tableau[:-1, :], pivcol,
191+
aux_start, argmins)
192+
193+
if not pivrow_found: # Unbounded
194+
success = False
195+
status = 3
196+
break
197+
198+
_pivoting(tableau, pivcol, pivrow)
199+
basis[pivrow] = pivcol
200+
201+
return success, status, num_iter
202+
203+
204+
@jit(nopython=True, cache=True)
205+
def _pivot_col(tableau, skip_aux):
206+
L = tableau.shape[0] - 1
207+
criterion_row_stop = tableau.shape[1] - 1
208+
if skip_aux:
209+
criterion_row_stop -= L
210+
211+
found = False
212+
pivcol = -1
213+
coeff = FEA_TOL
214+
for j in range(criterion_row_stop):
215+
if tableau[-1, j] > coeff:
216+
coeff = tableau[-1, j]
217+
pivcol = j
218+
found = True
219+
220+
return found, pivcol
221+
222+
223+
@jit(nopython=True, cache=True)
224+
def get_solution(tableau, basis, x, lambd, b_signs):
225+
n, L = x.size, lambd.size
226+
aux_start = tableau.shape[1] - L - 1
227+
228+
x[:] = 0
229+
for i in range(L):
230+
if basis[i] < n:
231+
x[basis[i]] = tableau[i, -1]
232+
for j in range(L):
233+
lambd[j] = tableau[-1, aux_start+j]
234+
if lambd[j] != 0 and b_signs[j]:
235+
lambd[j] *= -1
236+
fun = tableau[-1, -1] * (-1)
237+
238+
return fun

0 commit comments

Comments
 (0)