Skip to content

Commit ccf73ca

Browse files
committed
animation: add type hints, make output more appealing
1 parent f3d82fc commit ccf73ca

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

ncalab/visualization/animation.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pathlib import Path
2+
from typing import Tuple
23

34
import matplotlib.pyplot as plt # type: ignore[import-untyped]
45
import matplotlib.animation as animation # type: ignore[import-untyped]
@@ -15,12 +16,13 @@ def __init__(
1516
self,
1617
nca,
1718
seed: torch.Tensor,
18-
steps=100,
19-
interval=100,
20-
repeat=True,
21-
repeat_delay=10000,
22-
overlay=False,
23-
show_timestep=True,
19+
steps: int = 100,
20+
interval: int = 100,
21+
repeat: bool = True,
22+
repeat_delay: int = 10000,
23+
overlay: bool = False,
24+
show_timestep: bool = True,
25+
overlay_color: Tuple[float, float, float] = (0.0, 1.0, 0.0),
2426
):
2527
"""
2628
:param nca: NCA model instance
@@ -46,6 +48,7 @@ def __init__(
4648
fig.set_size_inches(2, 2)
4749
plt.rcParams["font.family"] = "sans-serif"
4850
plt.rcParams["font.sans-serif"] = ["Calibri", "Arial"]
51+
plt.rcParams["axes.titlecolor"] = (0.2, 0.2, 0.2)
4952

5053
# first frame is input image
5154
if nca.immutable_image_channels and not overlay:
@@ -81,7 +84,7 @@ def update(i):
8184
if not nca.immutable_image_channels:
8285
arr = arr[:, :, : nca.num_image_channels]
8386
elif overlay:
84-
color = np.ones((arr.shape[0], arr.shape[1], 3)) * (0.0, 0.0, 1.0)
87+
color = np.ones((arr.shape[0], arr.shape[1], 3)) * overlay_color
8588
A = np.clip(arr[:, :, : nca.num_image_channels], 0, 1)
8689
mask = np.clip(arr[:, :, -nca.num_output_channels :].squeeze(-1), 0, 1)
8790
alpha = 0.5
@@ -95,7 +98,11 @@ def update(i):
9598
arr = np.clip(arr, 0, 1)
9699
im.set_array(arr)
97100
if show_timestep:
98-
ax.set_title(f"Time step {i % steps}".upper())
101+
ax.set_title(
102+
r"$\mathbf{TIME STEP\:"
103+
+ f"{i % steps:3d}".replace(" ", r"\:")
104+
+ r"}$"
105+
)
99106
return (im,)
100107

101108
self.animation_fig = animation.FuncAnimation(

0 commit comments

Comments
 (0)