Skip to content

Commit a1df436

Browse files
authored
Tune augmentations with CLI and config for contrastive models (#126)
* fix initial crop size * make representation a sub-package * re-export ScaleIntensityRangePercentilesd for jsonargparse * add cli module for contrastive models with triplet data * update docstring
1 parent b63d4f7 commit a1df436

File tree

4 files changed

+72
-2
lines changed

4 files changed

+72
-2
lines changed

viscy/cli/contrastive_triplet.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import logging
2+
import os
3+
from datetime import datetime
4+
5+
import torch
6+
from jsonargparse import lazy_instance
7+
from lightning.pytorch.cli import LightningCLI
8+
from lightning.pytorch.loggers import TensorBoardLogger
9+
10+
from viscy.data.triplet import TripletDataModule
11+
from viscy.light.engine import ContrastiveModule
12+
13+
14+
class ContrastiveLightningCLI(LightningCLI):
15+
"""Lightning CLI with default logger."""
16+
17+
def add_arguments_to_parser(self, parser):
18+
parser.set_defaults(
19+
{
20+
"trainer.logger": lazy_instance(
21+
TensorBoardLogger,
22+
save_dir="",
23+
version=datetime.now().strftime(r"%Y%m%d-%H%M%S"),
24+
log_graph=True,
25+
)
26+
}
27+
)
28+
29+
30+
def main():
31+
"""Main Lightning CLI entry point."""
32+
log_level = os.getenv("VISCY_LOG_LEVEL", logging.INFO)
33+
logging.getLogger("lightning.pytorch").setLevel(log_level)
34+
torch.set_float32_matmul_precision("high")
35+
_ = ContrastiveLightningCLI(
36+
model_class=ContrastiveModule, datamodule_class=TripletDataModule
37+
)
38+
39+
40+
if __name__ == "__main__":
41+
main()

viscy/data/triplet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _setup_fit(self, dataset_settings: dict):
262262
self.train_dataset = TripletDataset(
263263
positions=train_positions,
264264
tracks_tables=train_tracks_tables,
265-
initial_yx_patch_size=self.yx_patch_size,
265+
initial_yx_patch_size=self.initial_yx_patch_size,
266266
anchor_transform=no_aug_transform,
267267
positive_transform=augment_transform,
268268
negative_transform=augment_transform,
@@ -273,7 +273,7 @@ def _setup_fit(self, dataset_settings: dict):
273273
self.val_dataset = TripletDataset(
274274
positions=val_positions,
275275
tracks_tables=val_tracks_tables,
276-
initial_yx_patch_size=self.yx_patch_size,
276+
initial_yx_patch_size=self.initial_yx_patch_size,
277277
anchor_transform=no_aug_transform,
278278
positive_transform=augment_transform,
279279
negative_transform=augment_transform,

viscy/representation/__init__.py

Whitespace-only changes.

viscy/transforms.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
RandomizableTransform,
1212
RandScaleIntensityd,
1313
RandWeightedCropd,
14+
ScaleIntensityRangePercentilesd,
1415
)
1516
from monai.transforms.transform import Randomizable
1617
from numpy.random.mtrand import RandomState as RandomState
@@ -127,6 +128,34 @@ def __init__(
127128
)
128129

129130

131+
class ScaleIntensityRangePercentilesd(ScaleIntensityRangePercentilesd):
132+
def __init__(
133+
self,
134+
keys: Union[Sequence[str], str],
135+
lower: float,
136+
upper: float,
137+
b_min: float | None,
138+
b_max: float | None,
139+
clip: bool = False,
140+
relative: bool = False,
141+
channel_wise: bool = False,
142+
dtype: Union[Sequence[str], str] = None,
143+
allow_missing_keys: bool = False,
144+
):
145+
super().__init__(
146+
keys=keys,
147+
lower=lower,
148+
upper=upper,
149+
b_min=b_min,
150+
b_max=b_max,
151+
clip=clip,
152+
relative=relative,
153+
channel_wise=channel_wise,
154+
dtype=dtype,
155+
allow_missing_keys=allow_missing_keys,
156+
)
157+
158+
130159
class NormalizeSampled(MapTransform):
131160
"""
132161
Normalize the sample

0 commit comments

Comments
 (0)