Skip to content

Commit 0ba38c7

Browse files
Migrate from wandb to tensorboard (#122)
* wip: use lightning's tensorboard logger instead of wandb * private logging methods * log center slice only * fix tensor cloning * only log metrics on epoch * add simple demo training script * fix flaky test * log graph + profiling * switch to simple profiler --------- Co-authored-by: Shalin Mehta <[email protected]>
1 parent 0a72176 commit 0ba38c7

File tree

4 files changed

+170
-314
lines changed

4 files changed

+170
-314
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from lightning.pytorch import Trainer
2+
from lightning.pytorch.callbacks import ModelCheckpoint
3+
from lightning.pytorch.loggers import TensorBoardLogger
4+
from lightning.pytorch.callbacks import DeviceStatsMonitor
5+
6+
7+
from viscy.data.triplet import TripletDataModule
8+
from viscy.light.engine import ContrastiveModule
9+
10+
11+
def main():
12+
dm = TripletDataModule(
13+
data_path="/hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr",
14+
tracks_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr",
15+
source_channel=["Phase3D", "RFP"],
16+
z_range=(20, 35),
17+
batch_size=16,
18+
num_workers=10,
19+
initial_yx_patch_size=(384, 384),
20+
final_yx_patch_size=(224, 224),
21+
)
22+
model = ContrastiveModule(
23+
backbone="convnext_tiny",
24+
in_channels=2,
25+
log_batches_per_epoch=2,
26+
log_samples_per_batch=3,
27+
)
28+
trainer = Trainer(
29+
max_epochs=5,
30+
limit_train_batches=10,
31+
limit_val_batches=5,
32+
logger=TensorBoardLogger(
33+
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/test_tb",
34+
log_graph=True,
35+
default_hp_metric=True,
36+
),
37+
log_every_n_steps=1,
38+
callbacks=[ModelCheckpoint()],
39+
profiler="simple", # other options: "advanced" uses cprofiler, "pytorch" uses pytorch profiler.
40+
)
41+
trainer.fit(model, dm)
42+
43+
44+
if __name__ == "__main__":
45+
main()

tests/preprocessing/test_pixel_ratio.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
def test_sematic_class_weights(small_hcs_dataset):
77
weights = sematic_class_weights(small_hcs_dataset, "GFP")
88
assert weights.shape == (3,)
9-
assert_allclose(weights[0], 1.0)
9+
assert_allclose(weights[0], 1.0, atol=1e-5)
1010
# infinity
1111
assert weights[1] > 1.0
1212
assert weights[2] > 1.0
13-
assert sematic_class_weights(
14-
small_hcs_dataset, "GFP", num_classes=2
15-
).shape == (2,)
13+
assert sematic_class_weights(small_hcs_dataset, "GFP", num_classes=2).shape == (2,)

viscy/data/hcs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,11 @@ def _read_norm_meta(fov: Position) -> NormMeta | None:
8888
for channel, channel_values in norm_meta.items():
8989
for level, level_values in channel_values.items():
9090
for stat, value in level_values.items():
91-
norm_meta[channel][level][stat] = torch.tensor(
92-
value, dtype=torch.float32
93-
)
91+
if isinstance(value, Tensor):
92+
value = value.clone().float()
93+
else:
94+
value = torch.tensor(value, dtype=torch.float32)
95+
norm_meta[channel][level][stat] = value
9496
return norm_meta
9597

9698

0 commit comments

Comments
 (0)