Skip to content

Commit 59c3534

Browse files
committed
prepare for training AZ
1 parent d2aa6d8 commit 59c3534

File tree

5 files changed

+65
-16
lines changed

5 files changed

+65
-16
lines changed

run_sbatch_revision.sbatch

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
#SBATCH -c 4 #4 #8
33
#SBATCH --mem 256G #120G #32G #64G #256G
44
#SBATCH -p grete:shared #grete:shared #grete-h100:shared
5-
#SBATCH -t 24:00:00 #6:00:00 #48:00:00 #SBATCH -G A100:1 #V100:1 #2 #A100:1 #gtx1080:2 #v100:1 #H100:1
5+
#SBATCH -t 48:00:00 #6:00:00 #48:00:00
6+
#SBATCH -G A100:1 #V100:1 #2 #A100:1 #gtx1080:2 #v100:1 #H100:1
67
#SBATCH --output=/user/muth9/u12095/synapse-net/slurm_revision/slurm-%j.out
7-
#SBATCH -A nim00007 #SBATCH --constraint 80gb
8+
#SBATCH -A nim00007
9+
#SBATCH --constraint 80gb
810

911
source ~/.bashrc
1012
conda activate synapse-net
11-
python /user/muth9/u12095/synapse-net/scripts/cooper/revision/merge_az.py -v 6
13+
python /user/muth9/u12095/synapse-net/scripts/cooper/revision/train_az.py -k az_merged_v6

scripts/cooper/revision/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# The root folder which contains the new AZ training data.
66
INPUT_ROOT = "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data"
77
# The output folder for AZ predictions.
8-
OUTPUT_ROOT = "/mnt/ceph-hdd/cold/nim00007/AZ_predictions_new"
8+
OUTPUT_ROOT = "/mnt/ceph-hdd/cold/nim00007/AZ_prediction_new_copy"
99

1010
# The names of all datasets for which to run prediction / evaluation.
1111
# This excludes 'endbulb_of_held_cropped', which is a duplicate of 'endbulb_of_held',

scripts/cooper/revision/merge_az.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import os
3+
from glob import glob
34

45
import h5py
56
import napari
@@ -19,13 +20,19 @@
1920
# STEM CROPPED IS OFTEN TOO SMALL!
2021
def merge_az(name, version, check):
2122
split_folder = get_split_folder(version)
22-
file_names = get_file_names(name, split_folder, split_names=["train", "val", "test"])
23+
24+
if name == "stem_cropped":
25+
file_paths = glob(os.path.join("/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/stem_cropped", "*.h5"))
26+
file_names = [os.path.basename(path) for path in file_paths]
27+
else:
28+
file_names = get_file_names(name, split_folder, split_names=["train", "val", "test"])
2329
seg_paths, gt_paths = get_paths(name, file_names)
2430

2531
for seg_path, gt_path in zip(seg_paths, gt_paths):
2632

2733
with h5py.File(gt_path, "r") as f:
28-
if not check and ("labels/az_merged" in f):
34+
#if not check and ("labels/az_merged" in f):
35+
if f"labels/az_merged_v{version}" in f :
2936
continue
3037
raw = f["raw"][:]
3138
gt = f["labels/az"][:]
@@ -56,9 +63,16 @@ def merge_az(name, version, check):
5663
v.title = f"{name}/{fname}"
5764
napari.run()
5865

66+
print(f"gt_path {gt_path}")
67+
with h5py.File(gt_path, "a") as f:
68+
f.create_dataset(f"labels/az_merged_v{version}", data=az_merged, compression="lzf")
69+
5970
else:
60-
with h5py.File(seg_path, "a") as f:
71+
print(f"gt_path {gt_path}")
72+
with h5py.File(gt_path, "a") as f:
6173
f.create_dataset(f"labels/az_merged_v{version}", data=az_merged, compression="lzf")
74+
'''with h5py.File(seg_path, "a") as f:
75+
f.create_dataset(f"labels/az_merged_v{version}", data=az_merged, compression="lzf")'''
6276

6377

6478
def visualize_merge(args):

scripts/cooper/revision/remove_az_thin.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import h5py
1+
'''import h5py
22
33
files = [
44
"/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/stem_cropped2_rescaled/36859_H2_SP_02_rec_2Kb1dawbp_crop_crop1.h5",
@@ -15,4 +15,32 @@
1515
del f["labels/az_thin"]
1616
1717
# Recreate the dataset with the new data
18-
f.create_dataset("labels/az_thin", data=gt)
18+
f.create_dataset("labels/az_thin", data=gt)
19+
'''
20+
import h5py
21+
import numpy as np
22+
import os
23+
from glob import glob
24+
25+
folder = "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/stem_cropped/"
26+
27+
# List of file names to process
28+
file_names = [
29+
"36859_H2_SP_01_rec_2Kb1dawbp_crop_cropped_noAZ.h5",
30+
"36859_H2_SP_02_rec_2Kb1dawbp_crop_cropped_noAZ.h5",
31+
"36859_H2_SP_03_rec_2Kb1dawbp_crop_cropped_noAZ.h5",
32+
"36859_H3_SP_05_rec_2kb1dawbp_crop_cropped_noAZ.h5",
33+
"36859_H3_SP_07_rec_2kb1dawbp_crop_cropped_noAZ.h5",
34+
"36859_H3_SP_10_rec_2kb1dawbp_crop_cropped_noAZ.h5"
35+
]
36+
37+
file_paths = glob(os.path.join("/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/endbulb_of_held_cropped", "*.h5"))
38+
39+
for fname in file_paths:
40+
#file_path = os.path.join(folder, fname)
41+
42+
with h5py.File(fname, "a") as f:
43+
az_merged = f["/labels/az_merged"][:]
44+
f.create_dataset("/labels/az_merged_v6", data=az_merged, compression="lzf")
45+
46+
print(f"Updated file: {fname}")

scripts/cooper/revision/train_az.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010
from synapse_net.training import supervised_training, AZDistanceLabelTransform
1111

12-
TRAIN_ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/new_AZ_train_data"
12+
TRAIN_ROOT = "/mnt/ceph-hdd/cold/nim00007/new_AZ_train_data"
1313
OUTPUT_ROOT = "./models_az_thin"
1414

1515

1616
def _require_train_val_test_split(datasets):
17-
train_ratio, val_ratio, test_ratio = 0.70, 0.1, 0.2
17+
train_ratio, val_ratio, test_ratio = 0.60, 0.2, 0.2
1818

1919
def _train_val_test_split(names):
2020
train, test = train_test_split(names, test_size=1 - train_ratio, shuffle=True)
@@ -87,17 +87,22 @@ def train(key, ignore_label=None, use_distances=False, training_2D=False, testse
8787

8888
os.makedirs(OUTPUT_ROOT, exist_ok=True)
8989

90-
datasets = ["tem", "chemical_fixation", "stem", "stem_cropped", "endbulb_of_held", "endbulb_of_held_cropped"]
91-
train_paths = get_paths("train", datasets=datasets, testset=testset)
92-
val_paths = get_paths("val", datasets=datasets, testset=testset)
90+
datasets_with_testset_true = ["tem", "chemical_fixation", "stem", "endbulb_of_held"]
91+
datasets_with_testset_false = ["stem_cropped", "endbulb_of_held_cropped"]
92+
93+
train_paths = get_paths("train", datasets=datasets_with_testset_true, testset=True)
94+
val_paths = get_paths("val", datasets=datasets_with_testset_true, testset=True)
95+
96+
train_paths += get_paths("train", datasets=datasets_with_testset_false, testset=False)
97+
val_paths += get_paths("val", datasets=datasets_with_testset_false, testset=False)
9398

9499
print("Start training with:")
95100
print(len(train_paths), "tomograms for training")
96101
print(len(val_paths), "tomograms for validation")
97102

98103
# patch_shape = [48, 256, 256]
99104
patch_shape = [48, 384, 384]
100-
model_name = "v6"
105+
model_name = "v7"
101106

102107
# checking for 2D training
103108
if training_2D:
@@ -121,7 +126,7 @@ def train(key, ignore_label=None, use_distances=False, training_2D=False, testse
121126
sampler=torch_em.data.sampler.MinInstanceSampler(min_num_instances=1, p_reject=0.85),
122127
n_samples_train=None, n_samples_val=100,
123128
check=check,
124-
save_root=OUTPUT_ROOT,
129+
save_root="/mnt/lustre-emmy-hdd/usr/u12095/synapse_net/models/ConstantinAZ",
125130
n_iterations=int(2e5),
126131
ignore_label=ignore_label,
127132
label_transform=label_transform,

0 commit comments

Comments
 (0)