File tree Expand file tree Collapse file tree 4 files changed +72
-2
lines changed Expand file tree Collapse file tree 4 files changed +72
-2
lines changed Original file line number Diff line number Diff line change
1
+ import logging
2
+ import os
3
+ from datetime import datetime
4
+
5
+ import torch
6
+ from jsonargparse import lazy_instance
7
+ from lightning .pytorch .cli import LightningCLI
8
+ from lightning .pytorch .loggers import TensorBoardLogger
9
+
10
+ from viscy .data .triplet import TripletDataModule
11
+ from viscy .light .engine import ContrastiveModule
12
+
13
+
14
+ class ContrastiveLightningCLI (LightningCLI ):
15
+ """Lightning CLI with default logger."""
16
+
17
+ def add_arguments_to_parser (self , parser ):
18
+ parser .set_defaults (
19
+ {
20
+ "trainer.logger" : lazy_instance (
21
+ TensorBoardLogger ,
22
+ save_dir = "" ,
23
+ version = datetime .now ().strftime (r"%Y%m%d-%H%M%S" ),
24
+ log_graph = True ,
25
+ )
26
+ }
27
+ )
28
+
29
+
30
+ def main ():
31
+ """Main Lightning CLI entry point."""
32
+ log_level = os .getenv ("VISCY_LOG_LEVEL" , logging .INFO )
33
+ logging .getLogger ("lightning.pytorch" ).setLevel (log_level )
34
+ torch .set_float32_matmul_precision ("high" )
35
+ _ = ContrastiveLightningCLI (
36
+ model_class = ContrastiveModule , datamodule_class = TripletDataModule
37
+ )
38
+
39
+
40
+ if __name__ == "__main__" :
41
+ main ()
Original file line number Diff line number Diff line change @@ -262,7 +262,7 @@ def _setup_fit(self, dataset_settings: dict):
262
262
self .train_dataset = TripletDataset (
263
263
positions = train_positions ,
264
264
tracks_tables = train_tracks_tables ,
265
- initial_yx_patch_size = self .yx_patch_size ,
265
+ initial_yx_patch_size = self .initial_yx_patch_size ,
266
266
anchor_transform = no_aug_transform ,
267
267
positive_transform = augment_transform ,
268
268
negative_transform = augment_transform ,
@@ -273,7 +273,7 @@ def _setup_fit(self, dataset_settings: dict):
273
273
self .val_dataset = TripletDataset (
274
274
positions = val_positions ,
275
275
tracks_tables = val_tracks_tables ,
276
- initial_yx_patch_size = self .yx_patch_size ,
276
+ initial_yx_patch_size = self .initial_yx_patch_size ,
277
277
anchor_transform = no_aug_transform ,
278
278
positive_transform = augment_transform ,
279
279
negative_transform = augment_transform ,
Original file line number Diff line number Diff line change 11
11
RandomizableTransform ,
12
12
RandScaleIntensityd ,
13
13
RandWeightedCropd ,
14
+ ScaleIntensityRangePercentilesd ,
14
15
)
15
16
from monai .transforms .transform import Randomizable
16
17
from numpy .random .mtrand import RandomState as RandomState
@@ -127,6 +128,34 @@ def __init__(
127
128
)
128
129
129
130
131
+ class ScaleIntensityRangePercentilesd (ScaleIntensityRangePercentilesd ):
132
+ def __init__ (
133
+ self ,
134
+ keys : Union [Sequence [str ], str ],
135
+ lower : float ,
136
+ upper : float ,
137
+ b_min : float | None ,
138
+ b_max : float | None ,
139
+ clip : bool = False ,
140
+ relative : bool = False ,
141
+ channel_wise : bool = False ,
142
+ dtype : Union [Sequence [str ], str ] = None ,
143
+ allow_missing_keys : bool = False ,
144
+ ):
145
+ super ().__init__ (
146
+ keys = keys ,
147
+ lower = lower ,
148
+ upper = upper ,
149
+ b_min = b_min ,
150
+ b_max = b_max ,
151
+ clip = clip ,
152
+ relative = relative ,
153
+ channel_wise = channel_wise ,
154
+ dtype = dtype ,
155
+ allow_missing_keys = allow_missing_keys ,
156
+ )
157
+
158
+
130
159
class NormalizeSampled (MapTransform ):
131
160
"""
132
161
Normalize the sample
You can’t perform that action at this time.
0 commit comments