11import numpy as np
22
3+
34class CartPoleConfigModule ():
45 # parameters
56 ENV_NAME = "CartPole-v0"
@@ -12,7 +13,7 @@ class CartPoleConfigModule():
1213 DT = 0.02
1314 # cost parameters
1415 R = np .diag ([0.01 ]) # 0.01 is worked for MPPI and CEM and MPPIWilliams
15- # 1. is worked for iLQR
16+ # 1. is worked for iLQR
1617 TERMINAL_WEIGHT = 1.
1718 Q = None
1819 Sf = None
@@ -39,41 +40,41 @@ def __init__(self):
3940 "num_elites" : 50 ,
4041 "max_iters" : 15 ,
4142 "alpha" : 0.3 ,
42- "init_var" :9. ,
43- "threshold" :0.001
43+ "init_var" : 9. ,
44+ "threshold" : 0.001
4445 },
45- "MPPI" :{
46- "beta" : 0.6 ,
46+ "MPPI" : {
47+ "beta" : 0.6 ,
4748 "popsize" : 5000 ,
4849 "kappa" : 0.9 ,
4950 "noise_sigma" : 0.5 ,
5051 },
51- "MPPIWilliams" :{
52+ "MPPIWilliams" : {
5253 "popsize" : 5000 ,
5354 "lambda" : 1. ,
5455 "noise_sigma" : 0.9 ,
5556 },
56- "iLQR" :{
57+ "iLQR" : {
5758 "max_iter" : 500 ,
5859 "init_mu" : 1. ,
5960 "mu_min" : 1e-6 ,
6061 "mu_max" : 1e10 ,
6162 "init_delta" : 2. ,
6263 "threshold" : 1e-6 ,
63- },
64- "DDP" :{
64+ },
65+ "DDP" : {
6566 "max_iter" : 500 ,
6667 "init_mu" : 1. ,
6768 "mu_min" : 1e-6 ,
6869 "mu_max" : 1e10 ,
6970 "init_delta" : 2. ,
7071 "threshold" : 1e-6 ,
71- },
72- "NMPC-CGMRES" :{
73- },
74- "NMPC-Newton" :{
75- },
76- }
72+ },
73+ "NMPC-CGMRES" : {
74+ },
75+ "NMPC-Newton" : {
76+ },
77+ }
7778
7879 @staticmethod
7980 def input_cost_fn (u ):
@@ -87,7 +88,7 @@ def input_cost_fn(u):
8788 shape(pop_size, pred_len, input_size)
8889 """
8990 return (u ** 2 ) * np .diag (CartPoleConfigModule .R )
90-
91+
9192 @staticmethod
9293 def state_cost_fn (x , g_x ):
9394 """ state cost function
@@ -103,21 +104,21 @@ def state_cost_fn(x, g_x):
103104 """
104105
105106 if len (x .shape ) > 2 :
106- return (6. * (x [:, :, 0 ]** 2 ) \
107- + 12. * ((np .cos (x [:, :, 2 ]) + 1. )** 2 ) \
108- + 0.1 * (x [:, :, 1 ]** 2 ) \
109- + 0.1 * (x [:, :, 3 ]** 2 ))[:, :, np .newaxis ]
107+ return (6. * (x [:, :, 0 ]** 2 )
108+ + 12. * ((np .cos (x [:, :, 2 ]) + 1. )** 2 )
109+ + 0.1 * (x [:, :, 1 ]** 2 )
110+ + 0.1 * (x [:, :, 3 ]** 2 ))[:, :, np .newaxis ]
110111
111112 elif len (x .shape ) > 1 :
112- return (6. * (x [:, 0 ]** 2 ) \
113- + 12. * ((np .cos (x [:, 2 ]) + 1. )** 2 ) \
114- + 0.1 * (x [:, 1 ]** 2 ) \
115- + 0.1 * (x [:, 3 ]** 2 ))[:, np .newaxis ]
116-
113+ return (6. * (x [:, 0 ]** 2 )
114+ + 12. * ((np .cos (x [:, 2 ]) + 1. )** 2 )
115+ + 0.1 * (x [:, 1 ]** 2 )
116+ + 0.1 * (x [:, 3 ]** 2 ))[:, np .newaxis ]
117+
117118 return 6. * (x [0 ]** 2 ) \
118- + 12. * ((np .cos (x [2 ]) + 1. )** 2 ) \
119- + 0.1 * (x [1 ]** 2 ) \
120- + 0.1 * (x [3 ]** 2 )
119+ + 12. * ((np .cos (x [2 ]) + 1. )** 2 ) \
120+ + 0.1 * (x [1 ]** 2 ) \
121+ + 0.1 * (x [3 ]** 2 )
121122
122123 @staticmethod
123124 def terminal_state_cost_fn (terminal_x , terminal_g_x ):
@@ -134,45 +135,45 @@ def terminal_state_cost_fn(terminal_x, terminal_g_x):
134135 """
135136
136137 if len (terminal_x .shape ) > 1 :
137- return (6. * (terminal_x [:, 0 ]** 2 ) \
138- + 12. * ((np .cos (terminal_x [:, 2 ]) + 1. )** 2 ) \
139- + 0.1 * (terminal_x [:, 1 ]** 2 ) \
140- + 0.1 * (terminal_x [:, 3 ]** 2 ))[:, np .newaxis ] \
141- * CartPoleConfigModule .TERMINAL_WEIGHT
142-
143- return (6. * (terminal_x [0 ]** 2 ) \
144- + 12. * ((np .cos (terminal_x [2 ]) + 1. )** 2 ) \
145- + 0.1 * (terminal_x [1 ]** 2 ) \
146- + 0.1 * (terminal_x [3 ]** 2 )) \
138+ return (6. * (terminal_x [:, 0 ]** 2 )
139+ + 12. * ((np .cos (terminal_x [:, 2 ]) + 1. )** 2 )
140+ + 0.1 * (terminal_x [:, 1 ]** 2 )
141+ + 0.1 * (terminal_x [:, 3 ]** 2 ))[:, np .newaxis ] \
147142 * CartPoleConfigModule .TERMINAL_WEIGHT
148-
143+
144+ return (6. * (terminal_x [0 ]** 2 )
145+ + 12. * ((np .cos (terminal_x [2 ]) + 1. )** 2 )
146+ + 0.1 * (terminal_x [1 ]** 2 )
147+ + 0.1 * (terminal_x [3 ]** 2 )) \
148+ * CartPoleConfigModule .TERMINAL_WEIGHT
149+
149150 @staticmethod
150151 def gradient_cost_fn_with_state (x , g_x , terminal = False ):
151152 """ gradient of costs with respect to the state
152153
153154 Args:
154155 x (numpy.ndarray): state, shape(pred_len, state_size)
155156 g_x (numpy.ndarray): goal state, shape(pred_len, state_size)
156-
157+
157158 Returns:
158159 l_x (numpy.ndarray): gradient of cost, shape(pred_len, state_size)
159160 or shape(1, state_size)
160161 """
161162 if not terminal :
162- cost_dx0 = 12. * x [:, 0 ]
163+ cost_dx0 = 12. * x [:, 0 ]
163164 cost_dx1 = 0.2 * x [:, 1 ]
164165 cost_dx2 = 24. * (1 + np .cos (x [:, 2 ])) * - np .sin (x [:, 2 ])
165166 cost_dx3 = 0.2 * x [:, 3 ]
166- cost_dx = np .stack ((cost_dx0 , cost_dx1 ,\
167+ cost_dx = np .stack ((cost_dx0 , cost_dx1 ,
167168 cost_dx2 , cost_dx3 ), axis = 1 )
168169 return cost_dx
169-
170- cost_dx0 = 12. * x [0 ]
170+
171+ cost_dx0 = 12. * x [0 ]
171172 cost_dx1 = 0.2 * x [1 ]
172173 cost_dx2 = 24. * (1 + np .cos (x [2 ])) * - np .sin (x [2 ])
173174 cost_dx3 = 0.2 * x [3 ]
174175 cost_dx = np .array ([[cost_dx0 , cost_dx1 , cost_dx2 , cost_dx3 ]])
175-
176+
176177 return cost_dx * CartPoleConfigModule .TERMINAL_WEIGHT
177178
178179 @staticmethod
@@ -206,21 +207,21 @@ def hessian_cost_fn_with_state(x, g_x, terminal=False):
206207 hessian [:, 0 , 0 ] = 12.
207208 hessian [:, 1 , 1 ] = 0.2
208209 hessian [:, 2 , 2 ] = 24. * - np .sin (x [:, 2 ]) \
209- * (- np .sin (x [:, 2 ])) \
210- + 24. * (1. + np .cos (x [:, 2 ])) \
211- * - np .cos (x [:, 2 ])
210+ * (- np .sin (x [:, 2 ])) \
211+ + 24. * (1. + np .cos (x [:, 2 ])) \
212+ * - np .cos (x [:, 2 ])
212213 hessian [:, 3 , 3 ] = 0.2
213214
214215 return hessian
215-
216+
216217 state_size = len (x )
217218 hessian = np .eye (state_size )
218219 hessian [0 , 0 ] = 12.
219220 hessian [1 , 1 ] = 0.2
220221 hessian [2 , 2 ] = 24. * - np .sin (x [2 ]) \
221- * (- np .sin (x [2 ])) \
222- + 24. * (1. + np .cos (x [2 ])) \
223- * - np .cos (x [2 ])
222+ * (- np .sin (x [2 ])) \
223+ + 24. * (1. + np .cos (x [2 ])) \
224+ * - np .cos (x [2 ])
224225 hessian [3 , 3 ] = 0.2
225226
226227 return hessian [np .newaxis , :, :] * CartPoleConfigModule .TERMINAL_WEIGHT
@@ -239,7 +240,7 @@ def hessian_cost_fn_with_input(x, u):
239240 (pred_len , _ ) = u .shape
240241
241242 return np .tile (2. * CartPoleConfigModule .R , (pred_len , 1 , 1 ))
242-
243+
243244 @staticmethod
244245 def hessian_cost_fn_with_input_state (x , u ):
245246 """ hessian costs with respect to the state and input
@@ -254,4 +255,4 @@ def hessian_cost_fn_with_input_state(x, u):
254255 (_ , state_size ) = x .shape
255256 (pred_len , input_size ) = u .shape
256257
257- return np .zeros ((pred_len , input_size , state_size ))
258+ return np .zeros ((pred_len , input_size , state_size ))
0 commit comments