Skip to content

Commit 828f803

Browse files
mattersoflightedyoshikun
authored andcommitted
version lighting CLI example (#128)
* version lighting CLI example * add some documentation * ignore slurm output * add more tips
1 parent 8a137d4 commit 828f803

File tree

3 files changed

+162
-0
lines changed

3 files changed

+162
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ __pycache__/
77
# written by setuptools_scm
88
*/_version.py
99

10+
# slurm output files
11+
slurm-*
12+
1013
# Distribution / packaging
1114
.Python
1215
build/
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# See help here on how to configure hyper-parameters with config files: https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html
2+
seed_everything: 42
3+
trainer:
4+
accelerator: gpu
5+
strategy: ddp
6+
devices: 4
7+
num_nodes: 1
8+
precision: 32-true
9+
logger:
10+
class_path: lightning.pytorch.loggers.TensorBoardLogger
11+
init_args:
12+
save_dir: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations
13+
version: chocolate # this is the name of the experiment. The logs will be saved in save_dir/lightning_logs/version
14+
log_graph: True
15+
# Nesting the logger config like this is equivalent to supplying the following argument to lightning.pytorch.Trainer
16+
# logger=TensorBoardLogger(
17+
# "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations",
18+
# log_graph=True,
19+
# version="vanilla",
20+
# )
21+
callbacks:
22+
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
23+
init_args:
24+
logging_interval: step
25+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
26+
init_args:
27+
monitor: loss/val
28+
every_n_epochs: 1
29+
save_top_k: 4
30+
save_last: true
31+
fast_dev_run: false
32+
max_epochs: 100
33+
log_every_n_steps: 10
34+
enable_checkpointing: true
35+
inference_mode: true
36+
use_distributed_sampler: true
37+
model:
38+
backbone: convnext_tiny
39+
in_channels: 2
40+
log_batches_per_epoch: 3
41+
log_samples_per_batch: 3
42+
lr: 0.0002
43+
data:
44+
data_path: /hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr
45+
tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr
46+
source_channel:
47+
- Phase3D
48+
- RFP
49+
z_range: [25, 40]
50+
batch_size: 32
51+
num_workers: 12
52+
initial_yx_patch_size: [384, 384]
53+
final_yx_patch_size: [192, 192]
54+
normalizations:
55+
- class_path: viscy.transforms.NormalizeSampled
56+
init_args:
57+
keys: [Phase3D]
58+
level: fov_statistics
59+
subtrahend: mean
60+
divisor: std
61+
- class_path: viscy.transforms.ScaleIntensityRangePercentilesd
62+
init_args:
63+
keys: [RFP]
64+
lower: 50
65+
upper: 99
66+
b_min: 0.0
67+
b_max: 1.0
68+
augmentations:
69+
- class_path: viscy.transforms.RandAffined
70+
init_args:
71+
keys: [Phase3D, RFP]
72+
prob: 0.8
73+
scale_range: [0, 0.2, 0.2]
74+
rotate_range: [3.14, 0.0, 0.0]
75+
shear_range: [0.0, 0.01, 0.01]
76+
padding_mode: zeros
77+
- class_path: viscy.transforms.RandAdjustContrastd
78+
init_args:
79+
keys: [RFP]
80+
prob: 0.5
81+
gamma: [0.7, 1.3]
82+
- class_path: viscy.transforms.RandAdjustContrastd
83+
init_args:
84+
keys: [Phase3D]
85+
prob: 0.5
86+
gamma: [0.8, 1.2]
87+
- class_path: viscy.transforms.RandScaleIntensityd
88+
init_args:
89+
keys: [RFP]
90+
prob: 0.7
91+
factors: 0.5
92+
- class_path: viscy.transforms.RandScaleIntensityd
93+
init_args:
94+
keys: [Phase3D]
95+
prob: 0.5
96+
factors: 0.5
97+
- class_path: viscy.transforms.RandGaussianSmoothd
98+
init_args:
99+
keys: [Phase3D, RFP]
100+
prob: 0.5
101+
sigma_x: [0.25, 0.75]
102+
sigma_y: [0.25, 0.75]
103+
sigma_z: [0.0, 0.0]
104+
- class_path: viscy.transforms.RandGaussianNoised
105+
init_args:
106+
keys: [RFP]
107+
prob: 0.5
108+
mean: 0.0
109+
std: 0.5
110+
- class_path: viscy.transforms.RandGaussianNoised
111+
init_args:
112+
keys: [Phase3D]
113+
prob: 0.5
114+
mean: 0.0
115+
std: 0.2
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/bin/bash
2+
3+
#SBATCH --job-name=contrastive_origin
4+
#SBATCH --nodes=1
5+
#SBATCH --ntasks-per-node=4
6+
#SBATCH --gres=gpu:4
7+
#SBATCH --partition=gpu
8+
#SBATCH --cpus-per-task=14
9+
#SBATCH --mem-per-cpu=15G
10+
#SBATCH --time=0-20:00:00
11+
12+
# debugging flags (optional)
13+
# https://lightning.ai/docs/pytorch/stable/clouds/cluster_advanced.html
14+
export NCCL_DEBUG=INFO
15+
export PYTHONFAULTHANDLER=1
16+
17+
18+
# Cleanup function to remove the temporary files
19+
function cleanup() {
20+
rm -rf /tmp/$SLURM_JOB_ID/*.zarr
21+
echo "Cleanup Completed."
22+
}
23+
24+
trap cleanup EXIT
25+
# trap the EXIT signal sent to the process and invoke the cleanup.
26+
27+
# Activate the conda environment - specfic to your installation!
28+
module load anaconda/2022.05
29+
# You'll need to replace this path with path to your own conda environment.
30+
conda activate /hpc/mydata/$USER/envs/viscy
31+
32+
config=./demo_cli_fit.yml
33+
34+
# Printing this to the stdout lets us connect the job id to config.
35+
scontrol show job $SLURM_JOB_ID
36+
cat $config
37+
38+
# Run the training CLI
39+
srun python -m viscy.cli.contrastive_triplet fit -c $config
40+
41+
# Tips:
42+
# 1. run this script with `sbatch demo_cli_fit_slurm.sh`
43+
# 2. check the status of the job with `squeue -u $USER`
44+
# 3. use turm to monitor the job with `turm -u first.last`. Use module load turm to load the turm module.

0 commit comments

Comments
 (0)