66# @Email: hvgazula@users.noreply.github.com
77# @Create At: 2024-03-29 09:08:29
88# @Last Modified By: Harsha
9- # @Last Modified At: 2024-03-29 10:31:56
9+ # @Last Modified At: 2024-04-01 16:04:23
1010# @Description: This is description.
1111
1212import os
13+ import sys
1314
1415# ruff: noqa: E402
1516os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "2"
16- import tensorflow as tf
1717import glob
18- from nobrainer .dataset import Dataset
19- import nobrainer
18+ from datetime import datetime
2019
21- from nobrainer .processing .segmentation import Segmentation
22- from nobrainer .models import unet
23- import numpy as np
2420import nibabel as nib
25- from datetime import datetime
26- import copy
21+ import nobrainer
22+ import numpy as np
23+ import tensorflow as tf
24+ from nobrainer .dataset import Dataset
25+ from nobrainer .models import unet
26+ from nobrainer .processing .segmentation import Segmentation
2727
28- NUM_GPUS = len ( tf .config . list_physical_devices ( "GPU" ) )
28+ # tf.data.experimental.enable_debug_mode( )
2929
3030
3131def main_timer (func ):
@@ -117,9 +117,11 @@ def load_sample_tfrec(target: str = "train"):
117117def load_custom_tfrec (target : str = "train" ):
118118
119119 if target == "train" :
120- data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*train*000*"
120+ data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*train*"
121+ data_pattern = "/om2/scratch/Fri/hgazula/kwyk_full/*train*"
121122 else :
122- data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*eval*000*"
123+ data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*eval*"
124+ data_pattern = "/om2/scratch/Fri/hgazula/kwyk_full/*eval*"
123125
124126 volume_shape = (256 , 256 , 256 )
125127 block_shape = None
@@ -128,7 +130,6 @@ def load_custom_tfrec(target: str = "train"):
128130 file_pattern = data_pattern ,
129131 volume_shape = volume_shape ,
130132 block_shape = block_shape ,
131- n_volumes = None ,
132133 )
133134
134135 return dataset
@@ -148,29 +149,35 @@ def get_label_count():
148149
149150# @main_timer
150151def main ():
151- print (tf .config .list_physical_devices ("GPU" ))
152+ NUM_GPUS = len (tf .config .list_physical_devices ("GPU" ))
152153
153- n_epochs = 10
154+ if not NUM_GPUS :
155+ sys .exit ("GPU not found" )
154156
155- model_string = "bem"
157+ n_epochs = 20
156158
157159 print ("loading data" )
158- if True :
159- # run of the following two lines (but not both)
160- # dataset_train, dataset_eval = load_sample_files()
161- dataset_train , dataset_eval = (
162- load_sample_tfrec ("train" ),
163- load_sample_tfrec ("eval" ),
164- )
165- save_freq = "epoch"
160+ if False :
161+ # run one of the following two lines (but not both)
162+ # the second line won't succeed unless the first one is run at least once
163+
164+ dataset_train , dataset_eval = load_sample_files ()
165+ # dataset_train, dataset_eval = (
166+ # load_sample_tfrec("train"),
167+ # load_sample_tfrec("eval"),
168+ # )
169+ # model_string = "bem_test"
170+ # save_freq = "epoch"
166171 else :
167172 dataset_train , dataset_eval = (
168173 load_custom_tfrec ("train" ),
169174 load_custom_tfrec ("eval" ),
170175 )
171- save_freq = 100
176+ model_string = "bem4"
177+ save_freq = 250
172178
173- dataset_train .repeat (n_epochs ).shuffle (1 ).batch (NUM_GPUS )
179+ dataset_train .shuffle (NUM_GPUS ).batch (NUM_GPUS )
180+ dataset_eval .map_labels ()
174181
175182 print ("creating callbacks" )
176183 callback_model_checkpoint = tf .keras .callbacks .ModelCheckpoint (
@@ -189,10 +196,10 @@ def main():
189196 )
190197
191198 callbacks = [
192- # callback_model_checkpoint,
199+ callback_model_checkpoint ,
193200 callback_tensorboard ,
194201 callback_early_stopping ,
195- # callback_backup,
202+ callback_backup ,
196203 ]
197204
198205 print ("creating model" )
@@ -203,10 +210,6 @@ def main():
203210 checkpoint_filepath = f"output/{ model_string } /nobrainer_ckpts" ,
204211 )
205212
206- # Segmentation.init_with_checkpoints(
207- # "unet", checkpoint_filepath=f"output/{model_string}/nobrainer_ckpts"
208- # )
209-
210213 print ("training" )
211214 _ = bem .fit (
212215 dataset_train = dataset_train ,
0 commit comments