Skip to content

Commit 32bec88

Browse files
Merge pull request #126 from colleenjg/cjg-dev
Speed mean and std in 2D
2 parents 5ce2a27 + 50b8a51 commit 32bec88

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

ratinabox/Agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ class Agent:
7272
# These defaults are fit to match data from Sargolini et al. (2016)
7373
# also given are the parameter names as refered to in the methods section of the paper
7474
"speed_coherence_time": 0.7, # time over which speed decoheres, τ_v1 & τ_v2
75-
"speed_mean": 0.08, # mean of speed, σ_v2 μ_v1
76-
"speed_std": 0.08, # std of speed (meaningless in 2D where speed ~rayleigh), σ_v1
75+
"speed_mean": 0.08, # mean of speed in 1D environment, μ_v1 (and in 2D environment if speed_std is set to 0). Otherwise, std of speed ~rayleigh in 2D environment, σ_v2. Can be computed based on a target speed mean using utils.get_rayleigh_sigma().
76+
"speed_std": 0.08, # std of speed in 1D environment, σ_v1 (ignored in 2D where speed ~rayleigh, unless set to 0. If set to 0, speed_mean is used.)
7777
"rotational_velocity_coherence_time": 0.08, # time over which speed decoheres, τ_w
7878
"rotational_velocity_std": (120 * (np.pi / 180)), # std of rotational speed, σ_w wall following parameter
7979
"head_direction_smoothing_timescale" : 0.15, # timescale over which head direction is smoothed (head dir = normalised smoothed velocity).

ratinabox/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,18 @@ def interpolate_and_smooth(x, y, sigma=None, resolution_increase=10):
394394
return x_new, y_interpolated
395395

396396

397+
def get_rayleigh_sigma(mean):
398+
"""Returns the standard deviation (sigma) of a Rayleigh distribution based on its mean"""
399+
sigma = mean / np.sqrt(np.pi / 2)
400+
return sigma
401+
402+
403+
def get_rayleigh_mean(sigma):
404+
"""Returns the mean of a Rayleigh distribution based on its standard deviation (sigma)."""
405+
mean = sigma * np.sqrt(np.pi / 2)
406+
return mean
407+
408+
397409
def normal_to_rayleigh(x, sigma=1):
398410
"""Converts a normally distributed variable (mean 0, var 1) to a rayleigh distributed variable (sigma)"""
399411
x = stats.norm.cdf(x) # norm to uniform)

0 commit comments

Comments
 (0)