Skip to content

Commit 1ec591a

Browse files
Merge pull request #128 from colleenjg/cjg-dev
Progress bar for creating animations
2 parents 9f111b1 + 959225f commit 1ec591a

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

ratinabox/Agent.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,7 @@ def animate_trajectory(
843843
t_end=None,
844844
fps=15,
845845
speed_up=5, #by default the animation is 5x faster than real time
846+
progress_bar=False,
846847
autosave=None,
847848
**kwargs
848849
):
@@ -852,6 +853,7 @@ def animate_trajectory(
852853
t_end (_type_, optional): _description_. Defaults to None.
853854
fps: frames per second of end video
854855
speed_up: #times real speed animation should come out at
856+
progress_bar (bool): if True, a progress bar will be shown as the animation is created. Defaults to False.
855857
autosave (bool): whether to automatical try and save this. Defaults to None in which case looks for global constant ratinabox.autosave_plots
856858
kwargs: passed to trajectory plotting function (chuck anything you wish in here). A particularly useful kwarg is 'additional_plot_func': any function which takes a fig, ax and t as input. The animation wll be passed through this each time after plotting the trajectory, use it to modify your animations however you like
857859
@@ -895,13 +897,18 @@ def animate_(i, fig, ax, t_start, t_max, speed_up, dt, kwargs):
895897
t_start=0, t_end=10 * self.dt, xlim=t_end / 60, autosave=False, **kwargs
896898
)
897899

900+
frames = int((t_end - t_start) / (dt * speed_up))
901+
if progress_bar:
902+
from tqdm import tqdm
903+
frames = tqdm(range(frames), position=0, leave=True)
904+
898905
from matplotlib import animation
899906

900907
anim = matplotlib.animation.FuncAnimation(
901908
fig,
902909
animate_,
903910
interval=1000 * dt,
904-
frames=int((t_end - t_start) / (dt * speed_up)),
911+
frames=frames,
905912
blit=False,
906913
fargs=(fig, ax, t_start, t_end, speed_up, dt, kwargs),
907914
)

ratinabox/Neurons.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,7 @@ def animate_rate_timeseries(
698698
chosen_neurons="all",
699699
fps=15,
700700
speed_up=1,
701+
progress_bar=False,
701702
autosave=None,
702703
**kwargs,
703704
):
@@ -710,8 +711,10 @@ def animate_rate_timeseries(
710711
Args:
711712
• t_end (_type_, optional): _description_. Defaults to None.
712713
• chosen_neurons: Which neurons to plot. string "10" or 10 will plot ten of them, "all" will plot all of them, "12rand" will plot 12 random ones. A list like [1,4,5] will plot cells indexed 1, 4 and 5. Defaults to "all".
714+
• fps: frames per second of end video. Defaults to 15.
715+
• speed_up: #times real speed animation should come out at. Defaults to 1.
716+
• progress_bar: if True, a progress bar will be shown as the animation is created. Default to False.
713717
714-
• speed_up: #times real speed animation should come out at.
715718
716719
Returns:
717720
animation
@@ -753,13 +756,18 @@ def animate_(i, fig, ax, chosen_neurons, t_start, t_max, dt, speed_up):
753756
**kwargs,
754757
)
755758

759+
frames = int((t_end - t_start) / (dt * speed_up))
760+
if progress_bar:
761+
from tqdm import tqdm
762+
frames = tqdm(range(frames), position=0, leave=True)
763+
756764
from matplotlib import animation
757765

758766
anim = matplotlib.animation.FuncAnimation(
759767
fig,
760768
animate_,
761769
interval=1000 * dt,
762-
frames=int((t_end - t_start) / (dt * speed_up)),
770+
frames=frames,
763771
blit=False,
764772
fargs=(fig, ax, chosen_neurons, t_start, t_end, dt, speed_up),
765773
)

0 commit comments

Comments
 (0)