Skip to content

Commit a03aa7e

Browse files
committed
Release 0.6.2
1 parent 695d76b commit a03aa7e

File tree

6 files changed

+137
-35
lines changed

6 files changed

+137
-35
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "canns-lib"
3-
version = "0.6.1"
3+
version = "0.6.2"
44
edition = "2021"
55
license = "Apache-2.0"
66
authors = ["Sichao He <sichaohe@outlook.com>"]

example/trajectory_comparison.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,41 @@ def plot_environment(ax, state):
5151
ax.set_aspect("equal")
5252

5353

54+
def _describe_history_shapes(case_name: str, label: str, history: dict[str, Any]) -> None:
55+
"""Log the shapes of history entries for quick inspection."""
56+
57+
history = dict(history)
58+
print(f"{case_name} - {label} history shapes:")
59+
if not history:
60+
print(" <empty>")
61+
return
62+
63+
for key in sorted(history):
64+
value = history[key]
65+
if value is None:
66+
print(f" {key}: None")
67+
continue
68+
arr = np.asarray(value)
69+
print(f" {key}: {arr.shape}")
70+
71+
72+
def _summarize_state_deltas(case_name: str, ra_states: np.ndarray, our_states: np.ndarray) -> None:
73+
if ra_states.shape != our_states.shape:
74+
print(f"{case_name} - state shape mismatch; cannot summarise deltas")
75+
return
76+
77+
diff = ra_states - our_states
78+
if diff.ndim == 1:
79+
distances = np.abs(diff)
80+
else:
81+
distances = np.linalg.norm(diff, axis=diff.ndim - 1)
82+
83+
print(
84+
f"{case_name} - state Δ summary: mean={distances.mean():.6g}, "
85+
f"median={np.median(distances):.6g}, max={distances.max():.6g}"
86+
)
87+
88+
5489
def _split_objects(env_params: dict) -> tuple[dict, list[Any]]:
5590
ra_params = dict(env_params)
5691
raw_objects = list(ra_params.pop("objects", []) or [])
@@ -113,7 +148,6 @@ def run_case(name: str, env_params: dict, agent_configs: Sequence[dict], steps:
113148

114149
if init_pos is not None or init_vel is not None:
115150
ra_agent.reset_history()
116-
ra_agent.save_to_history()
117151
our_agent.reset_history()
118152

119153
ra_states = [ra_agent.pos.copy()]
@@ -132,6 +166,15 @@ def run_case(name: str, env_params: dict, agent_configs: Sequence[dict], steps:
132166
ra_states = np.array(ra_states)
133167
our_states = np.array(our_states)
134168

169+
print(f"{name} - RatInABox states shape:")
170+
print(ra_states.shape)
171+
print(f"{name} - canns-lib states shape:")
172+
print(our_states.shape)
173+
174+
_summarize_state_deltas(name, ra_states, our_states)
175+
_describe_history_shapes(name, "RatInABox", ra_agent.history)
176+
_describe_history_shapes(name, "canns-lib", our_agent.history)
177+
135178
# Trajectory plot
136179
fig, ax = plt.subplots(figsize=(5, 5))
137180
plot_environment(ax, our_env.render_state())
@@ -157,6 +200,18 @@ def run_case(name: str, env_params: dict, agent_configs: Sequence[dict], steps:
157200
if __name__ == "__main__": # pragma: no cover
158201
# Scenarios roughly mirror RatInABox demos such as simple_example (uniform drift),
159202
# extensive_example (wall interactions), and path_integration/vector_cell notebooks.
203+
const_env_size = 1.5
204+
const_dt = 0.001
205+
const_duration = 2.0
206+
const_steps = int(round(const_duration / const_dt))
207+
const_speed = 2.0
208+
const_angle = (11.0 / 12.0) * np.pi
209+
const_init_vel = (
210+
const_speed * np.cos(const_angle),
211+
const_speed * np.sin(const_angle),
212+
)
213+
const_init_pos = [const_env_size * 15.0 / 16.0, const_env_size * 1.0 / 16.0]
214+
160215
cases = [
161216
(
162217
"case1_uniform",
@@ -385,5 +440,36 @@ def run_case(name: str, env_params: dict, agent_configs: Sequence[dict], steps:
385440
),
386441
]
387442

443+
for seed in [0]:
444+
name = f"case9_constant_speed_seed{seed}"
445+
cases.append(
446+
(
447+
name,
448+
{
449+
"dimensionality": "2D",
450+
"boundary_conditions": "solid",
451+
"scale": const_env_size,
452+
"aspect": 1.0,
453+
},
454+
(
455+
{
456+
"params": {
457+
"dt": const_dt,
458+
"speed_mean": const_speed,
459+
"speed_std": 0.0,
460+
"speed_coherence_time": 10.0,
461+
"rotational_velocity_std": np.deg2rad(40.0),
462+
},
463+
"rng_seed": seed,
464+
"init_pos": const_init_pos,
465+
# "init_vel": list(const_init_vel),
466+
},
467+
{},
468+
),
469+
const_steps,
470+
const_dt,
471+
)
472+
)
473+
388474
for case in cases:
389475
run_case(*case)

python/canns_lib/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Project-wide version string."""
22

3-
__version__ = "0.6.1"
3+
__version__ = "0.6.2"

src/spatial/agent.rs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -209,18 +209,21 @@ impl Agent {
209209
match self.dimensionality {
210210
Dimensionality::D1 => {
211211
let mut speed = self.velocity.get(0).copied().unwrap_or(0.0);
212-
speed = ornstein_uhlenbeck(
212+
speed += ornstein_uhlenbeck(
213213
speed,
214214
self.params.speed_mean,
215215
self.params.speed_std,
216216
self.params.speed_coherence_time,
217217
dt,
218218
&mut self.rng,
219219
);
220+
if self.params.speed_std == 0.0 {
221+
speed = self.params.speed_mean;
222+
}
220223
self.velocity = vec![speed];
221224
}
222225
Dimensionality::D2 => {
223-
self.rotational_velocity = ornstein_uhlenbeck(
226+
self.rotational_velocity += ornstein_uhlenbeck(
224227
self.rotational_velocity,
225228
0.0,
226229
self.params.rotational_velocity_std,
@@ -231,14 +234,15 @@ impl Agent {
231234
rotate_vector(&mut self.velocity, self.rotational_velocity * dt);
232235

233236
let speed = vector_norm(&self.velocity);
234-
let mut new_speed = ornstein_uhlenbeck(
235-
speed,
236-
self.params.speed_mean,
237-
self.params.speed_std,
238-
self.params.speed_coherence_time,
239-
dt,
240-
&mut self.rng,
241-
);
237+
let mut new_speed = speed
238+
+ ornstein_uhlenbeck(
239+
speed,
240+
self.params.speed_mean,
241+
self.params.speed_std,
242+
self.params.speed_coherence_time,
243+
dt,
244+
&mut self.rng,
245+
);
242246
if self.params.speed_std == 0.0 {
243247
new_speed = self.params.speed_mean;
244248
}
@@ -272,7 +276,7 @@ impl Agent {
272276
let ratio = drift_ratio.max(1e-6);
273277
let tau = (self.params.speed_coherence_time / ratio).max(1e-6);
274278
for (vel, target_val) in self.velocity.iter_mut().zip(target.iter()) {
275-
*vel = ornstein_uhlenbeck(*vel, *target_val, 0.0, tau, dt, &mut self.rng);
279+
*vel += ornstein_uhlenbeck(*vel, *target_val, 0.0, tau, dt, &mut self.rng);
276280
}
277281
}
278282
}
@@ -432,7 +436,7 @@ impl Agent {
432436
};
433437
}
434438

435-
let mut agent = Self {
439+
let agent = Self {
436440
dimensionality: env_state.dimensionality,
437441
env_state,
438442
params: agent_params,
@@ -452,8 +456,6 @@ impl Agent {
452456
history_rot: Vec::new(),
453457
imported: None,
454458
};
455-
456-
agent.record_history();
457459
Ok(agent)
458460
}
459461

@@ -630,7 +632,6 @@ impl Agent {
630632
self.history_head.clear();
631633
self.history_distance.clear();
632634
self.history_rot.clear();
633-
self.record_history();
634635
}
635636

636637
pub fn set_position(&mut self, position: Vec<f64>) -> PyResult<()> {
@@ -710,7 +711,6 @@ impl Agent {
710711
self.history_head.clear();
711712
self.history_distance.clear();
712713
self.history_rot.clear();
713-
self.record_history();
714714
}
715715
Ok(())
716716
}

src/spatial/utils.rs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,18 @@ pub(crate) fn ornstein_uhlenbeck(
3939
rng: &mut StdRng,
4040
) -> f64 {
4141
if coherence_time <= 0.0 {
42-
return drift;
42+
return drift - current;
4343
}
44+
4445
let theta = 1.0 / coherence_time;
45-
let exp_term = (-theta * dt).exp();
46-
let mean = current * exp_term + drift * (1.0 - exp_term);
47-
let variance = noise_scale.powi(2) * (1.0 - exp_term * exp_term);
48-
let std = variance.max(0.0).sqrt();
49-
let noise: f64 = rng.sample(StandardNormal);
50-
mean + std * noise
46+
let drift_term = theta * (drift - current) * dt;
47+
48+
if noise_scale == 0.0 {
49+
return drift_term;
50+
}
51+
52+
let sigma = ((2.0 * noise_scale.powi(2)) / (coherence_time * dt)).sqrt();
53+
let normal: f64 = rng.sample(StandardNormal);
54+
let diffusion = sigma * normal * dt;
55+
drift_term + diffusion
5156
}

tests/test_spatial_basic.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,13 @@ def test_agent_history_and_seeded_update():
5858
env = spatial.Environment()
5959
agent = spatial.Agent(env, rng_seed=1234)
6060
assert agent.t == pytest.approx(0.0)
61-
assert agent.history_positions().shape == (1, 2)
61+
assert agent.history_positions().shape == (0, 2)
6262
agent.update(dt=0.05)
6363
assert agent.t == pytest.approx(0.05)
64-
assert agent.history_positions().shape == (2, 2)
65-
assert agent.history_velocities().shape == (2, 2)
66-
assert agent.history_head_directions().shape == (2, 2)
67-
assert agent.history_distance_travelled().shape == (2,)
64+
assert agent.history_positions().shape == (1, 2)
65+
assert agent.history_velocities().shape == (1, 2)
66+
assert agent.history_head_directions().shape == (1, 2)
67+
assert agent.history_distance_travelled().shape == (1,)
6868
history = agent.history
6969
assert set(history.keys()) == {
7070
"t",
@@ -74,17 +74,21 @@ def test_agent_history_and_seeded_update():
7474
"rot_vel",
7575
"distance_travelled",
7676
}
77+
for key, value in history.items():
78+
if isinstance(value, list):
79+
assert len(value) == 1
7780
arrays = agent.history_arrays()
78-
assert arrays["pos"].shape == (2, 2)
81+
assert arrays["pos"].shape == (1, 2)
7982

8083

8184
def test_forced_position_and_reset_history():
8285
env = spatial.Environment()
8386
agent = spatial.Agent(env, rng_seed=42)
8487
agent.set_forced_next_position([0.5, 0.5])
8588
assert np.allclose(agent.pos, [0.5, 0.5])
86-
agent.reset_history()
8789
assert agent.history_positions().shape == (1, 2)
90+
agent.reset_history()
91+
assert agent.history_positions().shape == (0, 2)
8892

8993

9094
def test_agent_init_with_explicit_state():
@@ -97,13 +101,20 @@ def test_agent_init_with_explicit_state():
97101
)
98102
assert np.allclose(agent.pos, [0.2, 0.8])
99103
assert np.allclose(agent.velocity, [0.0, 0.1])
100-
assert agent.history_positions()[0, 0] == pytest.approx(0.2)
101-
assert agent.history_velocities()[0, 1] == pytest.approx(0.1)
104+
assert agent.history_positions().shape == (0, 2)
105+
agent.update(dt=0.05)
106+
pos_history = agent.history_positions()
107+
vel_history = agent.history_velocities()
108+
assert pos_history.shape == (1, 2)
109+
assert vel_history.shape == (1, 2)
110+
assert np.allclose(pos_history[-1], np.array(agent.pos))
111+
assert np.allclose(vel_history[-1], np.array(agent.measured_velocity))
102112

103113

104114
def test_agent_set_position_velocity_updates_history():
105115
env = spatial.Environment()
106116
agent = spatial.Agent(env, rng_seed=5)
117+
agent.update(dt=0.05)
107118
agent.set_position([0.3, 0.7])
108119
assert np.allclose(agent.pos, [0.3, 0.7])
109120
assert np.allclose(agent.history_positions()[-1], [0.3, 0.7])

0 commit comments

Comments
 (0)