Skip to content

Commit defaa00

Browse files
Merge pull request #41 from TomGeorge1234/successorfeatures
Successor features branch
2 parents 61d314e + 62efe3b commit defaa00

File tree

14 files changed

+181434
-287
lines changed

14 files changed

+181434
-287
lines changed

.images/demos/SF_development.gif

4.7 MB
Loading

.images/readme/multi_agents.gif

998 KB
Loading

.images/readme/rate_map.png

-861 KB
Loading

README.md

Lines changed: 41 additions & 23 deletions
Large diffs are not rendered by default.

demos/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ In approximate order of complexity, these include:
1111
* [paper_figures.ipynb](./paper_figures.ipynb): (Almost) all plots/animations shown in the paper are produced from this script (plus some major formatting done afterwards in powerpoint). [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TomGeorge1234/RatInABox/blob/dev/demos/paper_figures.ipynb)
1212
* [decoding_position_example.ipynb](./decoding_position_example.ipynb): Postion is decoded from neural data generated with RatInABox using linear regression. Place cells, grid cell and boundary vector cells are compared. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TomGeorge1234/RatInABox/blob/dev/demos/decoding_position_example.ipynb)
1313
* [reinforcement_learning_example.ipynb](./reinforcement_learning_example.ipynb): RatInABox is use to construct, train and visualise a small two-layer network capable of model free reinforcement learning in order to find a reward hidden behind a wall. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TomGeorge1234/RatInABox/blob/dev/demos/reinforcement_learning_example.ipynb)
14+
* [successor_features_example.ipynb](./successor_features_example.ipynb): RatInABox is use to learn successor features under random and biased motion policies. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TomGeorge1234/RatInABox/blob/dev/demos/successor_features_example.ipynb)
1415
* [path_integration_example.ipynb](./path_integration_example.ipynb): RatInABox is use to construct, train and visualise a large multi-layer network capable of learning a "ring attractor" capable of path integrating a position estimate using only velocity inputs. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/TomGeorge1234/RatInABox/blob/dev/demos/path_integration_example.ipynb)
1516

demos/readme_figures.ipynb

Lines changed: 63791 additions & 24 deletions
Large diffs are not rendered by default.

demos/reinforcement_learning_example.ipynb

Lines changed: 67 additions & 128 deletions
Large diffs are not rendered by default.

demos/successor_features_example.ipynb

Lines changed: 117334 additions & 0 deletions
Large diffs are not rendered by default.

ratinabox/Agent.py

Lines changed: 94 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,9 @@ def __init__(self, Environment, params={}):
7676
params (dict, optional). Defaults to {}.
7777
"""
7878
self.Environment = Environment
79+
self.Environment.Agents.append(self)
7980

80-
self.params = copy.deepcopy(__class__.default_params)
81+
self.params = copy.deepcopy(__class__.default_params)
8182
self.params.update(params)
8283

8384
utils.update_class_params(self, self.params, get_all_defaults=True)
@@ -568,11 +569,12 @@ def plot_trajectory(
568569
framerate=10,
569570
fig=None,
570571
ax=None,
572+
plot_all_agents=False,
571573
point_size=15,
572574
decay_point_size=False,
573575
decay_point_timescale=10,
574576
plot_agent=True,
575-
color="#7b699a",
577+
color=None,
576578
alpha=0.7,
577579
xlim=None,
578580
background_color=None,
@@ -587,6 +589,7 @@ def plot_trajectory(
587589
• framerate: how many scatter points / per second of motion to display
588590
• fig, ax: the fig, ax to plot on top of, optional, if not provided used self.Environment.plot_Environment().
589591
This can be used to plot trajectory on top of receptive fields etc.
592+
• plot_all_agents: if True, this will plot the trajectory of all agents in the list Environment.Agents
590593
• point_size: size of scatter points
591594
• decay_point_size: decay trajectory point size over time (recent times = largest)
592595
• decay_point_timescale: if decay_point_size is True, this is the timescale over which sizes decay
@@ -601,83 +604,97 @@ def plot_trajectory(
601604
Returns:
602605
fig, ax
603606
"""
604-
605-
dt = self.dt
606-
t, pos = np.array(self.history["t"]), np.array(self.history["pos"])
607-
if t_end == None:
608-
t_end = t[-1]
609-
startid = np.nanargmin(np.abs(t - (t_start)))
610-
endid = np.nanargmin(np.abs(t - (t_end)))
611-
if self.Environment.dimensionality == "2D":
612-
skiprate = max(1, int((1 / framerate) / dt))
613-
trajectory = pos[startid:endid, :][::skiprate]
614-
if self.Environment.dimensionality == "1D":
615-
skiprate = max(1, int((1 / framerate) / dt))
616-
trajectory = pos[startid:endid][::skiprate]
617-
time = t[startid:endid][::skiprate]
618-
if color is None:
619-
color = ["C0"] * len(time)
620-
elif color == "changing":
621-
trajectory_cmap = matplotlib.colormaps["viridis_r"]
622-
color = [trajectory_cmap(t / len(time)) for t in range(len(time))]
623-
decay_point_size = (
624-
False # if changing colour, may as well show WHOLE trajectory
625-
)
607+
# loop over all agents in the Environment if plot_all_agents is True
608+
if plot_all_agents == False:
609+
agent_list = [self]
610+
if color is None:
611+
color = "#7b699a"
626612
else:
627-
color = [color] * len(time)
628-
629-
if self.Environment.dimensionality == "2D":
630-
fig, ax = self.Environment.plot_environment(fig=fig, ax=ax, autosave=False)
631-
s = point_size * np.ones_like(time)
632-
if decay_point_size == True:
633-
s = point_size * np.exp((time - time[-1]) / decay_point_timescale)
634-
s[(time[-1] - time) > (1.5 * decay_point_timescale)] *= 0
635-
636-
if plot_agent == True:
637-
s[-1] = 40
638-
color[-1] = "r"
639-
640-
ax.scatter(
641-
trajectory[:, 0],
642-
trajectory[:, 1],
643-
s=s,
644-
alpha=alpha,
645-
zorder=1.1,
646-
c=color,
647-
linewidth=0,
648-
)
649-
# #plot the rat? TODO haha probably never gonna do this
650-
# ratpath = os.path.join(
651-
# os.path.abspath(os.path.join(ratinabox.__file__, os.pardir)),
652-
# "data/rat.png",
653-
# )
654-
# rat = plt.imread(ratpath)
655-
# rect = 0.5, 0.4, 0.4, 0.4 # What should these values be?
656-
# newax = fig.add_axes(rect, anchor='NE', zorder=1)
657-
# newax.axis('off')
658-
# newax.imshow(rat)
613+
agent_list = self.Environment.Agents
614+
replot_env = True
615+
for i, self_ in enumerate(agent_list):
616+
dt = self_.dt
617+
t, pos = np.array(self_.history["t"]), np.array(self_.history["pos"])
618+
if t_end == None:
619+
t_end = t[-1]
620+
startid = np.nanargmin(np.abs(t - (t_start)))
621+
endid = np.nanargmin(np.abs(t - (t_end)))
622+
if self_.Environment.dimensionality == "2D":
623+
skiprate = max(1, int((1 / framerate) / dt))
624+
trajectory = pos[startid:endid, :][::skiprate]
625+
if self_.Environment.dimensionality == "1D":
626+
skiprate = max(1, int((1 / framerate) / dt))
627+
trajectory = pos[startid:endid][::skiprate]
628+
time = t[startid:endid][::skiprate]
629+
if color is None:
630+
color_list = [f"C{i}"] * len(time)
631+
elif color == "changing":
632+
trajectory_cmap = matplotlib.colormaps["viridis_r"]
633+
color_list = [trajectory_cmap(t / len(time)) for t in range(len(time))]
634+
decay_point_size = (
635+
False # if changing colour, may as well show WHOLE trajectory
636+
)
637+
else:
638+
color_list = [color] * len(time)
659639

660-
if self.Environment.dimensionality == "1D":
661-
if fig is None and ax is None:
662-
fig, ax = plt.subplots(figsize=(3, 1.5))
663-
ax.scatter(time / 60, trajectory, alpha=alpha, linewidth=0, c=color, s=5)
664-
ax.spines["left"].set_position(("data", t_start / 60))
665-
if axis_labels == True:
666-
ax.set_xlabel("Time / min")
667-
ax.set_ylabel("Position / m")
668-
ax.set_xlim([t_start / 60, t_end / 60])
669-
if xlim is not None:
670-
ax.set_xlim(right=xlim)
671-
672-
ax.set_ylim(bottom=0, top=self.Environment.extent[1])
673-
ax.spines["right"].set_color(None)
674-
ax.spines["top"].set_color(None)
675-
ax.set_xticks([t_start / 60, t_end / 60])
676-
ex = self.Environment.extent
677-
ax.set_yticks([ex[1]])
678-
if background_color is not None:
679-
ax.set_facecolor(background_color)
680-
fig.patch.set_facecolor(background_color)
640+
if self_.Environment.dimensionality == "2D":
641+
if replot_env == True:
642+
fig, ax = self_.Environment.plot_environment(
643+
fig=fig, ax=ax, autosave=False
644+
)
645+
replot_env = False
646+
s = point_size * np.ones_like(time)
647+
if decay_point_size == True:
648+
s = point_size * np.exp((time - time[-1]) / decay_point_timescale)
649+
s[(time[-1] - time) > (1.5 * decay_point_timescale)] *= 0
650+
651+
if plot_agent == True:
652+
s[-1] = 40
653+
color_list[-1] = "r"
654+
655+
ax.scatter(
656+
trajectory[:, 0],
657+
trajectory[:, 1],
658+
s=s,
659+
alpha=alpha,
660+
zorder=1.1,
661+
c=color_list,
662+
linewidth=0,
663+
)
664+
# #plot the rat? TODO haha probably never gonna do this
665+
# ratpath = os.path.join(
666+
# os.path.abspath(os.path.join(ratinabox.__file__, os.pardir)),
667+
# "data/rat.png",
668+
# )
669+
# rat = plt.imread(ratpath)
670+
# rect = 0.5, 0.4, 0.4, 0.4 # What should these values be?
671+
# newax = fig.add_axes(rect, anchor='NE', zorder=1)
672+
# newax.axis('off')
673+
# newax.imshow(rat)
674+
675+
if self_.Environment.dimensionality == "1D":
676+
if fig is None and ax is None:
677+
fig, ax = plt.subplots(figsize=(3, 1.5))
678+
ax.scatter(
679+
time / 60, trajectory, alpha=alpha, linewidth=0, c=color_list, s=5
680+
)
681+
ax.spines["left"].set_position(("data", t_start / 60))
682+
if axis_labels == True:
683+
ax.set_xlabel("Time / min")
684+
ax.set_ylabel("Position / m")
685+
ax.set_xlim([t_start / 60, t_end / 60])
686+
if xlim is not None:
687+
ax.set_xlim(right=xlim)
688+
689+
ax.set_ylim(bottom=0, top=self_.Environment.extent[1])
690+
ax.spines["right"].set_color(None)
691+
ax.spines["top"].set_color(None)
692+
ax.set_xticks([t_start / 60, t_end / 60])
693+
ex = self_.Environment.extent
694+
ax.set_yticks([ex[1]])
695+
if background_color is not None:
696+
ax.set_facecolor(background_color)
697+
fig.patch.set_facecolor(background_color)
681698

682699
ratinabox.utils.save_figure(fig, "trajectory", save=autosave)
683700

@@ -687,11 +704,6 @@ def animate_trajectory(
687704
self, t_start=None, t_end=None, fps=15, speed_up=1, autosave=None, **kwargs
688705
):
689706
"""Returns an animation (anim) of the trajectory, 25fps.
690-
Should be saved using command like
691-
>>> anim.save("./where_to_save/animations.gif",dpi=300)
692-
To display in jupyter notebook, call it:
693-
>>> anim
694-
695707
Args:
696708
t_start: Agent time at which to start animation
697709
t_end (_type_, optional): _description_. Defaults to None.

ratinabox/Environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def __init__(self, params={}):
140140
self.n_object_types = 0
141141
self.object_colormap = "rainbow"
142142
self.plot_objects = True
143+
self.Agents = [] #each new Agent will append itself to this list
143144

144145
# make some other attributes
145146
left = min([c[0] for c in b])

0 commit comments

Comments
 (0)