Skip to content

Commit 1390c86

Browse files
committed
python<=3.8 bug fix
1 parent 383a821 commit 1390c86

File tree

4 files changed

+53
-30
lines changed

4 files changed

+53
-30
lines changed

README.md

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# RatInABox ![Tests](https://github.com/RatInABox-Lab/RatInABox/actions/workflows/test.yml/badge.svg) [![PyPI version](https://badge.fury.io/py/ratinabox.svg)](https://badge.fury.io/py/ratinabox) [![Downloads](https://static.pepy.tech/badge/ratinabox)](https://pepy.tech/project/ratinabox)<img align="right" src=".images/readme/logo.png" width=150>
22

3-
`RatInABox` (see [paper](https://www.biorxiv.org/content/10.1101/2022.08.10.503541v5)) is a toolkit for generating synthetic behaviour and neural data for spatially and/or velocity selective cell types in complex continuous environments.
3+
`RatInABox` (see [paper](https://elifesciences.org/articles/85274)) is a toolkit for generating synthetic behaviour and neural data for spatially and/or velocity selective cell types in complex continuous environments.
44

55
[**Install**](#installing-and-importing) | [**Demos**](#get-started) | [**Features**](#feature-run-down) | [**Contributions and Questions**](#contribute) | [**Cite**](#cite)
66

@@ -443,26 +443,30 @@ Questions? Just ask! Ideally via opening an issue so others can see the answer t
443443
Thanks to all contributors so far:
444444
![GitHub Contributors Image](https://contrib.rocks/image?repo=RatInABox-Lab/RatInABox)
445445

446-
## Cite [![](http://img.shields.io/badge/bioRxiv-10.1101/2022.08.10.503541-B31B1B.svg)](https://doi.org/10.1101/2022.08.10.503541)
446+
## Cite
447447

448448
If you use `RatInABox` in your research or educational material, please cite the work as follows:
449+
449450
Bibtex:
450451
```
451-
@article{ratinabox2022,
452-
doi = {10.1101/2022.08.10.503541},
453-
url = {https://doi.org/10.1101%2F2022.08.10.503541},
454-
year = 2022,
455-
month = {aug},
456-
publisher = {Cold Spring Harbor Laboratory},
457-
author = {Tom M George and William de Cothi and Claudia Clopath and Kimberly Stachenfeld and Caswell Barry},
458-
title = {{RatInABox}: A toolkit for modelling locomotion and neuronal activity in continuous environments}
452+
@article{George2024,
453+
title = {RatInABox, a toolkit for modelling locomotion and neuronal activity in continuous environments},
454+
volume = {13},
455+
ISSN = {2050-084X},
456+
url = {http://dx.doi.org/10.7554/eLife.85274},
457+
DOI = {10.7554/elife.85274},
458+
journal = {eLife},
459+
publisher = {eLife Sciences Publications, Ltd},
460+
author = {George, Tom M and Rastogi, Mehul and de Cothi, William and Clopath, Claudia and Stachenfeld, Kimberly and Barry, Caswell},
461+
year = {2024},
462+
month = feb
459463
}
460464
```
461465

462466
Formatted:
463467
```
464-
Tom M George, William de Cothi, Claudia Clopath, Kimberly Stachenfeld, Caswell Barry. "RatInABox: A toolkit for modelling locomotion and neuronal activity in continuous environments" (2022).
468+
Tom M George, Mehul Rastogi, William de Cothi, Claudia Clopath, Kimberly Stachenfeld, Caswell Barry. "RatInABox, a toolkit for modelling locomotion and neuronal activity in continuous environments" (2024), eLife, https://doi.org/10.7554/eLife.85274 .
465469
```
466-
The research paper corresponding to the above citation can be found [here](https://www.biorxiv.org/content/10.1101/2022.08.10.503541v4).
470+
The research paper corresponding to the above citation can be found [here](https://elifesciences.org/articles/85274).
467471

468472

ratinabox/Environment.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
import warnings
12-
from typing import Union
12+
from typing import Union, List
1313

1414
from ratinabox import utils
1515
from ratinabox.Agent import Agent
@@ -87,7 +87,7 @@ def __init__(self, params={}):
8787
utils.update_class_params(self, self.params, get_all_defaults=True)
8888
utils.check_params(self, params.keys())
8989

90-
self.Agents : list[Agent] = [] # each new Agent will append itself to this list
90+
self.Agents : List[Agent] = [] # each new Agent will append itself to this list
9191
self.agents_dict = {} # this is a dictionary which allows you to lookup a agent by name
9292

9393
if self.dimensionality == "1D":
@@ -206,17 +206,17 @@ def get_all_default_params(cls, verbose=False):
206206
return all_default_params
207207

208208

209-
def agent_lookup(self, agent_names:Union[str, list[str]] = None) -> list[Agent]:
209+
def agent_lookup(self, agent_names:Union[str, List[str]] = None) -> List[Agent]:
210210
'''
211211
This function will lookup a agent by name and return it. This assumes that the agent has been
212212
added to the Environment.agents list and that each agent object has a unique name associated with it.
213213
214214
215215
Args:
216-
agent_names (str, list[str]): the name of the agent you want to lookup.
216+
agent_names (str, List[str]): the name of the agent you want to lookup.
217217
218218
Returns:
219-
agents (list[Agent]): a list of agents that match the agent_names. If agent_names is a string, then a list of length 1 is returned. If agent_names is None, then None is returned
219+
agents (List[Agent]): a list of agents that match the agent_names. If agent_names is a string, then a list of length 1 is returned. If agent_names is None, then None is returned
220220
221221
'''
222222

@@ -226,7 +226,7 @@ def agent_lookup(self, agent_names:Union[str, list[str]] = None) -> list[Agent]
226226
if isinstance(agent_names, str):
227227
agent_names = [agent_names]
228228

229-
agents: list[Agent] = []
229+
agents: List[Agent] = []
230230

231231
for agent_name in agent_names:
232232
agent = self._agent_lookup(agent_name)
@@ -846,7 +846,7 @@ def apply_boundary_conditions(self, pos):
846846
returns new_pos
847847
TODO update this so if pos is in one of the holes the Agent is returned to the ~nearest legal location inside the Environment
848848
"""
849-
if self.check_if_position_is_in_environment(pos) is True: return
849+
if self.check_if_position_is_in_environment(pos) is True: return pos
850850

851851
if self.dimensionality == "1D":
852852
if self.boundary_conditions == "periodic":

ratinabox/Neurons.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,9 @@ def plot_rate_map(
475475
interpolation="bicubic", # smooths rate maps but this does slow down the plotting a bit
476476
)
477477
elif method == "history":
478-
bin_size = kwargs.get("bin_size", 0.05)
478+
default_2D_bin_size = 0.05
479+
bin_size = kwargs.get("bin_size", default_2D_bin_size)
480+
print(f"Using bin size of {bin_size} for rate map calculation")
479481
rate_timeseries_ = rate_timeseries[chosen_neurons[i], :]
480482
rate_map, zero_bins = utils.bin_data_for_histogramming(
481483
data=pos,
@@ -537,25 +539,32 @@ def plot_rate_map(
537539

538540
# PLOT 1D
539541
elif self.Agent.Environment.dimensionality == "1D":
542+
zero_bins = None
540543
if method == "groundtruth":
541544
rate_maps = rate_maps[chosen_neurons, :]
542545
x = self.Agent.Environment.flattened_discrete_coords[:, 0]
543546
if method == "history":
544547
ex = self.Agent.Environment.extent
548+
default_1D_bin_size = 0.01
549+
bin_size = kwargs.get("bin_size", default_1D_bin_size)
545550
pos_ = pos[:, 0]
546551
rate_maps = []
547552
for neuron_id in chosen_neurons:
548-
rate_map, x = utils.bin_data_for_histogramming(
553+
(rate_map, x, zero_bins) = utils.bin_data_for_histogramming(
549554
data=pos_,
550555
extent=ex,
551-
dx=0.01,
556+
dx=bin_size,
552557
weights=rate_timeseries[neuron_id, :],
553558
norm_by_bincount=True,
559+
return_zero_bins=True,
554560
)
555-
x, rate_map = utils.interpolate_and_smooth(x, rate_map, sigma=0.03)
561+
resolution_increase = 10
562+
x, rate_map = utils.interpolate_and_smooth(x, rate_map, sigma=0.01, resolution_increase=resolution_increase)
556563
rate_maps.append(rate_map)
564+
zero_bins = np.repeat(zero_bins, resolution_increase)
557565
rate_maps = np.array(rate_maps)
558566

567+
559568
if fig is None and ax is None:
560569
fig, ax = plt.subplots(
561570
figsize=(
@@ -569,7 +578,7 @@ def plot_rate_map(
569578

570579
if method != "neither":
571580
fig, ax = utils.mountain_plot(
572-
X=x, NbyX=rate_maps, color=self.color, fig=fig, ax=ax, **kwargs
581+
X=x, NbyX=rate_maps, color=self.color, nan_bins=zero_bins, fig=fig, ax=ax, **kwargs
573582
)
574583

575584
if spikes is True:

ratinabox/utils.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -368,20 +368,22 @@ def ornstein_uhlenbeck(dt, x, drift=0.0, noise_scale=0.2, coherence_time=5.0):
368368
return dx
369369

370370

371-
def interpolate_and_smooth(x, y, sigma=None):
371+
def interpolate_and_smooth(x, y, sigma=None, resolution_increase=10):
372372
"""Interpolates with cublic spline x and y to 10x resolution then smooths these with a gaussian kernel of width sigma.
373373
Currently this only works for 1-dimensional x.
374374
Args:
375375
x
376376
y
377377
sigma
378+
resolution_increase
378379
Returns (x_new,y_new)
379380
"""
380381
from scipy.ndimage.filters import gaussian_filter1d
381382
from scipy.interpolate import interp1d
382383

383384
y_cubic = interp1d(x, y, kind="cubic")
384-
x_new = np.arange(x[0], x[-1], (x[1] - x[0]) / 10)
385+
# x_new = np.arange(x[0], x[-1], (x[1] - x[0]) / resolution_increase)
386+
x_new = np.linspace(x[0], x[-1], len(x) * resolution_increase)
385387
y_interpolated = y_cubic(x_new)
386388
if sigma is not None:
387389
y_smoothed = gaussian_filter1d(
@@ -541,16 +543,20 @@ def bin_data_for_histogramming(data, extent, dx, weights=None, norm_by_bincount=
541543
542544
Returns:
543545
(heatmap,bin_centres): if 1D
544-
(heatmap): if 2D
546+
(heatmap): if 2D --> you should be able ot infer the bin centres from the extent and dx you passed
547+
in either case if return_zero_bins is True, the zero_bins array is also returned as the last element of the tuple
545548
"""
546549
if len(extent) == 2: # dimensionality = "1D"
547550
bins = np.arange(extent[0], extent[1] + dx, dx)
548551
heatmap, xedges = np.histogram(data, bins=bins, weights=weights)
549552
if norm_by_bincount:
550553
bincount = np.histogram(data, bins=bins)[0]
551-
bincount[bincount == 0] = 1
554+
zero_bins = (bincount == 0)
555+
bincount[zero_bins] = 1
552556
heatmap = heatmap / bincount
553557
centres = (xedges[1:] + xedges[:-1]) / 2
558+
if return_zero_bins:
559+
return (heatmap, centres, zero_bins)
554560
return (heatmap, centres)
555561

556562
elif len(extent) == 4: # dimensionality = "2D"
@@ -578,6 +584,7 @@ def mountain_plot(
578584
xlabel="",
579585
ylabel="",
580586
xlim=None,
587+
nan_bins=None,
581588
fig=None,
582589
ax=None,
583590
norm_by="max",
@@ -599,6 +606,7 @@ def mountain_plot(
599606
xlabel (str, optional): x axis label. Defaults to "".
600607
ylabel (str, optional): y axis label. Defaults to "".
601608
xlim (_type_, optional): fix xlim to this is desired. Defaults to None.
609+
nan_bins (array, optional): Optionally pass a boolean array of the same shape as X which is True where you want to plot a gap in the mountain plot. Defaults to None (ie skipped).
602610
fig (_type_, optional): fig to plot over if desired. Defaults to None.
603611
ax (_type_, optional): ax to plot on if desider. Defaults to None.
604612
norm_by: what to normalise each line of the mountainplot by.
@@ -630,11 +638,13 @@ def mountain_plot(
630638
)
631639

632640
zorder = 1
641+
X_ = X.copy()
642+
if nan_bins is not None: X_[nan_bins] = np.nan
633643
for i in range(len(NbyX)):
634-
ax.plot(X, NbyX[i] + i + 1, c=c, zorder=zorder, lw=linewidth)
644+
ax.plot(X_, NbyX[i] + i + 1, c=c, zorder=zorder, lw=linewidth)
635645
zorder -= 0.01
636646
ax.fill_between(
637-
X, NbyX[i] + i + 1, i + 1, color=fc, zorder=zorder, alpha=0.8, linewidth=0
647+
X_, NbyX[i] + i + 1, i + 1, color=fc, zorder=zorder, alpha=0.8, linewidth=0
638648
)
639649
zorder -= 0.01
640650
ax.spines["left"].set_bounds(1, len(NbyX))

0 commit comments

Comments
 (0)