@@ -46,15 +46,16 @@ def fit_angle_in_range(angles, min_angle=-np.pi, max_angle=np.pi):
4646 return output .reshape (output_shape )
4747
4848
49- def update_state_with_Runge_Kutta (state , u , functions , dt = 0.01 ):
49+ def update_state_with_Runge_Kutta (state , u , functions , dt = 0.01 , batch = True ):
5050 """ update state in Runge Kutta methods
5151 Args:
5252 state (array-like): state of system
5353 u (array-like): input of system
5454 functions (list): update function of each state,
55- each function will be called like func(* state, * u)
55+ each function will be called like func(state, u)
5656 We expect that this function returns differential of each state
5757 dt (float): float in seconds
58+ batch (bool): state and u is given by batch or not
5859
5960 Returns:
6061 next_state (np.array): next state of system
@@ -68,36 +69,50 @@ def func_x(self, x_1, x_2, u):
6869
6970 Note that the function return x_dot.
7071 """
71- state_size = len (state )
72- assert state_size == len (functions ), \
73- "Invalid functions length, You need to give the state size functions"
72+ if not batch :
73+ state_size = len (state )
74+ assert state_size == len (functions ), \
75+ "Invalid functions length, You need to give the state size functions"
7476
75- k0 = np .zeros (state_size )
76- k1 = np .zeros (state_size )
77- k2 = np .zeros (state_size )
78- k3 = np .zeros (state_size )
77+ k0 = np .zeros (state_size )
78+ k1 = np .zeros (state_size )
79+ k2 = np .zeros (state_size )
80+ k3 = np .zeros (state_size )
7981
80- inputs = np .concatenate ([state , u ])
82+ for i , func in enumerate (functions ):
83+ k0 [i ] = dt * func (state , u )
8184
82- for i , func in enumerate (functions ):
83- k0 [i ] = dt * func (* inputs )
85+ for i , func in enumerate (functions ):
86+ k1 [i ] = dt * func (state + k0 / 2. , u )
8487
85- add_state = state + k0 / 2.
86- inputs = np . concatenate ([ add_state , u ] )
88+ for i , func in enumerate ( functions ):
89+ k2 [ i ] = dt * func ( state + k1 / 2. , u )
8790
88- for i , func in enumerate (functions ):
89- k1 [i ] = dt * func (* inputs )
91+ for i , func in enumerate (functions ):
92+ k3 [i ] = dt * func (state + k2 , u )
9093
91- add_state = state + k1 / 2.
92- inputs = np .concatenate ([add_state , u ])
94+ return (k0 + 2. * k1 + 2. * k2 + k3 ) / 6.
9395
94- for i , func in enumerate (functions ):
95- k2 [i ] = dt * func (* inputs )
96+ else :
97+ batch_size , state_size = state .shape
98+ assert state_size == len (functions ), \
99+ "Invalid functions length, You need to give the state size functions"
96100
97- add_state = state + k2
98- inputs = np .concatenate ([add_state , u ])
101+ k0 = np .zeros (batch_size , state_size )
102+ k1 = np .zeros (batch_size , state_size )
103+ k2 = np .zeros (batch_size , state_size )
104+ k3 = np .zeros (batch_size , state_size )
99105
100- for i , func in enumerate (functions ):
101- k3 [ i ] = dt * func (* inputs )
106+ for i , func in enumerate (functions ):
107+ k0 [:, i ] = dt * func (state , u )
102108
103- return (k0 + 2. * k1 + 2. * k2 + k3 ) / 6.
109+ for i , func in enumerate (functions ):
110+ k1 [:, i ] = dt * func (state + k0 / 2. , u )
111+
112+ for i , func in enumerate (functions ):
113+ k2 [:, i ] = dt * func (state + k1 / 2. , u )
114+
115+ for i , func in enumerate (functions ):
116+ k3 [:, i ] = dt * func (state + k2 , u )
117+
118+ return (k0 + 2. * k1 + 2. * k2 + k3 ) / 6.
0 commit comments