Skip to content

Commit 873664a

Browse files
committed
slightly tune parameters in segmentation example
1 parent 9ead67c commit 873664a

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

tasks/segmentation_kvasir_seg/eval_segmentation_kvasir_seg.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,30 +52,31 @@ def eval_segmentation_kvasir_seg(hidden_channels: int, gpu: bool, gpu_index: int
5252
num_hidden_channels=hidden_channels,
5353
num_classes=1,
5454
pad_noise=True,
55-
fire_rate=0.8,
55+
fire_rate=0.5,
56+
use_temporal_encoding=True,
5657
)
57-
cascade = CascadeNCA(nca, [8, 4, 2, 1], [70, 20, 10, 5])
58+
cascade = CascadeNCA(nca, [4, 2, 1], [32, 16, 8])
5859

5960
T = A.Compose(
6061
[
61-
A.RandomCrop(300, 300),
6262
A.Resize(256, 256),
63-
A.RandomRotate90(),
64-
A.HorizontalFlip(),
6563
ToTensorV2(),
6664
]
6765
)
6866
dataset = KvasirSegDataset(KVASIR_SEG_PATH, transform=T)
67+
loader = torch.utils.data.DataLoader(
68+
dataset, shuffle=False, batch_size=8, drop_last=True
69+
)
6970

7071
cascade.load_state_dict(
7172
torch.load(
72-
WEIGHTS_PATH / "segmentation_kvasir_seg" / "last_model.pth",
73+
WEIGHTS_PATH / "segmentation_kvasir_seg" / "best_model.pth",
7374
weights_only=True,
7475
)
7576
)
7677

77-
seed = dataset[0][0].unsqueeze(0).to(device)
78-
animator = Animator(cascade, seed, overlay=True)
78+
seed = next(iter(loader))[0].to(device)
79+
animator = Animator(cascade, seed, overlay=True, interval=100, steps=sum(cascade.steps))
7980

8081
out_path = FIGURE_PATH / "segmentation_kvasir_seg.gif"
8182
animator.save(out_path)

tasks/segmentation_kvasir_seg/train_segmentation_kvasir_seg.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,15 @@ def train_segmentation_kvasir_seg(
5050
num_hidden_channels=hidden_channels,
5151
num_classes=1,
5252
pad_noise=True,
53-
fire_rate=0.8,
53+
fire_rate=0.5,
54+
use_temporal_encoding=True,
55+
filter_padding="circular",
5456
)
55-
cascade = CascadeNCA(nca, [8, 4, 2, 1], [70, 20, 10, 5])
57+
cascade = CascadeNCA(nca, [4, 2, 1], [32, 16, 8])
5658

5759
T = A.Compose(
5860
[
59-
A.RandomCrop(300, 300),
61+
A.ColorJitter(),
6062
A.Resize(256, 256),
6163
A.RandomRotate90(),
6264
A.HorizontalFlip(),
@@ -81,7 +83,13 @@ def train_segmentation_kvasir_seg(
8183
val_split, shuffle=True, batch_size=batch_size, drop_last=True
8284
)
8385

84-
trainer = BasicNCATrainer(cascade, WEIGHTS_PATH / "segmentation_kvasir_seg")
86+
trainer = BasicNCATrainer(
87+
cascade,
88+
WEIGHTS_PATH / "segmentation_kvasir_seg",
89+
max_epochs=500,
90+
steps_range=(30, 40),
91+
steps_validation=35,
92+
)
8593
trainer.train(
8694
loader_train,
8795
loader_val,

0 commit comments

Comments
 (0)