11import copy
22import os
3+ from pathlib import Path
34from unittest .mock import MagicMock
45
56import matplotlib
@@ -483,59 +484,61 @@ def test_no_matplotlib(no_site_packages, lr_finder):
483484 lr_finder .plot ()
484485
485486
486- def test_plot_single_param_group (lr_finder , mnist_to_save , dummy_engine_mnist , mnist_dataloader ):
487+ def test_plot_single_param_group (dirname , lr_finder , mnist_to_save , dummy_engine_mnist , mnist_dataloader ):
487488
488489 with lr_finder .attach (dummy_engine_mnist , mnist_to_save , end_lr = 20 , smooth_f = 0.04 ) as trainer_with_finder :
489490 trainer_with_finder .run (mnist_dataloader )
490491
492+ def _test (ax ):
493+ assert ax is not None
494+ assert ax .get_xscale () == "log"
495+ assert ax .get_xlabel () == "Learning rate"
496+ assert ax .get_ylabel () == "Loss"
497+ filepath = Path (dirname ) / "dummy.jpg"
498+ ax .figure .savefig (filepath )
499+ assert filepath .exists ()
500+ filepath .unlink ()
501+
491502 lr_finder .plot ()
492503 ax = lr_finder .plot (skip_end = 0 )
493- assert ax is not None
494- assert ax .get_xscale () == "log"
495- assert ax .get_xlabel () == "Learning rate"
496- assert ax .get_ylabel () == "Loss"
497- ax .figure .savefig ("dummy.jpg" )
498- assert os .path .exists ("dummy.jpg" )
504+ _test (ax )
499505
500506 # Passing axes object
501507 from matplotlib import pyplot as plt
502508
503- fig , ax = plt .subplots ()
509+ _ , ax = plt .subplots ()
504510 lr_finder .plot (skip_end = 0 , ax = ax )
505- assert ax .get_xscale () == "log"
506- assert ax .get_xlabel () == "Learning rate"
507- assert ax .get_ylabel () == "Loss"
508- ax .figure .savefig ("dummy2.jpg" )
509- assert os .path .exists ("dummy2.jpg" )
511+ _test (ax )
510512
511513
512514def test_plot_multiple_param_groups (
513- lr_finder , to_save_mulitple_param_groups , dummy_engine_mulitple_param_groups , dataloader_plot
515+ dirname , lr_finder , to_save_mulitple_param_groups , dummy_engine_mulitple_param_groups , dataloader_plot
514516):
515517
516518 with lr_finder .attach (
517519 dummy_engine_mulitple_param_groups , to_save_mulitple_param_groups , end_lr = 20 , smooth_f = 0.04
518520 ) as trainer_with_finder :
519521 trainer_with_finder .run (dataloader_plot )
520522
523+ def _test (ax ):
524+ assert ax is not None
525+ assert ax .get_xscale () == "log"
526+ assert ax .get_xlabel () == "Learning rate"
527+ assert ax .get_ylabel () == "Loss"
528+ filepath = Path (dirname ) / "dummy_muliple_param_groups.jpg"
529+ ax .figure .savefig (filepath )
530+ assert filepath .exists ()
531+ filepath .unlink ()
532+
521533 ax = lr_finder .plot (skip_start = 0 , skip_end = 0 )
522- assert ax is not None
523- assert ax .get_xscale () == "log"
524- assert ax .get_xlabel () == "Learning rate"
525- assert ax .get_ylabel () == "Loss"
526- ax .figure .savefig ("dummy_muliple_param_groups.jpg" )
527- assert os .path .exists ("dummy_muliple_param_groups.jpg" )
534+ _test (ax )
528535
529536 # Passing axes object
530537 from matplotlib import pyplot as plt
531538
532- fig , ax = plt .subplots ()
539+ _ , ax = plt .subplots ()
533540 lr_finder .plot (skip_start = 0 , skip_end = 0 , ax = ax )
534- assert ax .get_xscale () == "log"
535- assert ax .get_xlabel () == "Learning rate"
536- assert ax .get_ylabel () == "Loss"
537- ax .figure .savefig ("dummy_muliple_param_groups2.jpg" )
538- assert os .path .exists ("dummy_muliple_param_groups2.jpg" )
541+ _test (ax )
539542
540543
541544def _test_distrib_log_lr_and_loss (device ):
@@ -562,7 +565,7 @@ def _test_distrib_log_lr_and_loss(device):
562565 assert pytest .approx (lr_finder ._history ["loss" ][- 1 ]) == expected_loss
563566
564567
565- def _test_distrib_integration_mnist (device ):
568+ def _test_distrib_integration_mnist (dirname , device ):
566569 from torch .utils .data import DataLoader
567570 from torchvision .datasets import MNIST
568571 from torchvision .transforms import Compose , Normalize , ToTensor
@@ -598,8 +601,9 @@ def forward(self, x):
598601
599602 if idist .get_rank () == 0 :
600603 ax = lr_finder .plot (skip_end = 0 )
601- ax .figure .savefig ("distrib_dummy.jpg" )
602- assert os .path .exists ("distrib_dummy.jpg" )
604+ filepath = Path (dirname ) / "distrib_dummy.jpg"
605+ ax .figure .savefig (filepath )
606+ assert filepath .exists ()
603607
604608 sug_lr = lr_finder .lr_suggestion ()
605609 assert 1e-3 <= sug_lr <= 1
@@ -610,37 +614,37 @@ def forward(self, x):
610614
611615@pytest .mark .distributed
612616@pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
613- def test_distrib_gloo_cpu_or_gpu (distributed_context_single_node_gloo ):
617+ def test_distrib_gloo_cpu_or_gpu (dirname , distributed_context_single_node_gloo ):
614618
615619 device = idist .device ()
616620 _test_distrib_log_lr_and_loss (device )
617- _test_distrib_integration_mnist (device )
621+ _test_distrib_integration_mnist (dirname , device )
618622
619623
620624@pytest .mark .distributed
621625@pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
622626@pytest .mark .skipif (torch .cuda .device_count () < 1 , reason = "Skip if no GPU" )
623- def test_distrib_nccl_gpu (distributed_context_single_node_nccl ):
627+ def test_distrib_nccl_gpu (dirname , distributed_context_single_node_nccl ):
624628
625629 device = idist .device ()
626630 _test_distrib_log_lr_and_loss (device )
627- _test_distrib_integration_mnist (device )
631+ _test_distrib_integration_mnist (dirname , device )
628632
629633
630634@pytest .mark .tpu
631635@pytest .mark .skipif ("NUM_TPU_WORKERS" in os .environ , reason = "Skip if NUM_TPU_WORKERS is in env vars" )
632636@pytest .mark .skipif (not idist .has_xla_support , reason = "Not on TPU device" )
633- def test_distrib_single_device_xla ():
637+ def test_distrib_single_device_xla (dirname ):
634638 device = idist .device ()
635639 assert "xla" in device .type
636640 _test_distrib_log_lr_and_loss (device )
637- _test_distrib_integration_mnist (device )
641+ _test_distrib_integration_mnist (dirname , device )
638642
639643
640- def _test_distrib_log_lr_and_loss_xla_nprocs (index ):
644+ def _test_distrib_log_lr_and_loss_xla_nprocs (index , dirname ):
641645 device = idist .device ()
642646 _test_distrib_log_lr_and_loss (device )
643- _test_distrib_integration_mnist (device )
647+ _test_distrib_integration_mnist (dirname , device )
644648
645649 import time
646650
@@ -649,8 +653,8 @@ def _test_distrib_log_lr_and_loss_xla_nprocs(index):
649653
650654
651655@pytest .mark .tpu
652- @pytest .mark .skipif ("NUM_TPU_WORKERS" not in os .environ , reason = "Skip if NUM_TPU_WORKERS is in env vars" )
656+ @pytest .mark .skipif ("NUM_TPU_WORKERS" not in os .environ , reason = "Skip if no NUM_TPU_WORKERS is in env vars" )
653657@pytest .mark .skipif (not idist .has_xla_support , reason = "Not on TPU device" )
654- def test_distrib_single_device_xla_nprocs ( xmp_executor ):
658+ def test_distrib_xla_nprocs ( dirname , xmp_executor ):
655659 n = int (os .environ ["NUM_TPU_WORKERS" ])
656- xmp_executor (_test_distrib_log_lr_and_loss_xla_nprocs , args = (), nprocs = n )
660+ xmp_executor (_test_distrib_log_lr_and_loss_xla_nprocs , args = (dirname , ), nprocs = n )
0 commit comments