Skip to content

Commit 88e734e

Browse files
committed
add record() method to basicNCA model, create Animator class
1 parent 6f65582 commit 88e734e

File tree

4 files changed

+88
-66
lines changed

4 files changed

+88
-66
lines changed

ncalab/models/basicNCA.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import Any, Dict, Optional, Tuple
2+
from typing import Any, Dict, Optional, Tuple, List
33

44
import numpy as np
55

@@ -336,6 +336,29 @@ def predict(self, image: torch.Tensor, steps: int = 100) -> Prediction:
336336
prediction = self.forward(x, steps=steps)
337337
return prediction
338338

339+
def record(self, image: torch.Tensor, steps: int = 100) -> List[Prediction]:
340+
"""
341+
Record predictions for all time steps and return the resulting
342+
sequence of predictions.
343+
344+
:param image [torch.Tensor]: Input image, BCWH.
345+
346+
:returns [List[Prediction]]: List of Prediction objects.
347+
"""
348+
assert steps >= 1
349+
assert image.shape[1] <= self.num_channels
350+
self.eval()
351+
sequence = []
352+
with torch.no_grad():
353+
x = image.clone()
354+
x = pad_input(x, self, noise=self.pad_noise)
355+
x = self.prepare_input(x)
356+
for _ in range(steps):
357+
prediction = self.forward(x, steps=1)
358+
sequence.append(prediction)
359+
x = prediction.output_image
360+
return sequence
361+
339362
def validate(
340363
self, image: torch.Tensor, label: torch.Tensor, steps: int
341364
) -> Optional[Tuple[Dict[str, float], Prediction]]:

ncalab/models/growingNCA.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -74,40 +74,29 @@ def validate(
7474
"""
7575
return None
7676

77+
def make_seed(self, width: int, height: int) -> torch.Tensor:
78+
x = torch.zeros((1, self.num_channels, width, height)).to(self.device)
79+
# set seed in center
80+
x[:, 3:, width // 2, height // 2] = 1.0
81+
return x
82+
7783
def grow(
78-
self, width: int, height: int, steps: int = 100, save_steps=False
79-
) -> np.ndarray | List[np.ndarray]:
84+
self, seed: torch.Tensor, steps: int = 100
85+
) -> List[np.ndarray]:
8086
"""
81-
Run the growth process and return the resulting output image.
87+
Run the growth process and return the resulting output sequence.
8288
83-
:param width [int]: Output image width.
84-
:param height [int]: Output image height.
89+
:param seed [torch.Tensor]: Seed image, can be generated through make_seed.
8590
:param steps [int]: Number of inference steps. Defaults to 100.
8691
87-
:returns [np.ndarray]: Image channels of the output image.
92+
:returns [List[np.ndarray]]: Sequence of output images.
8893
"""
8994
with torch.no_grad():
90-
# TODO make use of autostepper, if available
9195
self.eval()
92-
x = torch.zeros((1, self.num_channels, width, height)).to(self.device)
93-
# set seed in center
94-
x[:, 3:, width // 2, height // 2] = 1.0
95-
96-
if save_steps:
97-
step_outs = []
98-
for _ in range(steps):
99-
prediction = self.forward(x, steps=1) # type: ignore[assignment]
100-
step_outs.append(
101-
np.clip(
102-
prediction.image_channels.squeeze(0).detach().cpu().numpy(),
103-
0,
104-
1,
105-
)
106-
)
107-
x = prediction.output_image
108-
return step_outs
109-
else:
110-
prediction = self.forward(x, steps=steps) # type: ignore[assignment]
111-
out_np = prediction.image_channels.detach().cpu().numpy().squeeze(0)
112-
out_np = np.clip(out_np, 0, 1)
113-
return out_np
96+
output = []
97+
sequence = self.record(seed, steps=steps)
98+
for prediction in sequence:
99+
output.append(
100+
np.clip(prediction.image_channels_np.squeeze(0), 0, 1)
101+
)
102+
return output

ncalab/visualization/animation.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,65 @@
22

33
import matplotlib.pyplot as plt # type: ignore[import-untyped]
44
import matplotlib.animation as animation # type: ignore[import-untyped]
5+
import numpy as np
6+
import torch
57

68

7-
class NCAAnimator:
8-
def __init__(self, nca, x, steps=100):
9-
""" """
9+
class Animator:
10+
def __init__(
11+
self,
12+
nca,
13+
seed: torch.Tensor,
14+
steps=100,
15+
interval=100,
16+
repeat=True,
17+
repeat_delay=3000,
18+
):
19+
nca.eval()
20+
1021
fig, ax = plt.subplots()
1122
fig.set_size_inches(2, 2)
12-
im = ax.imshow(x[0, :3].permute(0, 2, 3, 1).cpu(), animated=True)
23+
24+
# first frame is input image
25+
if nca.immutable_image_channels:
26+
first_frame = seed[0, -nca.num_output_channels :]
27+
else:
28+
first_frame = seed[0, : nca.num_image_channels]
29+
first_frame = first_frame.permute(1, 2, 0).detach().cpu().numpy()
30+
im = ax.imshow(
31+
first_frame,
32+
animated=True,
33+
)
34+
35+
predictions = nca.record(seed, steps)
36+
images = []
37+
for prediction in predictions:
38+
if nca.immutable_image_channels:
39+
output_image = prediction.output_channels_np[0]
40+
else:
41+
output_image = prediction.image_channels_np[0]
42+
output_image = output_image.transpose(1, 2, 0)
43+
output_image = np.clip(output_image, 0, 1)
44+
images.append(output_image)
45+
1346
ax.set_axis_off()
1447
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
1548
plt.margins(0, 0)
1649
plt.tight_layout()
1750

1851
def update(i):
19-
nonlocal x
20-
x = nca(x) # --> BWHC
21-
im.set_array(x)
52+
nonlocal images
53+
im.set_array(images[i])
2254
return (im,)
2355

2456
self.animation_fig = animation.FuncAnimation(
2557
fig,
2658
update,
2759
frames=steps,
28-
interval=100,
60+
interval=interval,
2961
blit=True,
30-
repeat=True,
31-
repeat_delay=3000,
62+
repeat=repeat,
63+
repeat_delay=repeat_delay,
3264
)
3365

3466
def save(self, path: str | Path):

tasks/growing_emoji/eval_growing_emoji.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
get_compute_device,
1212
print_NCALab_banner,
1313
fix_random_seed,
14+
Animator,
1415
)
1516

1617
import click
1718

1819
import torch
1920

20-
import matplotlib.pyplot as plt # type: ignore[import-untyped]
21-
import matplotlib.animation as animation # type: ignore[import-untyped]
2221

2322
TASK_PATH = Path(__file__).parent
2423
FIGURE_PATH = TASK_PATH / "figures"
@@ -53,33 +52,12 @@ def eval_growing_emoji(gpu: bool, gpu_index: int):
5352
weights_only=True,
5453
)
5554
)
56-
nca.eval()
5755

58-
images = nca.grow(48, 48, steps=100, save_steps=True)
56+
seed = nca.make_seed(48, 48)
57+
animator = Animator(nca, seed)
5958

60-
fig, ax = plt.subplots()
61-
fig.set_size_inches(2, 2)
62-
im = ax.imshow(images[0].transpose(1, 2, 0), animated=True)
63-
ax.set_axis_off()
64-
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
65-
plt.margins(0, 0)
66-
plt.tight_layout()
67-
68-
def update(i):
69-
im.set_array(images[i].transpose(1, 2, 0))
70-
return (im,)
71-
72-
animation_fig = animation.FuncAnimation(
73-
fig,
74-
update,
75-
frames=len(images),
76-
interval=10,
77-
blit=True,
78-
repeat=True,
79-
repeat_delay=3000,
80-
)
8159
out_path = FIGURE_PATH / "growing_emoji.gif"
82-
animation_fig.save(out_path)
60+
animator.save(out_path)
8361
click.secho(f"Done. You'll find the generated GIF in {out_path} .")
8462

8563

0 commit comments

Comments
 (0)