Skip to content

Commit fb8d085

Browse files
authored
Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI) (#1379)
* Fix TrainsLogger doctest failing (switch to bypass mode in GitHub CI) * fix * test ci * debug * debug CI * Fix CircleCI * Fix Any CI environment switch to bypass mode * Removed debug prints * Improve code coverage * Improve code coverage * Reverted * Improve code coverage * Test CI * test codecov * Codecov fix * remove pragma Co-authored-by: bmartinn <>
1 parent 2ae2bd2 commit fb8d085

File tree

1 file changed

+59
-22
lines changed

1 file changed

+59
-22
lines changed

pytorch_lightning/loggers/trains.py

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def any_lightning_module_function_or_hook(...):
2424
2525
"""
2626
from argparse import Namespace
27+
from os import environ
2728
from pathlib import Path
2829
from typing import Any, Dict, Optional, Union
2930

@@ -58,9 +59,9 @@ class TrainsLogger(LightningLoggerBase):
5859
sent along side the task scalars. Defaults to True.
5960
6061
Examples:
61-
>>> logger = TrainsLogger("lightning_log", "my-test", output_uri=".") # doctest: +ELLIPSIS
62+
>>> logger = TrainsLogger("lightning_log", "my-lightning-test", output_uri=".") # doctest: +ELLIPSIS
6263
TRAINS Task: ...
63-
TRAINS results page: https://demoapp.trains.allegro.ai/.../log
64+
TRAINS results page: ...
6465
>>> logger.log_metrics({"val_loss": 1.23}, step=0)
6566
>>> logger.log_text("sample test")
6667
sample test
@@ -69,7 +70,7 @@ class TrainsLogger(LightningLoggerBase):
6970
>>> logger.log_image("passed", "Image 1", np.random.randint(0, 255, (200, 150, 3), dtype=np.uint8))
7071
"""
7172

72-
_bypass = False
73+
_bypass = None
7374

7475
def __init__(
7576
self,
@@ -83,8 +84,24 @@ def __init__(
8384
auto_resource_monitoring: bool = True
8485
) -> None:
8586
super().__init__()
86-
if self._bypass:
87+
if self.bypass_mode():
8788
self._trains = None
89+
print('TRAINS Task: running in bypass mode')
90+
print('TRAINS results page: disabled')
91+
92+
class _TaskStub(object):
93+
def __call__(self, *args, **kwargs):
94+
return self
95+
96+
def __getattr__(self, attr):
97+
if attr in ('name', 'id'):
98+
return ''
99+
return self
100+
101+
def __setattr__(self, attr, val):
102+
pass
103+
104+
self._trains = _TaskStub()
88105
else:
89106
self._trains = Task.init(
90107
project_name=project_name,
@@ -114,8 +131,9 @@ def id(self) -> Union[str, None]:
114131
"""
115132
ID is a uuid (string) representing this specific experiment in the entire system.
116133
"""
117-
if self._bypass or not self._trains:
134+
if not self._trains:
118135
return None
136+
119137
return self._trains.id
120138

121139
@rank_zero_only
@@ -126,8 +144,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
126144
params:
127145
The hyperparameters that passed through the model.
128146
"""
129-
if self._bypass or not self._trains:
130-
return None
147+
if not self._trains:
148+
return
131149
if not params:
132150
return
133151

@@ -147,8 +165,8 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
147165
then the elements will be logged as "title" and "series" respectively.
148166
step: Step number at which the metrics should be recorded. Defaults to None.
149167
"""
150-
if self._bypass or not self._trains:
151-
return None
168+
if not self._trains:
169+
return
152170

153171
if not step:
154172
step = self._trains.get_last_iteration()
@@ -179,8 +197,8 @@ def log_metric(self, title: str, series: str, value: float, step: Optional[int]
179197
value: The value to log.
180198
step: Step number at which the metrics should be recorded. Defaults to None.
181199
"""
182-
if self._bypass or not self._trains:
183-
return None
200+
if not self._trains:
201+
return
184202

185203
if not step:
186204
step = self._trains.get_last_iteration()
@@ -197,8 +215,12 @@ def log_text(self, text: str) -> None:
197215
Args:
198216
text: The value of the log (data-point).
199217
"""
200-
if self._bypass or not self._trains:
201-
return None
218+
if self.bypass_mode():
219+
print(text)
220+
return
221+
222+
if not self._trains:
223+
return
202224

203225
self._trains.get_logger().report_text(text)
204226

@@ -222,8 +244,8 @@ def log_image(
222244
step:
223245
Step number at which the metrics should be recorded. Defaults to None.
224246
"""
225-
if self._bypass or not self._trains:
226-
return None
247+
if not self._trains:
248+
return
227249

228250
if not step:
229251
step = self._trains.get_last_iteration()
@@ -265,8 +287,8 @@ def log_artifact(
265287
If True local artifact will be deleted (only applies if artifact_object is a
266288
local file). Defaults to False.
267289
"""
268-
if self._bypass or not self._trains:
269-
return None
290+
if not self._trains:
291+
return
270292

271293
self._trains.upload_artifact(
272294
name=name, artifact_object=artifact, metadata=metadata,
@@ -278,8 +300,9 @@ def save(self) -> None:
278300

279301
@rank_zero_only
280302
def finalize(self, status: str = None) -> None:
281-
if self._bypass or not self._trains:
282-
return None
303+
if self.bypass_mode() or not self._trains:
304+
return
305+
283306
self._trains.close()
284307
self._trains = None
285308

@@ -288,14 +311,16 @@ def name(self) -> Union[str, None]:
288311
"""
289312
Name is a human readable non-unique name (str) of the experiment.
290313
"""
291-
if self._bypass or not self._trains:
314+
if not self._trains:
292315
return ''
316+
293317
return self._trains.name
294318

295319
@property
296320
def version(self) -> Union[str, None]:
297-
if self._bypass or not self._trains:
321+
if not self._trains:
298322
return None
323+
299324
return self._trains.id
300325

301326
@classmethod
@@ -327,9 +352,21 @@ def set_bypass_mode(cls, bypass: bool) -> None:
327352
"""
328353
cls._bypass = bypass
329354

355+
@classmethod
356+
def bypass_mode(cls) -> bool:
357+
"""
358+
bypass_mode returns the bypass mode state.
359+
Notice GITHUB_ACTIONS env will automatically set bypass_mode to True
360+
unless overridden specifically with set_bypass_mode(False)
361+
362+
:return: If True, all outside communication is skipped
363+
"""
364+
return cls._bypass if cls._bypass is not None else bool(environ.get('CI'))
365+
330366
def __getstate__(self) -> Union[str, None]:
331-
if self._bypass or not self._trains:
367+
if self.bypass_mode() or not self._trains:
332368
return ''
369+
333370
return self._trains.id
334371

335372
def __setstate__(self, state: str) -> None:

0 commit comments

Comments
 (0)