11from pathlib import Path
2+ from typing import Tuple
23
34import matplotlib .pyplot as plt # type: ignore[import-untyped]
45import matplotlib .animation as animation # type: ignore[import-untyped]
@@ -15,12 +16,13 @@ def __init__(
1516 self ,
1617 nca ,
1718 seed : torch .Tensor ,
18- steps = 100 ,
19- interval = 100 ,
20- repeat = True ,
21- repeat_delay = 10000 ,
22- overlay = False ,
23- show_timestep = True ,
19+ steps : int = 100 ,
20+ interval : int = 100 ,
21+ repeat : bool = True ,
22+ repeat_delay : int = 10000 ,
23+ overlay : bool = False ,
24+ show_timestep : bool = True ,
25+ overlay_color : Tuple [float , float , float ] = (0.0 , 1.0 , 0.0 ),
2426 ):
2527 """
2628 :param nca: NCA model instance
@@ -46,6 +48,7 @@ def __init__(
4648 fig .set_size_inches (2 , 2 )
4749 plt .rcParams ["font.family" ] = "sans-serif"
4850 plt .rcParams ["font.sans-serif" ] = ["Calibri" , "Arial" ]
51+ plt .rcParams ["axes.titlecolor" ] = (0.2 , 0.2 , 0.2 )
4952
5053 # first frame is input image
5154 if nca .immutable_image_channels and not overlay :
@@ -81,7 +84,7 @@ def update(i):
8184 if not nca .immutable_image_channels :
8285 arr = arr [:, :, : nca .num_image_channels ]
8386 elif overlay :
84- color = np .ones ((arr .shape [0 ], arr .shape [1 ], 3 )) * ( 0.0 , 0.0 , 1.0 )
87+ color = np .ones ((arr .shape [0 ], arr .shape [1 ], 3 )) * overlay_color
8588 A = np .clip (arr [:, :, : nca .num_image_channels ], 0 , 1 )
8689 mask = np .clip (arr [:, :, - nca .num_output_channels :].squeeze (- 1 ), 0 , 1 )
8790 alpha = 0.5
@@ -95,7 +98,11 @@ def update(i):
9598 arr = np .clip (arr , 0 , 1 )
9699 im .set_array (arr )
97100 if show_timestep :
98- ax .set_title (f"Time step { i % steps } " .upper ())
101+ ax .set_title (
102+ r"$\mathbf{TIME STEP\:"
103+ + f"{ i % steps :3d} " .replace (" " , r"\:" )
104+ + r"}$"
105+ )
99106 return (im ,)
100107
101108 self .animation_fig = animation .FuncAnimation (
0 commit comments