Skip to content

Commit 3060889

Browse files
committed
growing emoji example: store gif in task/figures/ dir
1 parent fdb87c4 commit 3060889

File tree

6 files changed

+29
-21
lines changed

6 files changed

+29
-21
lines changed

ncalab/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def pad_input(x: torch.Tensor, nca: "BasicNCAModel", noise: bool = True) -> torc
6161
return x
6262

6363

64-
def NCALab_banner():
64+
def print_NCALab_banner():
6565
"""
6666
Show NCALab banner on terminal.
6767
"""

tasks/growing_emoji/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
# search results
22
*.csv
3+
4+
# generated figures
5+
figures/

tasks/growing_emoji/eval_growing_emoji.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
#!/usr/bin/env python3
22
import os
33
import sys
4+
from pathlib import Path
45

56
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
67
sys.path.append(root_dir)
78

8-
from ncalab import GrowingNCAModel, WEIGHTS_PATH, get_compute_device
9+
from ncalab import (
10+
GrowingNCAModel,
11+
WEIGHTS_PATH,
12+
get_compute_device,
13+
print_NCALab_banner,
14+
fix_random_seed,
15+
)
916

1017
import click
1118

@@ -14,6 +21,10 @@
1421
import matplotlib.pyplot as plt # type: ignore[import-untyped]
1522
import matplotlib.animation as animation # type: ignore[import-untyped]
1623

24+
TASK_PATH = Path(__file__).parent
25+
FIGURE_PATH = TASK_PATH / "figures"
26+
FIGURE_PATH.mkdir(exist_ok=True)
27+
1728

1829
@click.command()
1930
@click.option(
@@ -23,6 +34,9 @@
2334
"--gpu-index", type=int, default=0, help="Index of GPU to use, if --gpu in use."
2435
)
2536
def eval_growing_emoji(gpu: bool, gpu_index: int):
37+
print_NCALab_banner()
38+
fix_random_seed()
39+
2640
device = get_compute_device(f"cuda:{gpu_index}" if gpu else "cpu")
2741

2842
nca = GrowingNCAModel(
@@ -60,7 +74,9 @@ def update(i):
6074
repeat=True,
6175
repeat_delay=3000,
6276
)
63-
animation_fig.save("artwork/growing_emoji.gif")
77+
out_path = FIGURE_PATH / "growing_emoji.gif"
78+
animation_fig.save(out_path)
79+
click.secho(f"Done. You'll find the generated GIF in {out_path} .")
6480

6581

6682
if __name__ == "__main__":

tasks/growing_emoji/finetune_growing_emoji.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
BasicNCATrainer,
2424
WEIGHTS_PATH,
2525
get_compute_device,
26-
NCALab_banner,
26+
print_NCALab_banner,
2727
print_mascot,
2828
fix_random_seed,
2929
)
@@ -32,7 +32,7 @@
3232
def finetune_growing_emoji(
3333
batch_size: int, hidden_channels: int, gpu: bool, gpu_index: int
3434
):
35-
NCALab_banner()
35+
print_NCALab_banner()
3636
print_mascot(
3737
"Things are getting really exciting now!\n"
3838
"You're about to finetune a pre-trained NCA\n"

tasks/growing_emoji/search_growing_emoji.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
GrowingNCADataset,
1616
GrowingNCAModel,
1717
get_compute_device,
18-
NCALab_banner,
18+
print_NCALab_banner,
1919
print_mascot,
2020
ParameterSearch,
2121
ParameterSet,
@@ -35,7 +35,7 @@ def search_growing_emoji(
3535
Main function to run the "growing emoji search" example task.
3636
"""
3737
# Display prologue
38-
NCALab_banner()
38+
print_NCALab_banner()
3939
print_mascot(
4040
"You are about to run a hyperparameter search.\n"
4141
"\n"

tasks/growing_emoji/train_growing_emoji.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,12 @@
1414
sys.path.append(root_dir)
1515

1616
from ncalab import (
17-
AutoStepper,
1817
GrowingNCADataset,
1918
GrowingNCAModel,
2019
BasicNCATrainer,
2120
WEIGHTS_PATH,
2221
get_compute_device,
23-
NCALab_banner,
24-
print_mascot,
22+
print_NCALab_banner,
2523
fix_random_seed,
2624
Pool,
2725
)
@@ -38,20 +36,11 @@ def train_growing_emoji(
3836
:param batch_size [int]: Minibatch size.
3937
:param hidden_channels [int]: Hidden channels the NCA will use.
4038
"""
41-
# Display prologue
42-
NCALab_banner()
43-
print_mascot(
44-
"You are about to run the growing lizard emoji example,\n"
45-
"a true NCA classic! To learn more about it, visit:\n"
46-
"\n"
47-
"https://distill.pub/2020/growing-ca/ \n"
48-
"(Ctrl+click to open URL)\n"
49-
)
50-
print()
39+
print_NCALab_banner()
5140
fix_random_seed()
5241

5342
# Create tensorboard summary writer
54-
writer = SummaryWriter(comment="Growing NCA (Lizard)")
43+
writer = SummaryWriter(comment=" Growing NCA (Lizard)")
5544

5645
# Select device, try to use GPU or fall back to CPU
5746
device = get_compute_device(f"cuda:{gpu_index}" if gpu else "cpu")

0 commit comments

Comments
 (0)