@@ -472,13 +472,24 @@ def test_neptune_scale_logger_invalid_run():
472472
473473
474474@pytest .mark .skipif (not _NEPTUNE_SCALE_AVAILABLE , reason = "Neptune-Scale is required for this test." )
475- def test_neptune_scale_logger_log_model_summary (neptune_scale_logger , caplog ):
476- """Test that log_model_summary shows warning."""
477- logger = NeptuneScaleLogger ( log_model_checkpoints = True )
475+ def test_neptune_scale_logger_log_model_summary (neptune_scale_logger ):
476+ from neptune_scale . types import File
477+
478478 model = BoringModel ()
479+ test_variants = [
480+ ({}, "training/model/summary" ),
481+ ({"prefix" : "custom_prefix" }, "custom_prefix/model/summary" ),
482+ ({"prefix" : "custom/nested/prefix" }, "custom/nested/prefix/model/summary" ),
483+ ]
479484
480- logger .log_model_summary (model )
481- assert "Neptune Scale does not support logging model summaries" in caplog .text
485+ for prefix , model_summary_key in test_variants :
486+ logger , run_instance_mock , _ = _get_logger_with_mocks (api_key = "test" , project = "project" , ** prefix )
487+
488+ logger .log_model_summary (model )
489+
490+ assert run_instance_mock .__setitem__ .call_count == 1
491+ assert run_instance_mock .__getitem__ .call_count == 0
492+ run_instance_mock .__setitem__ .assert_called_once_with (model_summary_key , File )
482493
483494
484495@pytest .mark .skipif (not _NEPTUNE_SCALE_AVAILABLE , reason = "Neptune-Scale is required for this test." )
@@ -496,3 +507,39 @@ def test_neptune_scale_logger_with_prefix(neptune_scale_logger):
496507 metrics = {"loss" : 1.23 }
497508 logger .log_metrics (metrics , step = 5 )
498509 mock_run .log_metrics .assert_called_once_with ({"training/loss" : 1.23 }, step = 5 )
510+
511+
512+ @pytest .mark .skipif (not _NEPTUNE_SCALE_AVAILABLE , reason = "Neptune-Scale is required for this test." )
513+ def test_neptune_scale_logger_after_save_checkpoint (neptune_scale_logger ):
514+ test_variants = [
515+ ({}, "training/model" ),
516+ ({"prefix" : "custom_prefix" }, "custom_prefix/model" ),
517+ ({"prefix" : "custom/nested/prefix" }, "custom/nested/prefix/model" ),
518+ ]
519+
520+ for prefix , model_key_prefix in test_variants :
521+ logger , run_instance_mock , run_attr_mock = _get_logger_with_mocks (api_key = "test" , project = "project" , ** prefix )
522+ models_root_dir = os .path .join ("path" , "to" , "models" )
523+ cb_mock = MagicMock (
524+ dirpath = models_root_dir ,
525+ last_model_path = os .path .join (models_root_dir , "last" ),
526+ best_k_models = {
527+ f"{ os .path .join (models_root_dir , 'model1' )} " : None ,
528+ f"{ os .path .join (models_root_dir , 'model2/with/slashes' )} " : None ,
529+ },
530+ best_model_path = os .path .join (models_root_dir , "best_model" ),
531+ best_model_score = None ,
532+ )
533+
534+ mock_file = neptune_scale_logger .types .File
535+ mock_file .reset_mock ()
536+ mock_file .side_effect = mock .Mock ()
537+ logger .after_save_checkpoint (cb_mock )
538+
539+ run_instance_mock .__getitem__ .assert_any_call (f"{ model_key_prefix } /checkpoints/model1" )
540+ run_instance_mock .__getitem__ .assert_any_call (f"{ model_key_prefix } /checkpoints/model2/with/slashes" )
541+
542+ run_attr_mock .upload .assert_has_calls ([
543+ call (os .path .join (models_root_dir , "model1" )),
544+ call (os .path .join (models_root_dir , "model2/with/slashes" )),
545+ ])
0 commit comments