Skip to content

Commit a804240

Browse files
committed
fix unit tests
1 parent 49722b1 commit a804240

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/tests_pytorch/loggers/test_all.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def log_metrics(self, metrics, step):
107107

108108
if logger_class == CometLogger:
109109
logger.experiment.id = "foo"
110+
logger._comet_config.offline_directory = None
110111
logger.experiment.project_name = "bar"
111112

112113
if logger_class == NeptuneLogger:
@@ -299,7 +300,9 @@ def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, neptune_moc
299300
_patch_comet_atexit(monkeypatch)
300301
logger = _instantiate_logger(CometLogger, save_dir=tmp_path, prefix=prefix)
301302
logger.log_metrics({"test": 1.0}, step=0)
302-
logger.experiment.log_metrics.assert_called_once_with({"tmp-test": 1.0}, epoch=None, step=0)
303+
logger.experiment.__internal_api__log_metrics__.assert_called_once_with(
304+
{"test": 1.0}, epoch=None, step=0, prefix=prefix, framework="pytorch-lightning"
305+
)
303306

304307
# MLflow
305308
Metric = mlflow_mock.entities.Metric

0 commit comments

Comments
 (0)