11//! Python-facing `Agent` class providing accelerated RatInABox-compatible behaviour.
22
33use crate :: spatial:: environment:: Environment ;
4- use crate :: spatial:: state:: { Dimensionality , EnvironmentState } ;
5- use crate :: spatial:: utils:: { normalize_vector, ornstein_uhlenbeck, rotate_vector, vector_norm} ;
4+ use crate :: spatial:: geometry:: { check_line_wall_collision, wall_bounce} ;
5+ use crate :: spatial:: state:: { BoundaryConditions , Dimensionality , EnvironmentState } ;
6+ use crate :: spatial:: utils:: {
7+ normalize_vector, normal_to_rayleigh, ornstein_uhlenbeck, rayleigh_to_normal, rotate_vector,
8+ vector_norm,
9+ } ;
610use ndarray:: Array2 ;
711use numpy:: { IntoPyArray , PyArray1 , PyArray2 } ;
812use pyo3:: exceptions:: PyValueError ;
@@ -180,7 +184,9 @@ pub struct Agent {
180184 position : Vec < f64 > ,
181185 velocity : Vec < f64 > ,
182186 measured_velocity : Vec < f64 > ,
187+ prev_measured_velocity : Vec < f64 > , // 新增:用于计算测量的旋转速度
183188 rotational_velocity : f64 ,
189+ measured_rotational_velocity : f64 , // 新增:实际测量的角速度
184190 head_direction : Vec < f64 > ,
185191 distance_travelled : f64 ,
186192 rng : StdRng ,
@@ -223,6 +229,8 @@ impl Agent {
223229 self . velocity = vec ! [ speed] ;
224230 }
225231 Dimensionality :: D2 => {
232+ // Step 1: Update rotational velocity (angular velocity)
233+ // Using OU process with drift=0 (no preferred rotation direction)
226234 self . rotational_velocity += ornstein_uhlenbeck (
227235 self . rotational_velocity ,
228236 0.0 ,
@@ -231,41 +239,52 @@ impl Agent {
231239 dt,
232240 & mut self . rng ,
233241 ) ;
234- rotate_vector ( & mut self . velocity , self . rotational_velocity * dt) ;
235242
236- let speed = vector_norm ( & self . velocity ) ;
237- let mut new_speed = speed
243+ // Step 2: Rotate velocity vector
244+ let dtheta = self . rotational_velocity * dt;
245+ rotate_vector ( & mut self . velocity , dtheta) ;
246+
247+ // Step 3: Update speed magnitude (using Rayleigh distribution)
248+ let mut speed = vector_norm ( & self . velocity ) ;
249+
250+ // Handle zero speed edge case (matching RatInABox behavior)
251+ if speed < 1e-8 {
252+ self . velocity = vec ! [ 1e-8 , 0.0 ] ; // [1, 0] direction
253+ speed = 1e-8 ;
254+ }
255+
256+ // Transform to normal space for OU update
257+ let normal_var = rayleigh_to_normal ( speed, self . params . speed_mean ) ;
258+
259+ // Apply OU process in normal space
260+ let normal_var_new = normal_var
238261 + ornstein_uhlenbeck (
239- speed ,
240- self . params . speed_mean ,
241- self . params . speed_std ,
262+ normal_var ,
263+ 0.0 , // drift = 0 (mean is 0 in normal space)
264+ 1.0 , // noise_scale = 1 (standard normal)
242265 self . params . speed_coherence_time ,
243266 dt,
244267 & mut self . rng ,
245268 ) ;
269+
270+ // Transform back to Rayleigh space
271+ let mut new_speed = normal_to_rayleigh ( normal_var_new, self . params . speed_mean ) ;
272+
273+ // If speed_std = 0, use deterministic speed
246274 if self . params . speed_std == 0.0 {
247275 new_speed = self . params . speed_mean ;
248276 }
249- new_speed = new_speed. max ( 0.0 ) ;
250277
278+ // Step 4: Scale velocity vector to new magnitude
251279 let current_norm = vector_norm ( & self . velocity ) ;
252- if current_norm < 1e-12 {
253- self . velocity = if self . head_direction . len ( ) == 2 {
254- let mut dir = self . head_direction . clone ( ) ;
255- let _ = normalize_vector ( & mut dir) ;
256- dir. into_iter ( ) . map ( |v| v * new_speed) . collect ( )
257- } else {
258- vec ! [ new_speed, 0.0 ]
259- } ;
260- } else {
261- let scale = if current_norm > 0.0 {
262- new_speed / current_norm
263- } else {
264- 0.0
265- } ;
280+ if current_norm > 1e-12 {
281+ let scale = new_speed / current_norm;
266282 for v in & mut self . velocity {
267283 * v *= scale;
268284 }
285+ } else {
286+ // Speed near zero, reinitialize
287+ self . velocity = vec ! [ new_speed, 0.0 ] ;
269288 }
270289 }
271290 }
@@ -373,6 +392,88 @@ impl Agent {
373392 }
374393 let _ = normalize_vector ( & mut self . head_direction ) ;
375394 }
395+
396+ /// Check and handle wall collisions (fully matching RatInABox elastic reflection physics)
397+ ///
398+ /// Workflow:
399+ /// 1. Add tiny random noise (1e-9) to positions to avoid numerical issues
400+ /// 2. Check if trajectory from prev_pos to current pos crosses any wall
401+ /// 3. If collision detected, reflect velocity and recalculate position
402+ /// 4. Iterate until no collision (random noise ensures fast convergence)
403+ ///
404+ /// # Arguments
405+ /// * `dt` - Time step size
406+ /// * `prev_pos` - Previous position
407+ fn check_and_handle_wall_collisions ( & mut self , dt : f64 , prev_pos : & [ f64 ] ) {
408+ use rand_distr:: { Distribution , StandardNormal } ;
409+
410+ // Ensure 2D environment
411+ if self . dimensionality != Dimensionality :: D2 || prev_pos. len ( ) != 2 || self . position . len ( ) != 2 {
412+ return ;
413+ }
414+
415+ // Get wall list
416+ let walls = & self . env_state . walls ;
417+ if walls. is_empty ( ) {
418+ return ; // No walls
419+ }
420+
421+ // Infinite loop until no collision
422+ // Random noise ensures fast convergence in practice
423+ loop {
424+ // Add tiny normal distribution random noise (std 1e-9)
425+ // This is RatInABox's key trick for:
426+ // 1. Avoiding numerical issues from perfect parallelism
427+ // 2. Breaking perfect symmetry at wall corners
428+ // 3. Ensuring small perturbation each iteration for eventual convergence
429+ // Note: Must use normal distribution (can be positive or negative), not uniform
430+ let n1: f64 = self . rng . sample ( StandardNormal ) ;
431+ let n2: f64 = self . rng . sample ( StandardNormal ) ;
432+ let n3: f64 = self . rng . sample ( StandardNormal ) ;
433+ let n4: f64 = self . rng . sample ( StandardNormal ) ;
434+
435+ let noise_prev = [
436+ prev_pos[ 0 ] + n1 * 1e-9 ,
437+ prev_pos[ 1 ] + n2 * 1e-9 ,
438+ ] ;
439+ let noise_curr = [
440+ self . position [ 0 ] + n3 * 1e-9 ,
441+ self . position [ 1 ] + n4 * 1e-9 ,
442+ ] ;
443+
444+ // Check if trajectory intersects with walls
445+ let collision = check_line_wall_collision ( & noise_prev, & noise_curr, & walls) ;
446+
447+ if let Some ( wall_idx) = collision {
448+ let wall = walls[ wall_idx] ;
449+
450+ // Reflect velocity vector
451+ let vel_array: [ f64 ; 2 ] = [ self . velocity [ 0 ] , self . velocity [ 1 ] ] ;
452+ let reflected = wall_bounce ( & vel_array, & wall) ;
453+ self . velocity [ 0 ] = reflected[ 0 ] ;
454+ self . velocity [ 1 ] = reflected[ 1 ] ;
455+
456+ // Reduce speed to 0.5 * speed_mean
457+ // This prevents agent from immediately colliding with the same wall again
458+ let speed = vector_norm ( & self . velocity ) ;
459+ if speed > 1e-12 {
460+ let target_speed = 0.5 * self . params . speed_mean ;
461+ let scale = target_speed / speed;
462+ self . velocity [ 0 ] *= scale;
463+ self . velocity [ 1 ] *= scale;
464+ }
465+
466+ // Recalculate position with new velocity
467+ self . position [ 0 ] = prev_pos[ 0 ] + self . velocity [ 0 ] * dt;
468+ self . position [ 1 ] = prev_pos[ 1 ] + self . velocity [ 1 ] * dt;
469+
470+ // Continue checking (may hit another wall)
471+ } else {
472+ // No collision, safe to exit
473+ return ;
474+ }
475+ }
476+ }
376477}
377478
378479#[ pymethods]
@@ -443,8 +544,10 @@ impl Agent {
443544 time : 0.0 ,
444545 position,
445546 velocity : velocity. clone ( ) ,
446- measured_velocity : velocity,
547+ measured_velocity : velocity. clone ( ) ,
548+ prev_measured_velocity : velocity, // 初始化为初始速度
447549 rotational_velocity : 0.0 ,
550+ measured_rotational_velocity : 0.0 , // 初始化为 0
448551 head_direction,
449552 distance_travelled : 0.0 ,
450553 rng,
@@ -526,18 +629,56 @@ impl Agent {
526629 proposed[ idx] = base_position[ idx] + vel * step;
527630 }
528631
529- self . position = self
530- . env_state
531- . project_position ( Some ( & base_position) , proposed) ;
632+ // Set proposed position first
633+ self . position = proposed;
634+
635+ // Handle wall collisions (only for 2D Solid boundaries, matching RatInABox elastic reflection physics)
636+ // Important: Must use prev_position (true position at start of update)
637+ // not base_position (position after wall_repulsion modification)
638+ if self . dimensionality == Dimensionality :: D2
639+ && self . env_state . boundary_conditions == BoundaryConditions :: Solid
640+ {
641+ self . check_and_handle_wall_collisions ( step, & prev_position) ;
642+ // Note: For 2D Solid boundaries, wall collision handling is complete.
643+ // Don't call project_position or it will overwrite collision handling results.
644+ } else {
645+ // 1D, Periodic, or other cases: use boundary projection (including periodic wrapping)
646+ self . position = self
647+ . env_state
648+ . project_position ( Some ( & base_position) , self . position . clone ( ) ) ;
649+ }
532650
533651 let displacement_vec: Vec < f64 > = self
534652 . position
535653 . iter ( )
536654 . zip ( prev_position. iter ( ) )
537655 . map ( |( new, old) | new - old)
538656 . collect ( ) ;
657+
658+ // Calculate measured velocity
539659 self . measured_velocity = displacement_vec. iter ( ) . map ( |delta| delta / step) . collect ( ) ;
540660
661+ // Calculate measured rotational velocity (2D only)
662+ if self . dimensionality == Dimensionality :: D2 && self . measured_velocity . len ( ) == 2 && self . prev_measured_velocity . len ( ) == 2 {
663+ let angle_now = self . measured_velocity [ 1 ] . atan2 ( self . measured_velocity [ 0 ] ) ;
664+ let angle_before = self . prev_measured_velocity [ 1 ] . atan2 ( self . prev_measured_velocity [ 0 ] ) ;
665+
666+ // Normalize angle difference to [-π, π]
667+ let mut angle_diff = angle_now - angle_before;
668+ const PI : f64 = std:: f64:: consts:: PI ;
669+ while angle_diff > PI {
670+ angle_diff -= 2.0 * PI ;
671+ }
672+ while angle_diff < -PI {
673+ angle_diff += 2.0 * PI ;
674+ }
675+
676+ self . measured_rotational_velocity = angle_diff / step;
677+ }
678+
679+ // 保存当前测量速度供下一步使用
680+ self . prev_measured_velocity = self . measured_velocity . clone ( ) ;
681+
541682 let displacement = vector_norm ( & displacement_vec) ;
542683 self . distance_travelled += displacement;
543684
@@ -578,6 +719,11 @@ impl Agent {
578719 self . distance_travelled
579720 }
580721
722+ #[ getter]
723+ pub fn measured_rotational_velocity ( & self ) -> f64 {
724+ self . measured_rotational_velocity
725+ }
726+
581727 pub fn params ( & self , py : Python < ' _ > ) -> PyResult < Py < PyDict > > {
582728 self . params . to_pydict ( py)
583729 }
0 commit comments