|
18 | 18 | from lightning.pytorch.loggers import CometLogger |
19 | 19 | from torch import tensor |
20 | 20 |
|
| 21 | +FRAMEWORK_NAME = "pytorch-lightning" |
| 22 | + |
21 | 23 |
|
22 | 24 | def _patch_comet_atexit(monkeypatch): |
23 | 25 | """Prevent comet logger from trying to print at exit, since pytest's stdout/stderr redirection breaks it.""" |
@@ -148,6 +150,46 @@ def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch): |
148 | 150 | ) |
149 | 151 |
|
150 | 152 |
|
| 153 | +@mock.patch.dict(os.environ, {}) |
| 154 | +def test_comet_log_hyperparams(comet_mock, tmp_path, monkeypatch): |
| 155 | + """Test that CometLogger.log_hyperparams calls internal API method.""" |
| 156 | + _patch_comet_atexit(monkeypatch) |
| 157 | + |
| 158 | + logger = CometLogger(project_name="test") |
| 159 | + hyperparams = { |
| 160 | + "batch_size": 256, |
| 161 | + "config": { |
| 162 | + "SLURM Job ID": "22334455", |
| 163 | + "RGB slurm jobID": "12345678", |
| 164 | + "autoencoder_model": False, |
| 165 | + }, |
| 166 | + } |
| 167 | + logger.log_hyperparams(hyperparams) |
| 168 | + |
| 169 | + logger.experiment.__internal_api__log_parameters__.assert_called_once_with( |
| 170 | + parameters=hyperparams, |
| 171 | + framework=FRAMEWORK_NAME, |
| 172 | + flatten_nested=True, |
| 173 | + source="manual", |
| 174 | + ) |
| 175 | + |
| 176 | + |
| 177 | +@mock.patch.dict(os.environ, {}) |
| 178 | +def test_comet_log_graph(comet_mock, tmp_path, monkeypatch): |
| 179 | + """Test that CometLogger.log_hyperparams calls internal API method.""" |
| 180 | + _patch_comet_atexit(monkeypatch) |
| 181 | + |
| 182 | + logger = CometLogger(project_name="test") |
| 183 | + model = Mock() |
| 184 | + |
| 185 | + logger.log_graph(model=model) |
| 186 | + |
| 187 | + logger.experiment.__internal_api__set_model_graph__.assert_called_once_with( |
| 188 | + graph=model, |
| 189 | + framework="pytorch-lightning", |
| 190 | + ) |
| 191 | + |
| 192 | + |
151 | 193 | @mock.patch.dict(os.environ, {}) |
152 | 194 | def test_comet_metrics_safe(comet_mock, tmp_path, monkeypatch): |
153 | 195 | """Test that CometLogger.log_metrics doesn't do inplace modification of metrics.""" |
|
0 commit comments