@@ -51,6 +51,41 @@ def plot_environment(ax, state):
5151 ax .set_aspect ("equal" )
5252
5353
54+ def _describe_history_shapes (case_name : str , label : str , history : dict [str , Any ]) -> None :
55+ """Log the shapes of history entries for quick inspection."""
56+
57+ history = dict (history )
58+ print (f"{ case_name } - { label } history shapes:" )
59+ if not history :
60+ print (" <empty>" )
61+ return
62+
63+ for key in sorted (history ):
64+ value = history [key ]
65+ if value is None :
66+ print (f" { key } : None" )
67+ continue
68+ arr = np .asarray (value )
69+ print (f" { key } : { arr .shape } " )
70+
71+
72+ def _summarize_state_deltas (case_name : str , ra_states : np .ndarray , our_states : np .ndarray ) -> None :
73+ if ra_states .shape != our_states .shape :
74+ print (f"{ case_name } - state shape mismatch; cannot summarise deltas" )
75+ return
76+
77+ diff = ra_states - our_states
78+ if diff .ndim == 1 :
79+ distances = np .abs (diff )
80+ else :
81+ distances = np .linalg .norm (diff , axis = diff .ndim - 1 )
82+
83+ print (
84+ f"{ case_name } - state Δ summary: mean={ distances .mean ():.6g} , "
85+ f"median={ np .median (distances ):.6g} , max={ distances .max ():.6g} "
86+ )
87+
88+
5489def _split_objects (env_params : dict ) -> tuple [dict , list [Any ]]:
5590 ra_params = dict (env_params )
5691 raw_objects = list (ra_params .pop ("objects" , []) or [])
@@ -113,7 +148,6 @@ def run_case(name: str, env_params: dict, agent_configs: Sequence[dict], steps:
113148
114149 if init_pos is not None or init_vel is not None :
115150 ra_agent .reset_history ()
116- ra_agent .save_to_history ()
117151 our_agent .reset_history ()
118152
119153 ra_states = [ra_agent .pos .copy ()]
@@ -132,6 +166,15 @@ def run_case(name: str, env_params: dict, agent_configs: Sequence[dict], steps:
132166 ra_states = np .array (ra_states )
133167 our_states = np .array (our_states )
134168
169+ print (f"{ name } - RatInABox states shape:" )
170+ print (ra_states .shape )
171+ print (f"{ name } - canns-lib states shape:" )
172+ print (our_states .shape )
173+
174+ _summarize_state_deltas (name , ra_states , our_states )
175+ _describe_history_shapes (name , "RatInABox" , ra_agent .history )
176+ _describe_history_shapes (name , "canns-lib" , our_agent .history )
177+
135178 # Trajectory plot
136179 fig , ax = plt .subplots (figsize = (5 , 5 ))
137180 plot_environment (ax , our_env .render_state ())
@@ -157,6 +200,18 @@ def run_case(name: str, env_params: dict, agent_configs: Sequence[dict], steps:
157200if __name__ == "__main__" : # pragma: no cover
158201 # Scenarios roughly mirror RatInABox demos such as simple_example (uniform drift),
159202 # extensive_example (wall interactions), and path_integration/vector_cell notebooks.
203+ const_env_size = 1.5
204+ const_dt = 0.001
205+ const_duration = 2.0
206+ const_steps = int (round (const_duration / const_dt ))
207+ const_speed = 2.0
208+ const_angle = (11.0 / 12.0 ) * np .pi
209+ const_init_vel = (
210+ const_speed * np .cos (const_angle ),
211+ const_speed * np .sin (const_angle ),
212+ )
213+ const_init_pos = [const_env_size * 15.0 / 16.0 , const_env_size * 1.0 / 16.0 ]
214+
160215 cases = [
161216 (
162217 "case1_uniform" ,
@@ -385,5 +440,36 @@ def run_case(name: str, env_params: dict, agent_configs: Sequence[dict], steps:
385440 ),
386441 ]
387442
443+ for seed in [0 ]:
444+ name = f"case9_constant_speed_seed{ seed } "
445+ cases .append (
446+ (
447+ name ,
448+ {
449+ "dimensionality" : "2D" ,
450+ "boundary_conditions" : "solid" ,
451+ "scale" : const_env_size ,
452+ "aspect" : 1.0 ,
453+ },
454+ (
455+ {
456+ "params" : {
457+ "dt" : const_dt ,
458+ "speed_mean" : const_speed ,
459+ "speed_std" : 0.0 ,
460+ "speed_coherence_time" : 10.0 ,
461+ "rotational_velocity_std" : np .deg2rad (40.0 ),
462+ },
463+ "rng_seed" : seed ,
464+ "init_pos" : const_init_pos ,
465+ # "init_vel": list(const_init_vel),
466+ },
467+ {},
468+ ),
469+ const_steps ,
470+ const_dt ,
471+ )
472+ )
473+
388474 for case in cases :
389475 run_case (* case )
0 commit comments