|
| 1 | +import numpy as np |
| 2 | +from numba import jit |
| 3 | +from .utilities import _find_indices |
| 4 | +from ..optimize.linprog_simplex import solve_tableau, PivOptions |
| 5 | +from ..optimize.pivoting import _pivoting |
| 6 | + |
| 7 | + |
| 8 | +@jit(nopython=True, cache=True) |
| 9 | +def ddp_linprog_simplex(R, Q, beta, a_indices, a_indptr, sigma, |
| 10 | + max_iter=10**6, piv_options=PivOptions(), |
| 11 | + tableau=None, basis=None, v=None): |
| 12 | + r""" |
| 13 | + Numba jit complied function to solve a discrete dynamic program via |
| 14 | + linear programming, using `optimize.linprog_simplex` routines. The |
| 15 | + problem has to be represented in state-action pair form with 1-dim |
| 16 | + reward ndarray `R` of shape (n,), 2-dim transition probability |
| 17 | + ndarray `Q` of shapce (L, n), and disount factor `beta`, where n is |
| 18 | + the number of states and L is the number of feasible state-action |
| 19 | + pairs. |
| 20 | +
|
| 21 | + The approach exploits the fact that the optimal value function is |
| 22 | + the smallest function that satisfies :math:`v \geq T v`, where |
| 23 | + :math:`T` is the Bellman operator, and hence it is a (unique) |
| 24 | + solution to the linear program: |
| 25 | +
|
| 26 | + minimize:: |
| 27 | +
|
| 28 | + \sum_{s \in S} v(s) |
| 29 | +
|
| 30 | + subject to :: |
| 31 | +
|
| 32 | + v(s) \geq r(s, a) + \beta \sum_{s' \in S} q(s'|s, a) v(s') |
| 33 | + \quad ((s, a) \in \mathit{SA}). |
| 34 | +
|
| 35 | + This function solves its dual problem: |
| 36 | +
|
| 37 | + maximize:: |
| 38 | +
|
| 39 | + \sum_{(s, a) \in \mathit{SA}} r(s, a) y(s, a) |
| 40 | +
|
| 41 | + subject to:: |
| 42 | +
|
| 43 | + \sum_{a: (s', a) \in \mathit{SA}} y(s', a) - |
| 44 | + \sum_{(s, a) \in \mathit{SA}} \beta q(s'|s, a) y(s, a) = 1 |
| 45 | + \quad (s' \in S), |
| 46 | +
|
| 47 | + y(s, a) \geq 0 \quad ((s, a) \in \mathit{SA}), |
| 48 | +
|
| 49 | + where the optimal value function is obtained as an optimal dual |
| 50 | + solution and an optimal policy as an optimal basis. |
| 51 | +
|
| 52 | + Parameters |
| 53 | + ---------- |
| 54 | + R : ndarray(float, ndim=1) |
| 55 | + Reward ndarray, of shape (n,). |
| 56 | +
|
| 57 | + Q : ndarray(float, ndim=2) |
| 58 | + Transition probability ndarray, of shape (L, n). |
| 59 | +
|
| 60 | + beta : scalar(float) |
| 61 | + Discount factor. Must be in [0, 1). |
| 62 | +
|
| 63 | + a_indices : ndarray(int, ndim=1) |
| 64 | + Action index ndarray, of shape (L,). |
| 65 | +
|
| 66 | + a_indptr : ndarray(int, ndim=1) |
| 67 | + Action index pointer ndarray, of shape (n+1,). |
| 68 | +
|
| 69 | + sigma : ndarray(int, ndim=1) |
| 70 | + ndarray containing the initial feasible policy, of shape (n,). |
| 71 | + To be modified in place to store the output optimal policy. |
| 72 | +
|
| 73 | + max_iter : int, optional(default=10**6) |
| 74 | + Maximum number of iteration in the linear programming solver. |
| 75 | +
|
| 76 | + piv_options : PivOptions, optional |
| 77 | + PivOptions namedtuple to set tolerance values used in the linear |
| 78 | + programming solver. |
| 79 | +
|
| 80 | + tableau : ndarray(float, ndim=2), optional |
| 81 | + Temporary ndarray of shape (n+1, L+n+1) to store the tableau. |
| 82 | + Modified in place. |
| 83 | +
|
| 84 | + basis : ndarray(int, ndim=1), optional |
| 85 | + Temporary ndarray of shape (n,) to store the basic variables. |
| 86 | + Modified in place. |
| 87 | +
|
| 88 | + v : ndarray(float, ndim=1), optional |
| 89 | + Output ndarray of shape (n,) to store the optimal value |
| 90 | + function. Modified in place. |
| 91 | +
|
| 92 | + Returns |
| 93 | + ------- |
| 94 | + success : bool |
| 95 | + True if the algorithm succeeded in finding an optimal solution. |
| 96 | +
|
| 97 | + num_iter : int |
| 98 | + The number of iterations performed. |
| 99 | +
|
| 100 | + v : ndarray(float, ndim=1) |
| 101 | + Optimal value function (view to `v` if supplied). |
| 102 | +
|
| 103 | + sigma : ndarray(int, ndim=1) |
| 104 | + Optimal policy (view to `sigma`). |
| 105 | +
|
| 106 | + """ |
| 107 | + L, n = Q.shape |
| 108 | + |
| 109 | + if tableau is None: |
| 110 | + tableau = np.empty((n+1, L+n+1)) |
| 111 | + if basis is None: |
| 112 | + basis = np.empty(n, dtype=np.int_) |
| 113 | + if v is None: |
| 114 | + v = np.empty(n) |
| 115 | + |
| 116 | + _initialize_tableau(R, Q, beta, a_indptr, tableau) |
| 117 | + _find_indices(a_indices, a_indptr, sigma, out=basis) |
| 118 | + |
| 119 | + # Phase 1 |
| 120 | + for i in range(n): |
| 121 | + _pivoting(tableau, basis[i], i) |
| 122 | + |
| 123 | + # Phase 2 |
| 124 | + success, status, num_iter = \ |
| 125 | + solve_tableau(tableau, basis, max_iter-n, skip_aux=True, |
| 126 | + piv_options=piv_options) |
| 127 | + |
| 128 | + # Obtain solution |
| 129 | + for i in range(n): |
| 130 | + v[i] = tableau[-1, L+i] * (-1) |
| 131 | + |
| 132 | + for i in range(n): |
| 133 | + sigma[i] = a_indices[basis[i]] |
| 134 | + |
| 135 | + return success, num_iter+n, v, sigma |
| 136 | + |
| 137 | + |
| 138 | +@jit(nopython=True, cache=True) |
| 139 | +def _initialize_tableau(R, Q, beta, a_indptr, tableau): |
| 140 | + """ |
| 141 | + Initialize the `tableau` array. |
| 142 | +
|
| 143 | + Parameters |
| 144 | + ---------- |
| 145 | + R : ndarray(float, ndim=1) |
| 146 | + Reward ndarray, of shape (n,). |
| 147 | +
|
| 148 | + Q : ndarray(float, ndim=2) |
| 149 | + Transition probability ndarray, of shape (L, n). |
| 150 | +
|
| 151 | + beta : scalar(float) |
| 152 | + Discount factor. Must be in [0, 1). |
| 153 | +
|
| 154 | + a_indptr : ndarray(int, ndim=1) |
| 155 | + Action index pointer ndarray, of shape (n+1,). |
| 156 | +
|
| 157 | + tableau : ndarray(float, ndim=2) |
| 158 | + Empty ndarray of shape (n+1, L+n+1) to store the tableau. |
| 159 | + Modified in place. |
| 160 | +
|
| 161 | + Returns |
| 162 | + ------- |
| 163 | + tableau : ndarray(float, ndim=2) |
| 164 | + View to `tableau`. |
| 165 | +
|
| 166 | + """ |
| 167 | + L, n = Q.shape |
| 168 | + |
| 169 | + for j in range(L): |
| 170 | + for i in range(n): |
| 171 | + tableau[i, j] = Q[j, i] * (-beta) |
| 172 | + |
| 173 | + for i in range(n): |
| 174 | + for j in range(a_indptr[i], a_indptr[i+1]): |
| 175 | + tableau[i, j] += 1 |
| 176 | + |
| 177 | + tableau[:n, L:-1] = 0 |
| 178 | + |
| 179 | + for i in range(n): |
| 180 | + tableau[i, L+i] = 1 |
| 181 | + tableau[i, -1] = 1 |
| 182 | + |
| 183 | + for j in range(L): |
| 184 | + tableau[-1, j] = R[j] |
| 185 | + |
| 186 | + tableau[-1, L:] = 0 |
| 187 | + |
| 188 | + return tableau |
0 commit comments