Skip to content

Commit 7b56d72

Browse files
committed
medmnist classification tasks: share parameters between training and evaluation
1 parent 9db5b20 commit 7b56d72

File tree

6 files changed

+76
-52
lines changed

6 files changed

+76
-52
lines changed

tasks/class_medmnist/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# MedMNIST Classification
2+
3+
This example uses a classification NCA model.

tasks/class_medmnist/eval_class_bloodmnist.py

Whitespace-only changes.

tasks/class_medmnist/eval_class_dermamnist.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,63 +5,71 @@
55
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
66
sys.path.append(root_dir)
77

8-
from ncalab import ClassificationNCAModel, WEIGHTS_PATH, get_compute_device, pad_input
8+
from ncalab import ClassificationNCAModel, get_compute_device
99

1010
import click
1111

12-
from medmnist import PathMNIST # type: ignore[import-untyped]
12+
from medmnist import INFO, DermaMNIST # type: ignore[import-untyped]
1313

1414
import torch # type: ignore[import-untyped]
1515
from torchvision import transforms # type: ignore[import-untyped]
1616
from torchvision.transforms import v2 # type: ignore[import-untyped]
1717

18-
import numpy as np
19-
2018
import torchmetrics
2119
import torchmetrics.classification
2220

2321
from tqdm import tqdm
2422

23+
from train_class_dermamnist import (
24+
pad_noise,
25+
alive_mask,
26+
use_temporal_encoding,
27+
fire_rate,
28+
WEIGHTS_PATH,
29+
)
2530

2631
T = transforms.Compose(
2732
[
2833
v2.ToImage(),
2934
v2.ToDtype(torch.float, scale=True),
3035
v2.ConvertImageDtype(dtype=torch.float32),
31-
transforms.RandomHorizontalFlip(),
32-
transforms.RandomVerticalFlip(),
36+
transforms.Normalize((0.5,), (0.225,)),
3337
]
3438
)
3539

3640

37-
def eval_selfclass_pathmnist(
41+
def eval_selfclass_dermamnist(
3842
hidden_channels: int,
3943
gpu,
4044
gpu_index,
4145
):
4246
device = get_compute_device(f"cuda:{gpu_index}" if gpu else "cpu")
4347

44-
dataset_test = PathMNIST(split="test", download=True, transform=T)
48+
dataset_test = DermaMNIST(split="test", download=True, transform=T)
4549
loader_test = torch.utils.data.DataLoader(
4650
dataset_test, shuffle=False, batch_size=32
4751
)
4852

49-
num_classes = 9
53+
num_classes = len(INFO["dermamnist"]["label"])
5054
nca = ClassificationNCAModel(
5155
device,
5256
num_image_channels=3,
5357
num_hidden_channels=hidden_channels,
5458
hidden_size=128,
5559
num_classes=num_classes,
56-
use_alive_mask=False,
57-
fire_rate=0.5,
60+
use_alive_mask=alive_mask,
61+
fire_rate=fire_rate,
62+
use_temporal_encoding=use_temporal_encoding,
63+
pad_noise=pad_noise,
5864
)
5965
nca.load_state_dict(
60-
torch.load(WEIGHTS_PATH / "selfclass_pathmnist.best.pth", weights_only=True)
66+
torch.load(
67+
WEIGHTS_PATH / "classification_dermamnist" / "best_model.pth",
68+
weights_only=True,
69+
)
6170
)
6271

63-
model_parameters = filter(lambda p: p.requires_grad, nca.parameters())
64-
params = sum([np.prod(p.size()) for p in model_parameters])
72+
params = nca.num_trainable_parameters()
6573
print(f"Trainable parameters: {params}")
6674
print(f"That is {4 * params / 1000} kB")
6775

@@ -70,24 +78,24 @@ def eval_selfclass_pathmnist(
7078

7179
macro_acc = torchmetrics.classification.MulticlassAccuracy(
7280
average="macro", num_classes=num_classes
73-
)
81+
).to(device)
7482
micro_acc = torchmetrics.classification.MulticlassAccuracy(
7583
average="micro", num_classes=num_classes
76-
)
84+
).to(device)
7785
macro_auc = torchmetrics.classification.MulticlassAUROC(
7886
average="macro",
7987
num_classes=num_classes,
80-
)
88+
).to(device)
8189
micro_f1 = torchmetrics.classification.MulticlassF1Score(
8290
average="micro", num_classes=num_classes
83-
)
91+
).to(device)
8492
for sample in tqdm(iter(loader_test)):
8593
x, y = sample
86-
x = pad_input(x, nca, noise=True)
8794
x = x.float().to(device)
8895
steps = 72
8996
y_prob = nca.classify(x, steps, reduce=False)
9097
y = y.squeeze().to(device)
98+
9199
macro_acc.update(y_prob, y)
92100
micro_acc.update(y_prob, y)
93101
macro_auc.update(y_prob, y)
@@ -104,15 +112,15 @@ def eval_selfclass_pathmnist(
104112

105113

106114
@click.command()
107-
@click.option("--hidden-channels", "-H", default=20, type=int)
115+
@click.option("--hidden-channels", "-H", default=30, type=int)
108116
@click.option(
109117
"--gpu/--no-gpu", is_flag=True, default=True, help="Try using the GPU if available."
110118
)
111119
@click.option(
112120
"--gpu-index", type=int, default=0, help="Index of GPU to use, if --gpu in use."
113121
)
114122
def main(hidden_channels, gpu: bool, gpu_index: int):
115-
eval_selfclass_pathmnist(
123+
eval_selfclass_dermamnist(
116124
hidden_channels=hidden_channels,
117125
gpu=gpu,
118126
gpu_index=gpu_index,

tasks/class_medmnist/eval_class_pathmnist.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,35 @@
55
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
66
sys.path.append(root_dir)
77

8-
from ncalab import ClassificationNCAModel, WEIGHTS_PATH, get_compute_device, pad_input
8+
from ncalab import ClassificationNCAModel, get_compute_device
99

1010
import click
1111

12-
from medmnist import PathMNIST # type: ignore[import-untyped]
12+
from medmnist import INFO, PathMNIST # type: ignore[import-untyped]
1313

1414
import torch # type: ignore[import-untyped]
1515
from torchvision import transforms # type: ignore[import-untyped]
1616
from torchvision.transforms import v2 # type: ignore[import-untyped]
1717

18-
import numpy as np
19-
2018
import torchmetrics
2119
import torchmetrics.classification
2220

2321
from tqdm import tqdm
2422

23+
from train_class_pathmnist import (
24+
pad_noise,
25+
alive_mask,
26+
use_temporal_encoding,
27+
fire_rate,
28+
WEIGHTS_PATH,
29+
)
2530

2631
T = transforms.Compose(
2732
[
2833
v2.ToImage(),
2934
v2.ToDtype(torch.float, scale=True),
3035
v2.ConvertImageDtype(dtype=torch.float32),
31-
transforms.RandomHorizontalFlip(),
32-
transforms.RandomVerticalFlip(),
36+
transforms.Normalize((0.5,), (0.225,)),
3337
]
3438
)
3539

@@ -46,22 +50,26 @@ def eval_selfclass_pathmnist(
4650
dataset_test, shuffle=False, batch_size=32
4751
)
4852

49-
num_classes = 9
53+
num_classes = len(INFO["pathmnist"]["label"])
5054
nca = ClassificationNCAModel(
5155
device,
5256
num_image_channels=3,
5357
num_hidden_channels=hidden_channels,
5458
hidden_size=128,
5559
num_classes=num_classes,
56-
use_alive_mask=False,
57-
fire_rate=0.5,
60+
use_alive_mask=alive_mask,
61+
fire_rate=fire_rate,
62+
pad_noise=pad_noise,
63+
use_temporal_encoding=use_temporal_encoding,
5864
)
5965
nca.load_state_dict(
60-
torch.load(WEIGHTS_PATH / "selfclass_pathmnist.best.pth", weights_only=True)
66+
torch.load(
67+
WEIGHTS_PATH / "classification_pathmnist" / "best_model.pth",
68+
weights_only=True,
69+
)
6170
)
6271

63-
model_parameters = filter(lambda p: p.requires_grad, nca.parameters())
64-
params = sum([np.prod(p.size()) for p in model_parameters])
72+
params = nca.num_trainable_parameters()
6573
print(f"Trainable parameters: {params}")
6674
print(f"That is {4 * params / 1000} kB")
6775

@@ -70,24 +78,24 @@ def eval_selfclass_pathmnist(
7078

7179
macro_acc = torchmetrics.classification.MulticlassAccuracy(
7280
average="macro", num_classes=num_classes
73-
)
81+
).to(device)
7482
micro_acc = torchmetrics.classification.MulticlassAccuracy(
7583
average="micro", num_classes=num_classes
76-
)
84+
).to(device)
7785
macro_auc = torchmetrics.classification.MulticlassAUROC(
7886
average="macro",
7987
num_classes=num_classes,
80-
)
88+
).to(device)
8189
micro_f1 = torchmetrics.classification.MulticlassF1Score(
8290
average="micro", num_classes=num_classes
83-
)
91+
).to(device)
8492
for sample in tqdm(iter(loader_test)):
8593
x, y = sample
86-
x = pad_input(x, nca, noise=True)
8794
x = x.float().to(device)
8895
steps = 72
8996
y_prob = nca.classify(x, steps, reduce=False)
9097
y = y.squeeze().to(device)
98+
9199
macro_acc.update(y_prob, y)
92100
micro_acc.update(y_prob, y)
93101
macro_auc.update(y_prob, y)

tasks/class_medmnist/train_class_dermamnist.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
WEIGHTS_PATH = TASK_PATH / "weights"
2828
WEIGHTS_PATH.mkdir(exist_ok=True)
2929

30+
gradient_clipping = False
31+
pad_noise = True
32+
alive_mask = False
33+
use_temporal_encoding = True
34+
fire_rate = 0.8
35+
3036

3137
def train_class_dermamnist(
3238
batch_size: int,
@@ -36,11 +42,6 @@ def train_class_dermamnist(
3642
):
3743
print_NCALab_banner()
3844

39-
gradient_clipping = False
40-
pad_noise = True
41-
alive_mask = False
42-
use_temporal_encoding = True
43-
4445
comment = "DermaMNIST"
4546
comment += f"_hidden_{hidden_channels}"
4647
comment += f"_gc_{gradient_clipping}"
@@ -106,7 +107,7 @@ def train_class_dermamnist(
106107
num_hidden_channels=hidden_channels,
107108
num_classes=7,
108109
use_alive_mask=alive_mask,
109-
fire_rate=0.8,
110+
fire_rate=fire_rate,
110111
pad_noise=pad_noise,
111112
use_temporal_encoding=use_temporal_encoding,
112113
)

tasks/class_medmnist/train_class_pathmnist.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import click
1616

17-
from medmnist import PathMNIST # type: ignore[import-untyped]
17+
from medmnist import INFO, PathMNIST # type: ignore[import-untyped]
1818

1919
import torch # type: ignore[import-untyped]
2020
from torchvision import transforms # type: ignore[import-untyped]
@@ -25,6 +25,12 @@
2525
WEIGHTS_PATH = TASK_PATH / "weights"
2626
WEIGHTS_PATH.mkdir(exist_ok=True)
2727

28+
gradient_clipping = False
29+
pad_noise = True
30+
alive_mask = False
31+
use_temporal_encoding = True
32+
fire_rate = 0.8
33+
2834

2935
def train_class_pathmnist(
3036
batch_size: int,
@@ -33,10 +39,6 @@ def train_class_pathmnist(
3339
gpu_index: int,
3440
lambda_activity: float,
3541
):
36-
gradient_clipping = False
37-
pad_noise = True
38-
alive_mask = False
39-
4042
writer = SummaryWriter(
4143
comment=f"L.act_{lambda_activity}_c.hidden_{hidden_channels}_gc_{gradient_clipping}_noise_{pad_noise}_AM_{alive_mask}"
4244
)
@@ -50,6 +52,7 @@ def train_class_pathmnist(
5052
v2.ConvertImageDtype(dtype=torch.float32),
5153
transforms.RandomHorizontalFlip(),
5254
transforms.RandomVerticalFlip(),
55+
transforms.Normalize((0.5,), (0.225,)),
5356
]
5457
)
5558

@@ -63,20 +66,21 @@ def train_class_pathmnist(
6366
dataset_val, shuffle=True, batch_size=32, drop_last=True
6467
)
6568

69+
num_classes = len(INFO["pathmnist"]["label"])
6670
nca = ClassificationNCAModel(
6771
device,
6872
num_image_channels=3,
6973
num_hidden_channels=hidden_channels,
70-
num_classes=9,
74+
num_classes=num_classes,
7175
use_alive_mask=alive_mask,
72-
fire_rate=0.5,
76+
fire_rate=fire_rate,
7377
lambda_activity=lambda_activity,
7478
filter_padding="circular",
7579
pad_noise=pad_noise,
7680
)
7781
trainer = BasicNCATrainer(
7882
nca,
79-
WEIGHTS_PATH / "selfclass_pathmnist",
83+
WEIGHTS_PATH / "classification_pathmnist",
8084
batch_repeat=2,
8185
max_epochs=100,
8286
gradient_clipping=gradient_clipping,

0 commit comments

Comments
 (0)