11import random
22import numpy as np
33from scipy .ndimage .filters import gaussian_filter
4- from .utils import pol2cart
4+ from .utils import pol2cart , cart2pol
55
66
77class State :
@@ -49,7 +49,10 @@ def __init__(self, n_targets=1, prob=0.9, target_speed=None, target_speed_range=
4949 # Setup an initial random state
5050 self .target_state = None
5151 if simulated :
52+ self .update_state = self .update_sim_state
5253 self .target_state = self .init_target_state ()
54+ else :
55+ self .update_state = self .update_real_state
5356 # Setup an initial sensor state
5457 self .sensor_state = self .init_sensor_state ()
5558
@@ -241,7 +244,7 @@ def entropy_collision_reward(self, state, action=None, action_idx=None, particle
241244
242245
243246 # returns new state given last state and action (control)
244- def update_state (self , state , control , target_update = False , transition_overwrite = None ):
247+ def update_sim_state (self , state , control = None , transition_overwrite = None , ** kwargs ):
245248 """Update state based on state and action
246249
247250 Parameters
@@ -258,6 +261,9 @@ def update_state(self, state, control, target_update=False, transition_overwrite
258261 """
259262 # Get current state vars
260263 r , theta , crs , spd = state
264+
265+ spd = random .randint (0 ,1 )
266+
261267 control_spd = control [1 ]
262268
263269 theta = theta % 360
@@ -276,24 +282,12 @@ def update_state(self, state, control, target_update=False, transition_overwrite
276282 x , y = pol2cart (r , np .radians (theta ))
277283
278284 # Generate next course given current course
279- if target_update :
280- spd = random .choice (self .target_speed_range )
281- if self .target_movement == 'circular' :
282- d_crs , circ_spd = self .circular_control (50 )
283- crs += d_crs
284- spd = circ_spd
285- else :
286- if random .random () >= self .prob_target_change_crs :
287- crs += random .choice ([- 1 , 1 ]) * 30
288- else :
289- if random .random () >= self .prob_target_change_crs :
290- crs += random .choice ([- 1 , 1 ]) * 30
285+ if random .random () >= self .prob_target_change_crs :
286+ crs += random .choice ([- 1 , 1 ]) * 30
291287 crs %= 360
292288 if crs < 0 :
293289 crs += 360
294290
295- spd = random .randint (0 ,1 )
296-
297291 # Transform changes to coords to cartesian
298292 dx , dy = pol2cart (spd , np .radians (crs ))
299293 if transition_overwrite :
@@ -308,8 +302,81 @@ def update_state(self, state, control, target_update=False, transition_overwrite
308302
309303 return [r , theta , crs , spd ]
310304
305+ # returns new state given last state and action (control)
306+ def update_real_state (self , state , distance = None , course = None , heading = None , ** kwargs ):
307+ """Update state based on state and action
308+
309+ Parameters
310+ ----------
311+ state_vars : list
312+ List of current state variables
313+ control : action (tuple)
314+ Action tuple
315+
316+ Returns
317+ -------
318+ State (array_like)
319+ Updated state values array
320+ """
321+ if distance is None :
322+ distance = 0
323+ if course is None :
324+ course = 0
325+ if heading is None :
326+ heading = self .sensor_state [2 ]
327+
328+ # Get current state vars
329+ r , theta_deg , crs , spd = state
330+ if random .random () >= self .prob_target_change_crs :
331+ crs += random .choice ([- 1 , 1 ]) * 30
332+ spd = random .randint (0 ,1 )
333+ control_spd = distance
334+ control_course = course % 360
335+ control_delta_heading = (heading - self .sensor_state [2 ]) % 360
336+
337+ # polar -> cartesian
338+ x , y = pol2cart (r , np .radians (theta_deg ))
339+
340+ # translate sensor movement
341+ dx , dy = pol2cart (control_spd , np .radians (control_course ))
342+ pos = [x - dx , y - dy ]
343+
344+ # translate target movement
345+ dx , dy = pol2cart (spd , np .radians (crs ))
346+ pos = [pos [0 ] + dx , pos [1 ] + dy ]
347+
348+ # cartesian -> polar
349+ r , theta = cart2pol (pos [0 ], pos [1 ])
350+ theta_deg = np .degrees (theta )
351+
352+ # rotation
353+ theta_deg -= control_delta_heading
354+ theta_deg %= 360
355+ crs -= control_delta_heading
356+ crs %= 360
357+
358+ return [r , theta_deg , crs , spd ]
359+
360+ def update_real_sensor (self , distance , course , heading ):
361+
362+ r , theta_deg , prev_heading , spd = self .sensor_state
363+ heading = heading if heading else prev_heading
364+
365+ if distance and course :
366+ spd = distance
367+ crs = course % 360
368+ dx , dy = pol2cart (spd , np .radians (crs ))
369+ x , y = pol2cart (r , np .radians (theta_deg ))
370+ pos = [x + dx , y + dy ]
371+
372+ r = np .sqrt (pos [0 ]** 2 + pos [1 ]** 2 )
373+ theta_deg = np .degrees (np .arctan2 (pos [1 ], pos [0 ]))
374+ theta_deg %= 360
375+
376+ self .sensor_state = np .array ([r , theta_deg , heading , spd ])
377+
311378 def update_sensor (self , control , bearing = None ):
312- r , theta_deg , crs , spd = self .sensor_state
379+ r , theta_deg , crs , old_spd = self .sensor_state
313380
314381 spd = control [1 ]
315382
0 commit comments