Skip to content

Commit 4e4abbe

Browse files
committed
minor refactor
1 parent ba0acd4 commit 4e4abbe

File tree

1 file changed

+35
-32
lines changed

1 file changed

+35
-32
lines changed

1.2.0/brainy_train.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,26 @@
66
# @Email: hvgazula@users.noreply.github.com
77
# @Create At: 2024-03-29 09:08:29
88
# @Last Modified By: Harsha
9-
# @Last Modified At: 2024-03-29 10:31:56
9+
# @Last Modified At: 2024-04-01 16:04:23
1010
# @Description: This is description.
1111

1212
import os
13+
import sys
1314

1415
# ruff: noqa: E402
1516
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
16-
import tensorflow as tf
1717
import glob
18-
from nobrainer.dataset import Dataset
19-
import nobrainer
18+
from datetime import datetime
2019

21-
from nobrainer.processing.segmentation import Segmentation
22-
from nobrainer.models import unet
23-
import numpy as np
2420
import nibabel as nib
25-
from datetime import datetime
26-
import copy
21+
import nobrainer
22+
import numpy as np
23+
import tensorflow as tf
24+
from nobrainer.dataset import Dataset
25+
from nobrainer.models import unet
26+
from nobrainer.processing.segmentation import Segmentation
2727

28-
NUM_GPUS = len(tf.config.list_physical_devices("GPU"))
28+
# tf.data.experimental.enable_debug_mode()
2929

3030

3131
def main_timer(func):
@@ -117,9 +117,11 @@ def load_sample_tfrec(target: str = "train"):
117117
def load_custom_tfrec(target: str = "train"):
118118

119119
if target == "train":
120-
data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*train*000*"
120+
data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*train*"
121+
data_pattern = "/om2/scratch/Fri/hgazula/kwyk_full/*train*"
121122
else:
122-
data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*eval*000*"
123+
data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*eval*"
124+
data_pattern = "/om2/scratch/Fri/hgazula/kwyk_full/*eval*"
123125

124126
volume_shape = (256, 256, 256)
125127
block_shape = None
@@ -128,7 +130,6 @@ def load_custom_tfrec(target: str = "train"):
128130
file_pattern=data_pattern,
129131
volume_shape=volume_shape,
130132
block_shape=block_shape,
131-
n_volumes=None,
132133
)
133134

134135
return dataset
@@ -148,29 +149,35 @@ def get_label_count():
148149

149150
# @main_timer
150151
def main():
151-
print(tf.config.list_physical_devices("GPU"))
152+
NUM_GPUS = len(tf.config.list_physical_devices("GPU"))
152153

153-
n_epochs = 10
154+
if not NUM_GPUS:
155+
sys.exit("GPU not found")
154156

155-
model_string = "bem"
157+
n_epochs = 20
156158

157159
print("loading data")
158-
if True:
159-
# run of the following two lines (but not both)
160-
# dataset_train, dataset_eval = load_sample_files()
161-
dataset_train, dataset_eval = (
162-
load_sample_tfrec("train"),
163-
load_sample_tfrec("eval"),
164-
)
165-
save_freq = "epoch"
160+
if False:
161+
# run one of the following two lines (but not both)
162+
# the second line won't succeed unless the first one is run at least once
163+
164+
dataset_train, dataset_eval = load_sample_files()
165+
# dataset_train, dataset_eval = (
166+
# load_sample_tfrec("train"),
167+
# load_sample_tfrec("eval"),
168+
# )
169+
# model_string = "bem_test"
170+
# save_freq = "epoch"
166171
else:
167172
dataset_train, dataset_eval = (
168173
load_custom_tfrec("train"),
169174
load_custom_tfrec("eval"),
170175
)
171-
save_freq = 100
176+
model_string = "bem4"
177+
save_freq = 250
172178

173-
dataset_train.repeat(n_epochs).shuffle(1).batch(NUM_GPUS)
179+
dataset_train.shuffle(NUM_GPUS).batch(NUM_GPUS)
180+
dataset_eval.map_labels()
174181

175182
print("creating callbacks")
176183
callback_model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
@@ -189,10 +196,10 @@ def main():
189196
)
190197

191198
callbacks = [
192-
# callback_model_checkpoint,
199+
callback_model_checkpoint,
193200
callback_tensorboard,
194201
callback_early_stopping,
195-
# callback_backup,
202+
callback_backup,
196203
]
197204

198205
print("creating model")
@@ -203,10 +210,6 @@ def main():
203210
checkpoint_filepath=f"output/{model_string}/nobrainer_ckpts",
204211
)
205212

206-
# Segmentation.init_with_checkpoints(
207-
# "unet", checkpoint_filepath=f"output/{model_string}/nobrainer_ckpts"
208-
# )
209-
210213
print("training")
211214
_ = bem.fit(
212215
dataset_train=dataset_train,

0 commit comments

Comments
 (0)