Skip to content

Commit ad2d319

Browse files
committed
add unit tests for public/double underscore methods
1 parent a785341 commit ad2d319

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

tests/tests_pytorch/loggers/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def comet_mock(monkeypatch):
9898
comet_experiment = MagicMock(name="CommonExperiment")
9999
setattr(comet_experiment, "__internal_api__set_model_graph__", MagicMock())
100100
setattr(comet_experiment, "__internal_api__log_metrics__", MagicMock())
101+
setattr(comet_experiment, "__internal_api__log_parameters__", MagicMock())
101102

102103
comet.Experiment = MagicMock(name="Experiment", return_value=comet_experiment)
103104
comet.ExistingExperiment = MagicMock(name="ExistingExperiment", return_value=comet_experiment)

tests/tests_pytorch/loggers/test_comet.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from lightning.pytorch.loggers import CometLogger
1919
from torch import tensor
2020

21+
FRAMEWORK_NAME = "pytorch-lightning"
22+
2123

2224
def _patch_comet_atexit(monkeypatch):
2325
"""Prevent comet logger from trying to print at exit, since pytest's stdout/stderr redirection breaks it."""
@@ -148,6 +150,46 @@ def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch):
148150
)
149151

150152

153+
@mock.patch.dict(os.environ, {})
154+
def test_comet_log_hyperparams(comet_mock, tmp_path, monkeypatch):
155+
"""Test that CometLogger.log_hyperparams calls internal API method."""
156+
_patch_comet_atexit(monkeypatch)
157+
158+
logger = CometLogger(project_name="test")
159+
hyperparams = {
160+
"batch_size": 256,
161+
"config": {
162+
"SLURM Job ID": "22334455",
163+
"RGB slurm jobID": "12345678",
164+
"autoencoder_model": False,
165+
},
166+
}
167+
logger.log_hyperparams(hyperparams)
168+
169+
logger.experiment.__internal_api__log_parameters__.assert_called_once_with(
170+
parameters=hyperparams,
171+
framework=FRAMEWORK_NAME,
172+
flatten_nested=True,
173+
source="manual",
174+
)
175+
176+
177+
@mock.patch.dict(os.environ, {})
178+
def test_comet_log_graph(comet_mock, tmp_path, monkeypatch):
179+
"""Test that CometLogger.log_hyperparams calls internal API method."""
180+
_patch_comet_atexit(monkeypatch)
181+
182+
logger = CometLogger(project_name="test")
183+
model = Mock()
184+
185+
logger.log_graph(model=model)
186+
187+
logger.experiment.__internal_api__set_model_graph__.assert_called_once_with(
188+
graph=model,
189+
framework="pytorch-lightning",
190+
)
191+
192+
151193
@mock.patch.dict(os.environ, {})
152194
def test_comet_metrics_safe(comet_mock, tmp_path, monkeypatch):
153195
"""Test that CometLogger.log_metrics doesn't do inplace modification of metrics."""

0 commit comments

Comments
 (0)