Skip to content

Commit 76be21a

Browse files
committed
add gpustats callback
1 parent 11347b1 commit 76be21a

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

1.2.0/kwyk_train.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
# @Create At: 2024-03-29 09:08:29
88
# @Last Modified By: Harsha
9-
# @Last Modified At: 2024-04-01 17:44:15
9+
# @Last Modified At: 2024-04-01 19:23:08
1010
# @Description: This is description.
1111

1212
import os
@@ -23,8 +23,9 @@
2323
import tensorflow as tf
2424
from nobrainer.dataset import Dataset
2525
from nobrainer.models import unet
26-
from nobrainer.processing.segmentation import Segmentation
2726
from nobrainer.models.bayesian_meshnet import variational_meshnet
27+
from nobrainer.processing.segmentation import Segmentation
28+
from nvitop.callbacks.keras import GpuStatsLogger
2829

2930
# tf.data.experimental.enable_debug_mode()
3031

@@ -80,7 +81,6 @@ def create_filepaths(path_to_data: str, sample: bool = False) -> None:
8081

8182
@main_timer
8283
def load_sample_files():
83-
8484
if True:
8585
csv_path = nobrainer.utils.get_data()
8686
filepaths = nobrainer.io.read_csv(csv_path)
@@ -116,7 +116,6 @@ def load_sample_tfrec(target: str = "train"):
116116

117117
@main_timer
118118
def load_custom_tfrec(target: str = "train"):
119-
120119
if target == "train":
121120
data_pattern = "/nese/mit/group/sig/data/kwyk/tfrecords/*train*"
122121
data_pattern = "/om2/scratch/Fri/hgazula/kwyk_full/*train*"
@@ -151,6 +150,7 @@ def get_label_count():
151150
# @main_timer
152151
def main():
153152
gpus = tf.config.list_physical_devices("GPU")
153+
gpu_names = [item.name for item in gpus]
154154
for gpu in gpus:
155155
tf.config.experimental.set_memory_growth(gpu, True)
156156
NUM_GPUS = len(gpus)
@@ -198,12 +198,14 @@ def main():
198198
callback_backup = tf.keras.callbacks.BackupAndRestore(
199199
backup_dir=f"output/{model_string}/backup", save_freq=save_freq
200200
)
201+
callback_gpustats = GpuStatsLogger(gpu_names)
201202

202203
callbacks = [
203204
callback_model_checkpoint,
204205
callback_tensorboard,
205206
callback_early_stopping,
206207
callback_backup,
208+
callback_gpustats,
207209
]
208210

209211
print("creating model")
@@ -220,6 +222,7 @@ def main():
220222
dataset_validate=dataset_eval,
221223
epochs=n_epochs,
222224
callbacks=callbacks,
225+
verbose=1,
223226
)
224227

225228
print("Success")

0 commit comments

Comments
 (0)