Skip to content

Commit 5b6e70c

Browse files
committed
example tasks: organize imports and ignore E402 for ncalab imports
1 parent 1ae62d0 commit 5b6e70c

19 files changed

+182
-207
lines changed

tasks/class_cifar10/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# CIFAR10 Classification
2+
3+
We currently achieve about 75% test set accuracy with NCAs.

tasks/class_cifar10/eval_class_cifar10.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
from torchvision import transforms # type: ignore[import-untyped]
1111
from torchvision.transforms import v2 # type: ignore[import-untyped]
1212
from tqdm import tqdm
13-
1413
from train_class_cifar10 import (
15-
WEIGHTS_PATH,
1614
TASK_PATH,
15+
WEIGHTS_PATH,
1716
alive_mask,
1817
default_hidden_channels,
1918
fire_rate,
@@ -25,7 +24,11 @@
2524
sys.path.append(root_dir)
2625

2726

28-
from ncalab import ClassificationNCAModel, get_compute_device, fix_random_seed
27+
from ncalab import ( # noqa: E402
28+
ClassificationNCAModel,
29+
fix_random_seed,
30+
get_compute_device,
31+
)
2932

3033
T = transforms.Compose(
3134
[

tasks/class_cifar10/train_class_cifar10.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
import click
77
import numpy as np # type: ignore[import-untyped]
88
import torch # type: ignore[import-untyped]
9+
import torchvision # type: ignore[import-untyped]
910
from torch.utils.data.sampler import SubsetRandomSampler # type: ignore[import-untyped]
1011
from torch.utils.tensorboard import SummaryWriter # type: ignore[import-untyped]
11-
import torchvision # type: ignore[import-untyped]
1212
from torchvision import transforms # type: ignore[import-untyped]
1313
from torchvision.transforms import v2 # type: ignore[import-untyped]
1414

1515
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
1616
sys.path.append(root_dir)
17-
from ncalab import (
17+
from ncalab import ( # noqa: E402
1818
BasicNCATrainer,
1919
ClassificationNCAModel,
2020
get_compute_device,
@@ -61,18 +61,20 @@ def train_class_cifar10(
6161
transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
6262
transforms.RandomHorizontalFlip(),
6363
transforms.RandomVerticalFlip(),
64-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
64+
transforms.RandomCrop(32, padding=4),
65+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
6566
]
6667
)
6768
T_val = transforms.Compose(
6869
[
6970
v2.ToImage(),
7071
v2.ToDtype(torch.float, scale=True),
7172
v2.ConvertImageDtype(dtype=torch.float32),
72-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
73+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
7374
]
7475
)
7576

77+
# Split train dataset into train and validation
7678
train_dataset = torchvision.datasets.CIFAR10(
7779
root=TASK_PATH / "data",
7880
train=True,
@@ -86,7 +88,6 @@ def train_class_cifar10(
8688
download=True,
8789
transform=T_val,
8890
)
89-
9091
indices = list(range(len(train_dataset)))
9192
split = int(np.floor(0.1 * len(train_dataset)))
9293
np.random.shuffle(indices)
@@ -122,8 +123,9 @@ def train_class_cifar10(
122123
"ship",
123124
"truck",
124125
]
125-
126126
num_classes = len(class_names)
127+
128+
# Create NCA model for classification
127129
nca = ClassificationNCAModel(
128130
device,
129131
num_image_channels=3,
@@ -135,11 +137,12 @@ def train_class_cifar10(
135137
use_temporal_encoding=use_temporal_encoding,
136138
class_names=class_names,
137139
)
140+
# Train the NCA model
138141
trainer = BasicNCATrainer(
139142
nca,
140143
WEIGHTS_PATH / "classification_cifar10",
141144
batch_repeat=2,
142-
max_epochs=100,
145+
max_epochs=500,
143146
gradient_clipping=gradient_clipping,
144147
steps_range=(32, 48),
145148
steps_validation=42,

tasks/class_medmnist/eval_class_bloodmnist.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,30 @@
11
#!/usr/bin/env python3
2-
import sys
32
import os
4-
5-
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
6-
sys.path.append(root_dir)
7-
8-
from ncalab import ClassificationNCAModel, get_compute_device
3+
import sys
94

105
import click
11-
12-
from medmnist import INFO, BloodMNIST # type: ignore[import-untyped]
13-
146
import torch # type: ignore[import-untyped]
15-
from torchvision import transforms # type: ignore[import-untyped]
16-
from torchvision.transforms import v2 # type: ignore[import-untyped]
17-
187
import torchmetrics
198
import torchmetrics.classification
20-
9+
from medmnist import INFO, BloodMNIST # type: ignore[import-untyped]
10+
from torchvision import transforms # type: ignore[import-untyped]
11+
from torchvision.transforms import v2 # type: ignore[import-untyped]
2112
from tqdm import tqdm
22-
2313
from train_class_bloodmnist import (
24-
pad_noise,
25-
alive_mask,
26-
use_temporal_encoding,
27-
fire_rate,
2814
WEIGHTS_PATH,
15+
alive_mask,
2916
default_hidden_channels,
17+
fire_rate,
18+
pad_noise,
19+
use_temporal_encoding,
3020
)
3121

22+
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
23+
sys.path.append(root_dir)
24+
25+
26+
from ncalab import ClassificationNCAModel, get_compute_device # noqa: E402
27+
3228
T = transforms.Compose(
3329
[
3430
v2.ToImage(),

tasks/class_medmnist/eval_class_dermamnist.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,29 @@
11
#!/usr/bin/env python3
2-
import sys
32
import os
4-
5-
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
6-
sys.path.append(root_dir)
7-
8-
from ncalab import ClassificationNCAModel, get_compute_device
3+
import sys
94

105
import click
11-
12-
from medmnist import INFO, DermaMNIST # type: ignore[import-untyped]
13-
146
import torch # type: ignore[import-untyped]
15-
from torchvision import transforms # type: ignore[import-untyped]
16-
from torchvision.transforms import v2 # type: ignore[import-untyped]
17-
187
import torchmetrics
198
import torchmetrics.classification
20-
9+
from medmnist import INFO, DermaMNIST # type: ignore[import-untyped]
10+
from torchvision import transforms # type: ignore[import-untyped]
11+
from torchvision.transforms import v2 # type: ignore[import-untyped]
2112
from tqdm import tqdm
22-
2313
from train_class_dermamnist import (
24-
pad_noise,
25-
alive_mask,
26-
use_temporal_encoding,
27-
fire_rate,
2814
WEIGHTS_PATH,
15+
alive_mask,
2916
default_hidden_channels,
17+
fire_rate,
18+
pad_noise,
19+
use_temporal_encoding,
3020
)
3121

22+
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
23+
sys.path.append(root_dir)
24+
25+
from ncalab import ClassificationNCAModel, get_compute_device # noqa: E402
26+
3227
T = transforms.Compose(
3328
[
3429
v2.ToImage(),

tasks/class_medmnist/eval_class_pathmnist.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,29 @@
11
#!/usr/bin/env python3
2-
import sys
32
import os
4-
5-
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
6-
sys.path.append(root_dir)
7-
8-
from ncalab import ClassificationNCAModel, get_compute_device
3+
import sys
94

105
import click
11-
12-
from medmnist import INFO, PathMNIST # type: ignore[import-untyped]
13-
146
import torch # type: ignore[import-untyped]
15-
from torchvision import transforms # type: ignore[import-untyped]
16-
from torchvision.transforms import v2 # type: ignore[import-untyped]
17-
187
import torchmetrics
198
import torchmetrics.classification
20-
9+
from medmnist import INFO, PathMNIST # type: ignore[import-untyped]
10+
from torchvision import transforms # type: ignore[import-untyped]
11+
from torchvision.transforms import v2 # type: ignore[import-untyped]
2112
from tqdm import tqdm
22-
2313
from train_class_pathmnist import (
24-
pad_noise,
25-
alive_mask,
26-
use_temporal_encoding,
27-
fire_rate,
2814
WEIGHTS_PATH,
15+
alive_mask,
2916
default_hidden_channels,
17+
fire_rate,
18+
pad_noise,
19+
use_temporal_encoding,
3020
)
3121

22+
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
23+
sys.path.append(root_dir)
24+
25+
from ncalab import ClassificationNCAModel, get_compute_device # noqa: E402
26+
3227
T = transforms.Compose(
3328
[
3429
v2.ToImage(),

tasks/class_medmnist/train_class_bloodmnist.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
#!/usr/bin/env python3
2-
from pathlib import Path
3-
import sys
42
import os
3+
import sys
4+
from pathlib import Path
5+
6+
import click
7+
import torch # type: ignore[import-untyped]
8+
from medmnist import INFO, BloodMNIST # type: ignore[import-untyped]
9+
from torch.utils.tensorboard import SummaryWriter # type: ignore[import-untyped]
10+
from torchvision import transforms # type: ignore[import-untyped]
11+
from torchvision.transforms import v2 # type: ignore[import-untyped]
512

613
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
714
sys.path.append(root_dir)
815

9-
from ncalab import (
10-
ClassificationNCAModel,
16+
17+
from ncalab import ( # noqa: E402
1118
BasicNCATrainer,
19+
ClassificationNCAModel,
1220
get_compute_device,
1321
)
1422

15-
import click
16-
17-
from medmnist import INFO, BloodMNIST # type: ignore[import-untyped]
18-
19-
import torch # type: ignore[import-untyped]
20-
from torchvision import transforms # type: ignore[import-untyped]
21-
from torchvision.transforms import v2 # type: ignore[import-untyped]
22-
from torch.utils.tensorboard import SummaryWriter # type: ignore[import-untyped]
23-
2423
TASK_PATH = Path(__file__).parent.resolve()
2524
WEIGHTS_PATH = TASK_PATH / "weights"
2625
WEIGHTS_PATH.mkdir(exist_ok=True)

tasks/class_medmnist/train_class_dermamnist.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,26 @@
11
#!/usr/bin/env python3
2-
from pathlib import Path
3-
import sys
42
import os
3+
import sys
4+
from pathlib import Path
5+
6+
import click
7+
import numpy as np
8+
import torch # type: ignore[import-untyped]
9+
from medmnist import INFO, DermaMNIST # type: ignore[import-untyped]
10+
from torch.utils.tensorboard import SummaryWriter # type: ignore[import-untyped]
11+
from torchvision import transforms # type: ignore[import-untyped]
12+
from torchvision.transforms import v2 # type: ignore[import-untyped]
513

614
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
715
sys.path.append(root_dir)
816

9-
from ncalab import (
10-
ClassificationNCAModel,
17+
from ncalab import ( # noqa: E402
1118
BasicNCATrainer,
19+
ClassificationNCAModel,
1220
get_compute_device,
1321
print_NCALab_banner,
1422
)
1523

16-
import numpy as np
17-
import click
18-
19-
from medmnist import INFO, DermaMNIST # type: ignore[import-untyped]
20-
21-
import torch # type: ignore[import-untyped]
22-
from torchvision import transforms # type: ignore[import-untyped]
23-
from torchvision.transforms import v2 # type: ignore[import-untyped]
24-
from torch.utils.tensorboard import SummaryWriter # type: ignore[import-untyped]
25-
2624
TASK_PATH = Path(__file__).parent.resolve()
2725
WEIGHTS_PATH = TASK_PATH / "weights"
2826
WEIGHTS_PATH.mkdir(exist_ok=True)

tasks/class_medmnist/train_class_pathmnist.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
11
#!/usr/bin/env python3
2-
from pathlib import Path
3-
import sys
42
import os
3+
import sys
4+
from pathlib import Path
5+
6+
import click
7+
import torch # type: ignore[import-untyped]
8+
from medmnist import INFO, PathMNIST # type: ignore[import-untyped]
9+
from torch.utils.tensorboard import SummaryWriter # type: ignore[import-untyped]
10+
from torchvision import transforms # type: ignore[import-untyped]
11+
from torchvision.transforms import v2 # type: ignore[import-untyped]
512

613
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
714
sys.path.append(root_dir)
815

9-
from ncalab import (
10-
ClassificationNCAModel,
16+
from ncalab import ( # noqa: E402
1117
BasicNCATrainer,
18+
ClassificationNCAModel,
1219
get_compute_device,
1320
)
1421

15-
import click
16-
17-
from medmnist import INFO, PathMNIST # type: ignore[import-untyped]
18-
19-
import torch # type: ignore[import-untyped]
20-
from torchvision import transforms # type: ignore[import-untyped]
21-
from torchvision.transforms import v2 # type: ignore[import-untyped]
22-
from torch.utils.tensorboard import SummaryWriter # type: ignore[import-untyped]
23-
2422
TASK_PATH = Path(__file__).parent.resolve()
2523
WEIGHTS_PATH = TASK_PATH / "weights"
2624
WEIGHTS_PATH.mkdir(exist_ok=True)

tasks/growing_emoji/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@
22

33
- Loosely based on the [Growing Cellular Automata](https://distill.pub/2020/growing-ca/) publication that sparked major interest in Neural Cellular Automata.
44
- An evaluation script is provided to generate and show an image.
5+
6+
7+
## train_growing_emoji.py

0 commit comments

Comments
 (0)