Skip to content

Commit 5a9afb1

Browse files
Ir1dwilliamFalcon
authored andcommitted
change print to logging (#457)
* change print to logging * always use logging.info * use f-strings * update code style * set logging configs * remove unused code
1 parent 9a5307d commit 5a9afb1

File tree

12 files changed

+56
-41
lines changed

12 files changed

+56
-41
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ trainer = Trainer(max_nb_epochs=1, train_percent_check=0.1)
164164
trainer.fit(model)
165165

166166
# view tensorboard logs
167-
print('View tensorboard logs by running\ntensorboard --logdir %s' % os.getcwd())
168-
print('and going to http://localhost:6006 on your browser')
167+
logging.info(f'View tensorboard logs by running\ntensorboard --logdir {os.getcwd()}')
168+
logging.info('and going to http://localhost:6006 on your browser')
169169
```
170170

171171
When you're all done you can even run the test set separately.

docs/examples/Examples.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def optimize_on_cluster(hyperparams):
119119
job_display_name = job_display_name[0:3]
120120
121121
# run hopt
122-
print('submitting jobs...')
122+
logging.info('submitting jobs...')
123123
cluster.optimize_parallel_cluster_gpu(
124124
main,
125125
nb_trials=hyperparams.nb_hopt_trials,

pl_examples/basic_examples/lightning_module_template.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Example template for defining a system
33
"""
44
import os
5+
import logging
56
from argparse import ArgumentParser
67
from collections import OrderedDict
78

@@ -214,17 +215,17 @@ def __dataloader(self, train):
214215

215216
@pl.data_loader
216217
def train_dataloader(self):
217-
print('training data loader called')
218+
logging.info('training data loader called')
218219
return self.__dataloader(train=True)
219220

220221
@pl.data_loader
221222
def val_dataloader(self):
222-
print('val data loader called')
223+
logging.info('val data loader called')
223224
return self.__dataloader(train=False)
224225

225226
@pl.data_loader
226227
def test_dataloader(self):
227-
print('test data loader called')
228+
logging.info('test data loader called')
228229
return self.__dataloader(train=False)
229230

230231
@staticmethod

pytorch_lightning/callbacks/pt_callbacks.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import shutil
3-
3+
import logging
4+
import warnings
45
import numpy as np
56

67
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel
@@ -91,7 +92,7 @@ def __init__(self, monitor='val_loss',
9192
self.stopped_epoch = 0
9293

9394
if mode not in ['auto', 'min', 'max']:
94-
print('EarlyStopping mode %s is unknown, fallback to auto mode.' % mode)
95+
logging.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
9596
mode = 'auto'
9697

9798
if mode == 'min':
@@ -121,9 +122,10 @@ def on_epoch_end(self, epoch, logs=None):
121122
current = logs.get(self.monitor)
122123
stop_training = False
123124
if current is None:
124-
print('Early stopping conditioned on metric `%s` '
125-
'which is not available. Available metrics are: %s' %
126-
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning)
125+
warnings.warn(
126+
f'Early stopping conditioned on metric `{self.monitor}`'
127+
f' which is not available. Available metrics are: {",".join(list(logs.keys()))}',
128+
RuntimeWarning)
127129
stop_training = True
128130
return stop_training
129131

@@ -141,7 +143,7 @@ def on_epoch_end(self, epoch, logs=None):
141143

142144
def on_train_end(self, logs=None):
143145
if self.stopped_epoch > 0 and self.verbose > 0:
144-
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
146+
logging.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping')
145147

146148

147149
class ModelCheckpoint(Callback):
@@ -187,8 +189,9 @@ def __init__(self, filepath, monitor='val_loss', verbose=0,
187189
self.prefix = prefix
188190

189191
if mode not in ['auto', 'min', 'max']:
190-
print('ModelCheckpoint mode %s is unknown, '
191-
'fallback to auto mode.' % (mode), RuntimeWarning)
192+
warnings.warn(
193+
f'ModelCheckpoint mode {mode} is unknown, '
194+
'fallback to auto mode.', RuntimeWarning)
192195
mode = 'auto'
193196

194197
if mode == 'min':
@@ -232,25 +235,26 @@ def on_epoch_end(self, epoch, logs=None):
232235
if self.save_best_only:
233236
current = logs.get(self.monitor)
234237
if current is None:
235-
print('Can save best model only with %s available,'
236-
' skipping.' % (self.monitor), RuntimeWarning)
238+
warnings.warn(
239+
f'Can save best model only with {self.monitor} available,'
240+
' skipping.', RuntimeWarning)
237241
else:
238242
if self.monitor_op(current, self.best):
239243
if self.verbose > 0:
240-
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
241-
' saving model to %s'
242-
% (epoch + 1, self.monitor, self.best,
243-
current, filepath))
244+
logging.info(
245+
f'\nEpoch {epoch + 1:05d}: {self.monitor} improved'
246+
f' from {self.best:0.5f} to {current:0.5f},',
247+
f' saving model to {filepath}')
244248
self.best = current
245249
self.save_model(filepath, overwrite=True)
246250

247251
else:
248252
if self.verbose > 0:
249-
print('\nEpoch %05d: %s did not improve' %
250-
(epoch + 1, self.monitor))
253+
logging.info(
254+
f'\nEpoch {epoch + 1:05d}: {self.monitor} did not improve')
251255
else:
252256
if self.verbose > 0:
253-
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
257+
logging.info(f'\nEpoch {epoch + 1:05d}: saving model to {filepath}')
254258
self.save_model(filepath, overwrite=False)
255259

256260

@@ -291,6 +295,6 @@ def on_epoch_begin(self, epoch, trainer):
291295
losses = [10, 9, 8, 8, 6, 4.3, 5, 4.4, 2.8, 2.5]
292296
for i, loss in enumerate(losses):
293297
should_stop = c.on_epoch_end(i, logs={'val_loss': loss})
294-
print(loss)
298+
logging.info(loss)
295299
if should_stop:
296300
break

pytorch_lightning/root_module/memory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import pandas as pd
1010
import torch
11+
import logging
1112

1213

1314
class ModelSummary(object):
@@ -166,7 +167,7 @@ def print_mem_stack(): # pragma: no cover
166167
for obj in gc.get_objects():
167168
try:
168169
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
169-
print(type(obj), obj.size())
170+
logging.info(type(obj), obj.size())
170171
except Exception:
171172
pass
172173

pytorch_lightning/root_module/root_module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytorch_lightning.root_module.memory import ModelSummary
1111
from pytorch_lightning.root_module.model_saving import ModelIO
1212
from pytorch_lightning.trainer.trainer_io import load_hparams_from_tags_csv
13+
import logging
1314

1415

1516
class LightningModule(GradInformation, ModelIO, ModelHooks):
@@ -240,7 +241,7 @@ def load_from_checkpoint(cls, checkpoint_path):
240241

241242
def summarize(self, mode):
242243
model_summary = ModelSummary(self, mode=mode)
243-
print(model_summary)
244+
logging.info(model_summary)
244245

245246
def freeze(self):
246247
for param in self.parameters():

pytorch_lightning/trainer/amp_mixin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
APEX_AVAILABLE = True
55
except ImportError:
66
APEX_AVAILABLE = False
7+
import logging
78

89

910
class TrainerAMPMixin(object):
1011

1112
def init_amp(self, use_amp):
1213
self.use_amp = use_amp and APEX_AVAILABLE
1314
if self.use_amp:
14-
print('using 16bit precision')
15+
logging.info('using 16bit precision')
1516

1617
if use_amp and not APEX_AVAILABLE: # pragma: no cover
1718
msg = """

pytorch_lightning/trainer/ddp_mixin.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import re
33
import warnings
4+
import logging
45

56
import torch
67
import torch.distributed as dist
@@ -59,7 +60,7 @@ def set_distributed_mode(self, distributed_backend, nb_gpu_nodes):
5960
'To silence this warning set distributed_backend=ddp'
6061
warnings.warn(w)
6162

62-
print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))
63+
logging.info(f'gpu available: {torch.cuda.is_available()}, used: {self.on_gpu}')
6364

6465
def configure_slurm_ddp(self, nb_gpu_nodes):
6566
self.is_slurm_managing_tasks = False
@@ -107,7 +108,7 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
107108
gpu_str = ','.join([str(x) for x in data_parallel_device_ids])
108109
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str
109110

110-
print(f'VISIBLE GPUS: {os.environ["CUDA_VISIBLE_DEVICES"]}')
111+
logging.info(f'VISIBLE GPUS: {os.environ["CUDA_VISIBLE_DEVICES"]}')
111112

112113
def ddp_train(self, gpu_nb, model):
113114
"""

pytorch_lightning/trainer/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import os
66
import warnings
7+
import logging
78

89
import torch
910
import torch.distributed as dist
@@ -148,7 +149,7 @@ def __init__(self,
148149
Running in fast_dev_run mode: will run a full train,
149150
val loop using a single batch
150151
'''
151-
print(m)
152+
logging.info(m)
152153

153154
# set default save path if user didn't provide one
154155
self.default_save_path = default_save_path
@@ -234,6 +235,9 @@ def __init__(self,
234235
self.amp_level = amp_level
235236
self.init_amp(use_amp)
236237

238+
# set logging options
239+
logging.basicConfig(level=logging.INFO)
240+
237241
@property
238242
def slurm_job_id(self):
239243
try:

pytorch_lightning/trainer/trainer_io.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import signal
44
import warnings
55
from subprocess import call
6+
import logging
67

78
import torch
89
import torch.distributed as dist
@@ -87,7 +88,7 @@ def restore_state_if_checkpoint_exists(self, model):
8788
if last_ckpt_name is not None:
8889
last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name)
8990
self.restore(last_ckpt_path, self.on_gpu)
90-
print(f'model and trainer restored from checkpoint: {last_ckpt_path}')
91+
logging.info(f'model and trainer restored from checkpoint: {last_ckpt_path}')
9192
did_restore = True
9293

9394
return did_restore
@@ -106,36 +107,36 @@ def register_slurm_signal_handlers(self):
106107
pass
107108

108109
if on_slurm:
109-
print('set slurm handle signals')
110+
logging.info('set slurm handle signals')
110111
signal.signal(signal.SIGUSR1, self.sig_handler)
111112
signal.signal(signal.SIGTERM, self.term_handler)
112113

113114
def sig_handler(self, signum, frame):
114115
if self.proc_rank == 0:
115116
# save weights
116-
print('handling SIGUSR1')
117+
logging.info('handling SIGUSR1')
117118
self.hpc_save(self.weights_save_path, self.logger)
118119

119120
# find job id
120121
job_id = os.environ['SLURM_JOB_ID']
121122
cmd = 'scontrol requeue {}'.format(job_id)
122123

123124
# requeue job
124-
print('\nrequeing job {}...'.format(job_id))
125+
logging.info('\nrequeing job {job_id}...')
125126
result = call(cmd, shell=True)
126127

127128
# print result text
128129
if result == 0:
129-
print('requeued exp ', job_id)
130+
logging.info('requeued exp {job_id}')
130131
else:
131-
print('requeue failed...')
132+
logging.info('requeue failed...')
132133

133134
# close experiment to avoid issues
134135
self.logger.close()
135136

136137
def term_handler(self, signum, frame):
137138
# save
138-
print("bypassing sigterm")
139+
logging.info("bypassing sigterm")
139140

140141
# --------------------
141142
# MODEL SAVE CHECKPOINT
@@ -328,7 +329,7 @@ def hpc_load(self, folderpath, on_gpu):
328329
# call model hook
329330
model.on_hpc_load(checkpoint)
330331

331-
print(f'restored hpc model from: {filepath}')
332+
logging.info(f'restored hpc model from: {filepath}')
332333

333334
def max_ckpt_in_folder(self, path, name_key='ckpt_'):
334335
files = os.listdir(path)

0 commit comments

Comments
 (0)