Skip to content

Commit e5f3a28

Browse files
authored
ENH: Add LP solution method to DiscreteDP (#585)
1 parent b812edd commit e5f3a28

File tree

4 files changed

+377
-107
lines changed

4 files changed

+377
-107
lines changed
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import numpy as np
2+
from numba import jit
3+
from .utilities import _find_indices
4+
from ..optimize.linprog_simplex import solve_tableau, PivOptions
5+
from ..optimize.pivoting import _pivoting
6+
7+
8+
@jit(nopython=True, cache=True)
9+
def ddp_linprog_simplex(R, Q, beta, a_indices, a_indptr, sigma,
10+
max_iter=10**6, piv_options=PivOptions(),
11+
tableau=None, basis=None, v=None):
12+
r"""
13+
Numba jit complied function to solve a discrete dynamic program via
14+
linear programming, using `optimize.linprog_simplex` routines. The
15+
problem has to be represented in state-action pair form with 1-dim
16+
reward ndarray `R` of shape (n,), 2-dim transition probability
17+
ndarray `Q` of shapce (L, n), and disount factor `beta`, where n is
18+
the number of states and L is the number of feasible state-action
19+
pairs.
20+
21+
The approach exploits the fact that the optimal value function is
22+
the smallest function that satisfies :math:`v \geq T v`, where
23+
:math:`T` is the Bellman operator, and hence it is a (unique)
24+
solution to the linear program:
25+
26+
minimize::
27+
28+
\sum_{s \in S} v(s)
29+
30+
subject to ::
31+
32+
v(s) \geq r(s, a) + \beta \sum_{s' \in S} q(s'|s, a) v(s')
33+
\quad ((s, a) \in \mathit{SA}).
34+
35+
This function solves its dual problem:
36+
37+
maximize::
38+
39+
\sum_{(s, a) \in \mathit{SA}} r(s, a) y(s, a)
40+
41+
subject to::
42+
43+
\sum_{a: (s', a) \in \mathit{SA}} y(s', a) -
44+
\sum_{(s, a) \in \mathit{SA}} \beta q(s'|s, a) y(s, a) = 1
45+
\quad (s' \in S),
46+
47+
y(s, a) \geq 0 \quad ((s, a) \in \mathit{SA}),
48+
49+
where the optimal value function is obtained as an optimal dual
50+
solution and an optimal policy as an optimal basis.
51+
52+
Parameters
53+
----------
54+
R : ndarray(float, ndim=1)
55+
Reward ndarray, of shape (n,).
56+
57+
Q : ndarray(float, ndim=2)
58+
Transition probability ndarray, of shape (L, n).
59+
60+
beta : scalar(float)
61+
Discount factor. Must be in [0, 1).
62+
63+
a_indices : ndarray(int, ndim=1)
64+
Action index ndarray, of shape (L,).
65+
66+
a_indptr : ndarray(int, ndim=1)
67+
Action index pointer ndarray, of shape (n+1,).
68+
69+
sigma : ndarray(int, ndim=1)
70+
ndarray containing the initial feasible policy, of shape (n,).
71+
To be modified in place to store the output optimal policy.
72+
73+
max_iter : int, optional(default=10**6)
74+
Maximum number of iteration in the linear programming solver.
75+
76+
piv_options : PivOptions, optional
77+
PivOptions namedtuple to set tolerance values used in the linear
78+
programming solver.
79+
80+
tableau : ndarray(float, ndim=2), optional
81+
Temporary ndarray of shape (n+1, L+n+1) to store the tableau.
82+
Modified in place.
83+
84+
basis : ndarray(int, ndim=1), optional
85+
Temporary ndarray of shape (n,) to store the basic variables.
86+
Modified in place.
87+
88+
v : ndarray(float, ndim=1), optional
89+
Output ndarray of shape (n,) to store the optimal value
90+
function. Modified in place.
91+
92+
Returns
93+
-------
94+
success : bool
95+
True if the algorithm succeeded in finding an optimal solution.
96+
97+
num_iter : int
98+
The number of iterations performed.
99+
100+
v : ndarray(float, ndim=1)
101+
Optimal value function (view to `v` if supplied).
102+
103+
sigma : ndarray(int, ndim=1)
104+
Optimal policy (view to `sigma`).
105+
106+
"""
107+
L, n = Q.shape
108+
109+
if tableau is None:
110+
tableau = np.empty((n+1, L+n+1))
111+
if basis is None:
112+
basis = np.empty(n, dtype=np.int_)
113+
if v is None:
114+
v = np.empty(n)
115+
116+
_initialize_tableau(R, Q, beta, a_indptr, tableau)
117+
_find_indices(a_indices, a_indptr, sigma, out=basis)
118+
119+
# Phase 1
120+
for i in range(n):
121+
_pivoting(tableau, basis[i], i)
122+
123+
# Phase 2
124+
success, status, num_iter = \
125+
solve_tableau(tableau, basis, max_iter-n, skip_aux=True,
126+
piv_options=piv_options)
127+
128+
# Obtain solution
129+
for i in range(n):
130+
v[i] = tableau[-1, L+i] * (-1)
131+
132+
for i in range(n):
133+
sigma[i] = a_indices[basis[i]]
134+
135+
return success, num_iter+n, v, sigma
136+
137+
138+
@jit(nopython=True, cache=True)
139+
def _initialize_tableau(R, Q, beta, a_indptr, tableau):
140+
"""
141+
Initialize the `tableau` array.
142+
143+
Parameters
144+
----------
145+
R : ndarray(float, ndim=1)
146+
Reward ndarray, of shape (n,).
147+
148+
Q : ndarray(float, ndim=2)
149+
Transition probability ndarray, of shape (L, n).
150+
151+
beta : scalar(float)
152+
Discount factor. Must be in [0, 1).
153+
154+
a_indptr : ndarray(int, ndim=1)
155+
Action index pointer ndarray, of shape (n+1,).
156+
157+
tableau : ndarray(float, ndim=2)
158+
Empty ndarray of shape (n+1, L+n+1) to store the tableau.
159+
Modified in place.
160+
161+
Returns
162+
-------
163+
tableau : ndarray(float, ndim=2)
164+
View to `tableau`.
165+
166+
"""
167+
L, n = Q.shape
168+
169+
for j in range(L):
170+
for i in range(n):
171+
tableau[i, j] = Q[j, i] * (-beta)
172+
173+
for i in range(n):
174+
for j in range(a_indptr[i], a_indptr[i+1]):
175+
tableau[i, j] += 1
176+
177+
tableau[:n, L:-1] = 0
178+
179+
for i in range(n):
180+
tableau[i, L+i] = 1
181+
tableau[i, -1] = 1
182+
183+
for j in range(L):
184+
tableau[-1, j] = R[j]
185+
186+
tableau[-1, L:] = 0
187+
188+
return tableau

0 commit comments

Comments
 (0)