Skip to content

Commit 11347b1

Browse files
committed
variational meshnet for binary segmentation
1 parent 4e4abbe commit 11347b1

File tree

1 file changed

+229
-0
lines changed

1 file changed

+229
-0
lines changed

1.2.0/kwyk_train.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Copyright (c) 2024 MIT
2+
#
3+
# -*- coding:utf-8 -*-
4+
# @Script: kwyk_train.py
5+
# @Author: Harsha
6+
# @Email: hvgazula@users.noreply.github.com
7+
# @Create At: 2024-03-29 09:08:29
8+
# @Last Modified By: Harsha
9+
# @Last Modified At: 2024-04-01 17:44:15
10+
# @Description: This is description.
11+
12+
import os
13+
import sys
14+
15+
# ruff: noqa: E402
16+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
17+
import glob
18+
from datetime import datetime
19+
20+
import nibabel as nib
21+
import nobrainer
22+
import numpy as np
23+
import tensorflow as tf
24+
from nobrainer.dataset import Dataset
25+
from nobrainer.models import unet
26+
from nobrainer.processing.segmentation import Segmentation
27+
from nobrainer.models.bayesian_meshnet import variational_meshnet
28+
29+
# tf.data.experimental.enable_debug_mode()
30+
31+
32+
def main_timer(func):
33+
"""Decorator to time any function"""
34+
35+
def function_wrapper(*args, **kwargs):
36+
start_time = datetime.now()
37+
# print(f'Start Time: {start_time.strftime("%A %m/%d/%Y %H:%M:%S")}')
38+
result = func(*args, *kwargs)
39+
end_time = datetime.now()
40+
# print(f'End Time: {end_time.strftime("%A %m/%d/%Y %H:%M:%S")}')
41+
print(
42+
f"Function: {func.__name__} Total runtime: {end_time - start_time} (HH:MM:SS)"
43+
)
44+
return result
45+
46+
return function_wrapper
47+
48+
49+
def sort_function(item):
50+
return int(os.path.basename(item).split("_")[1])
51+
52+
53+
def create_filepaths(path_to_data: str, sample: bool = False) -> None:
54+
"""Create filepaths CSV file.
55+
56+
Args:
57+
path_to_data (str): Path to data directory.
58+
sample (bool, optional): Whether to create a sample filepaths CSV. Defaults to False.
59+
"""
60+
if not path_to_data:
61+
path_to_data = "/nese/mit/group/sig/data/kwyk/rawdata"
62+
63+
feature_paths = sorted(
64+
glob.glob(os.path.join(path_to_data, "*orig*.nii.gz")), key=sort_function
65+
)
66+
label_paths = sorted(
67+
glob.glob(os.path.join(path_to_data, "*aseg*.nii.gz")), key=sort_function
68+
)
69+
70+
assert len(feature_paths) == len(
71+
label_paths
72+
), "Mismatch between feature and label paths"
73+
74+
file_name = "filepaths_sample.csv" if sample else "filepaths.csv"
75+
76+
with open(file_name, "w") as f:
77+
for feature, label in zip(feature_paths, label_paths):
78+
f.write(f"{feature},{label}\n")
79+
80+
81+
@main_timer
82+
def load_sample_files():
83+
84+
if True:
85+
csv_path = nobrainer.utils.get_data()
86+
filepaths = nobrainer.io.read_csv(csv_path)
87+
88+
dataset_train, dataset_eval = Dataset.from_files(
89+
filepaths,
90+
out_tfrec_dir="data/binseg",
91+
shard_size=3,
92+
num_parallel_calls=None,
93+
n_classes=1,
94+
)
95+
return dataset_train, dataset_eval
96+
97+
98+
def load_sample_tfrec(target: str = "train"):
99+
volume_shape = (256, 256, 256)
100+
block_shape = None
101+
102+
if target == "train":
103+
data_pattern = "data/binseg/*train*"
104+
else:
105+
data_pattern = "data/binseg/*eval*"
106+
107+
dataset = Dataset.from_tfrecords(
108+
file_pattern=data_pattern,
109+
volume_shape=volume_shape,
110+
block_shape=block_shape,
111+
n_volumes=None,
112+
)
113+
114+
return dataset
115+
116+
117+
@main_timer
118+
def load_custom_tfrec(target: str = "train"):
119+
120+
if target == "train":
121+
data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*train*"
122+
data_pattern = "/om2/scratch/Fri/hgazula/kwyk_full/*train*"
123+
else:
124+
data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*eval*"
125+
data_pattern = "/om2/scratch/Fri/hgazula/kwyk_full/*eval*"
126+
127+
volume_shape = (256, 256, 256)
128+
block_shape = None
129+
130+
dataset = Dataset.from_tfrecords(
131+
file_pattern=data_pattern,
132+
volume_shape=volume_shape,
133+
block_shape=block_shape,
134+
)
135+
136+
return dataset
137+
138+
139+
@main_timer
140+
def get_label_count():
141+
label_count = []
142+
with open("filepaths.csv", "r") as f:
143+
lines = f.readlines()[:500]
144+
for line in lines:
145+
_, label = line.strip().split(",")
146+
label_count.append(len(np.unique(nib.load(label).get_fdata())))
147+
148+
print(set(label_count))
149+
150+
151+
# @main_timer
152+
def main():
153+
gpus = tf.config.list_physical_devices("GPU")
154+
for gpu in gpus:
155+
tf.config.experimental.set_memory_growth(gpu, True)
156+
NUM_GPUS = len(gpus)
157+
158+
if not NUM_GPUS:
159+
sys.exit("GPU not found")
160+
161+
n_epochs = 20
162+
163+
print("loading data")
164+
if False:
165+
# run one of the following two lines (but not both)
166+
# the second line won't succeed unless the first one is run at least once
167+
168+
dataset_train, dataset_eval = load_sample_files()
169+
# dataset_train, dataset_eval = (
170+
# load_sample_tfrec("train"),
171+
# load_sample_tfrec("eval"),
172+
# )
173+
# model_string = "bem_test"
174+
# save_freq = "epoch"
175+
else:
176+
dataset_train, dataset_eval = (
177+
load_custom_tfrec("train"),
178+
load_custom_tfrec("eval"),
179+
)
180+
model_string = "kwyk"
181+
save_freq = 250
182+
183+
dataset_train.shuffle(NUM_GPUS).batch(NUM_GPUS)
184+
dataset_eval.map_labels()
185+
186+
print("creating callbacks")
187+
callback_model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
188+
os.path.join(f"output/{model_string}/model_ckpts", "model_{epoch:03d}.keras")
189+
)
190+
callback_tensorboard = tf.keras.callbacks.TensorBoard(
191+
log_dir=f"output/{model_string}/logs/", histogram_freq=1
192+
)
193+
callback_early_stopping = tf.keras.callbacks.EarlyStopping(
194+
monitor="val_loss",
195+
min_delta=1e-4,
196+
patience=10,
197+
)
198+
callback_backup = tf.keras.callbacks.BackupAndRestore(
199+
backup_dir=f"output/{model_string}/backup", save_freq=save_freq
200+
)
201+
202+
callbacks = [
203+
callback_model_checkpoint,
204+
callback_tensorboard,
205+
callback_early_stopping,
206+
callback_backup,
207+
]
208+
209+
print("creating model")
210+
kwyk = Segmentation(
211+
variational_meshnet,
212+
model_args=dict(no_examples=9200, filters=21),
213+
multi_gpu=True,
214+
checkpoint_filepath=f"output/{model_string}/nobrainer_ckpts",
215+
)
216+
217+
print("training")
218+
_ = kwyk.fit(
219+
dataset_train=dataset_train,
220+
dataset_validate=dataset_eval,
221+
epochs=n_epochs,
222+
callbacks=callbacks,
223+
)
224+
225+
print("Success")
226+
227+
228+
if __name__ == "__main__":
229+
main()

0 commit comments

Comments
 (0)