6677# @Create At: 2024-03-29 09:08:29
88# @Last Modified By: Harsha
9- # @Last Modified At: 2024-04-01 17:44:15
9+ # @Last Modified At: 2024-04-01 19:23:08
1010# @Description: This is description.
1111
1212import os
2323import tensorflow as tf
2424from nobrainer .dataset import Dataset
2525from nobrainer .models import unet
26- from nobrainer .processing .segmentation import Segmentation
2726from nobrainer .models .bayesian_meshnet import variational_meshnet
27+ from nobrainer .processing .segmentation import Segmentation
28+ from nvitop .callbacks .keras import GpuStatsLogger
2829
2930# tf.data.experimental.enable_debug_mode()
3031
@@ -80,7 +81,6 @@ def create_filepaths(path_to_data: str, sample: bool = False) -> None:
8081
8182@main_timer
8283def load_sample_files ():
83-
8484 if True :
8585 csv_path = nobrainer .utils .get_data ()
8686 filepaths = nobrainer .io .read_csv (csv_path )
@@ -116,7 +116,6 @@ def load_sample_tfrec(target: str = "train"):
116116
117117@main_timer
118118def load_custom_tfrec (target : str = "train" ):
119-
120119 if target == "train" :
121120 data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*train*"
122121 data_pattern = "/om2/scratch/Fri/hgazula/kwyk_full/*train*"
@@ -151,6 +150,7 @@ def get_label_count():
151150# @main_timer
152151def main ():
153152 gpus = tf .config .list_physical_devices ("GPU" )
153+ gpu_names = [item .name for item in gpus ]
154154 for gpu in gpus :
155155 tf .config .experimental .set_memory_growth (gpu , True )
156156 NUM_GPUS = len (gpus )
@@ -198,12 +198,14 @@ def main():
198198 callback_backup = tf .keras .callbacks .BackupAndRestore (
199199 backup_dir = f"output/{ model_string } /backup" , save_freq = save_freq
200200 )
201+ callback_gpustats = GpuStatsLogger (gpu_names )
201202
202203 callbacks = [
203204 callback_model_checkpoint ,
204205 callback_tensorboard ,
205206 callback_early_stopping ,
206207 callback_backup ,
208+ callback_gpustats ,
207209 ]
208210
209211 print ("creating model" )
@@ -220,6 +222,7 @@ def main():
220222 dataset_validate = dataset_eval ,
221223 epochs = n_epochs ,
222224 callbacks = callbacks ,
225+ verbose = 1 ,
223226 )
224227
225228 print ("Success" )
0 commit comments