-
Notifications
You must be signed in to change notification settings - Fork 299
Expand file tree
/
Copy pathtest_wb_logger.py
More file actions
131 lines (104 loc) · 3.87 KB
/
test_wb_logger.py
File metadata and controls
131 lines (104 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Tests `imitation.util.logger.WandbOutputFormat`."""
import sys
from typing import Any, Mapping, Optional
from unittest import mock
import pytest
import wandb
from imitation.util import logger
class MockHistory:
"""Mock History object for testing."""
def __init__(self):
"""Initializes an instance of MockHistory."""
self._step = 0
self._data = dict()
self._callback = None
def _set_callback(self, cb):
self._callback = cb
def _row_update(self, row):
self._data.update(row)
def _row_add(self, row):
self._data.update(row)
self._flush()
self._step += 1
def _flush(self):
if len(self._data) > 0:
self._data["_step"] = self._step
if self._callback:
self._callback(row=self._data, step=self._step)
self._data = dict()
class MockWandb:
"""Mock Wandb object for testing."""
def __init__(self):
"""Initializes an instance of MockWandb."""
self._initialized = False
self.history = MockHistory()
self.history_list = []
self.history._set_callback(self._history_callback)
self._init_args = None
self._init_kwargs = None
def init(self, *args, **kwargs):
self._initialized = True
self._init_args = args
self._init_kwargs = kwargs
def log(
self,
data: Mapping[str, Any],
step: Optional[int] = None,
commit: bool = False,
sync: bool = False,
):
assert self._initialized
if sync:
raise NotImplementedError("usage of sync to MockWandb.log not implemented")
if step is not None:
if step > self.history._step:
self.history._flush()
self.history._step = step
if commit:
self.history._row_add(data)
else:
self.history._row_update(data)
def _history_callback(self, row: Mapping[str, Any], step: int) -> None:
self.history_list.append(row)
def finish(self):
assert self._initialized
self._initialized = False
mock_wandb = MockWandb()
# we ignore the type below as one should technically not access the
# __init__ method directly but only by creating an instance.
@mock.patch.object(wandb, "__init__", mock_wandb.__init__) # type: ignore[misc]
@mock.patch.object(wandb, "init", mock_wandb.init)
@mock.patch.object(wandb, "log", mock_wandb.log)
@mock.patch.object(wandb, "finish", mock_wandb.finish)
def test_wandb_output_format():
wandb.init()
log_obj = logger.configure(format_strs=["wandb"])
assert len(mock_wandb.history_list) == 0, "nothing should be logged yet"
log_obj.info("test 123")
assert len(mock_wandb.history_list) == 0, "nothing should be logged yet"
log_obj.record("foo", 42)
assert len(mock_wandb.history_list) == 0, "nothing should be logged yet"
log_obj.record("fow", 24, exclude="wandb")
log_obj.record("fizz", 12, exclude="stdout")
log_obj.dump()
assert len(mock_wandb.history_list) == 1, "exactly one entry should be logged"
assert mock_wandb.history_list == [{"_step": 0, "foo": 42, "fizz": 12}]
log_obj.record("fizz", 21)
log_obj.dump(step=3)
assert len(mock_wandb.history_list) == 2, "exactly two entries should be logged"
assert mock_wandb.history_list == [
{"_step": 0, "foo": 42, "fizz": 12},
{"_step": 3, "fizz": 21},
]
with pytest.raises(ValueError, match=r"wandb.Video accepts a file path.*"):
log_obj.record("video", 42)
log_obj.dump(step=4)
log_obj.close()
def test_wandb_module_import_error():
wandb_module = sys.modules["wandb"]
try:
sys.modules["wandb"] = None
with pytest.raises(ModuleNotFoundError, match=r"Trying to log data.*"):
logger.configure(format_strs=["wandb"])
finally:
sys.modules[wandb] = wandb_module