Skip to content

Commit e1347db

Browse files
authored
Updated temp test files location and remove them (#2419)
* Updated temp test files location and remove them * More fixes and cosmetics
1 parent 192d721 commit e1347db

File tree

3 files changed

+51
-45
lines changed

3 files changed

+51
-45
lines changed

tests/ignite/handlers/test_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,9 +1327,9 @@ def _test_tpu_saves_to_cpu_nprocs(index, dirname):
13271327

13281328

13291329
@pytest.mark.tpu
1330-
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
1330+
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS is in env vars")
13311331
@pytest.mark.skipif(not idist.has_xla_support, reason="Not on TPU device")
1332-
def test_distrib_single_device_xla_nprocs(xmp_executor, dirname):
1332+
def test_distrib_xla_nprocs(xmp_executor, dirname):
13331333
n = int(os.environ["NUM_TPU_WORKERS"])
13341334
xmp_executor(_test_tpu_saves_to_cpu_nprocs, args=(dirname,), nprocs=n)
13351335

tests/ignite/handlers/test_lr_finder.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import os
3+
from pathlib import Path
34
from unittest.mock import MagicMock
45

56
import matplotlib
@@ -483,59 +484,61 @@ def test_no_matplotlib(no_site_packages, lr_finder):
483484
lr_finder.plot()
484485

485486

486-
def test_plot_single_param_group(lr_finder, mnist_to_save, dummy_engine_mnist, mnist_dataloader):
487+
def test_plot_single_param_group(dirname, lr_finder, mnist_to_save, dummy_engine_mnist, mnist_dataloader):
487488

488489
with lr_finder.attach(dummy_engine_mnist, mnist_to_save, end_lr=20, smooth_f=0.04) as trainer_with_finder:
489490
trainer_with_finder.run(mnist_dataloader)
490491

492+
def _test(ax):
493+
assert ax is not None
494+
assert ax.get_xscale() == "log"
495+
assert ax.get_xlabel() == "Learning rate"
496+
assert ax.get_ylabel() == "Loss"
497+
filepath = Path(dirname) / "dummy.jpg"
498+
ax.figure.savefig(filepath)
499+
assert filepath.exists()
500+
filepath.unlink()
501+
491502
lr_finder.plot()
492503
ax = lr_finder.plot(skip_end=0)
493-
assert ax is not None
494-
assert ax.get_xscale() == "log"
495-
assert ax.get_xlabel() == "Learning rate"
496-
assert ax.get_ylabel() == "Loss"
497-
ax.figure.savefig("dummy.jpg")
498-
assert os.path.exists("dummy.jpg")
504+
_test(ax)
499505

500506
# Passing axes object
501507
from matplotlib import pyplot as plt
502508

503-
fig, ax = plt.subplots()
509+
_, ax = plt.subplots()
504510
lr_finder.plot(skip_end=0, ax=ax)
505-
assert ax.get_xscale() == "log"
506-
assert ax.get_xlabel() == "Learning rate"
507-
assert ax.get_ylabel() == "Loss"
508-
ax.figure.savefig("dummy2.jpg")
509-
assert os.path.exists("dummy2.jpg")
511+
_test(ax)
510512

511513

512514
def test_plot_multiple_param_groups(
513-
lr_finder, to_save_mulitple_param_groups, dummy_engine_mulitple_param_groups, dataloader_plot
515+
dirname, lr_finder, to_save_mulitple_param_groups, dummy_engine_mulitple_param_groups, dataloader_plot
514516
):
515517

516518
with lr_finder.attach(
517519
dummy_engine_mulitple_param_groups, to_save_mulitple_param_groups, end_lr=20, smooth_f=0.04
518520
) as trainer_with_finder:
519521
trainer_with_finder.run(dataloader_plot)
520522

523+
def _test(ax):
524+
assert ax is not None
525+
assert ax.get_xscale() == "log"
526+
assert ax.get_xlabel() == "Learning rate"
527+
assert ax.get_ylabel() == "Loss"
528+
filepath = Path(dirname) / "dummy_muliple_param_groups.jpg"
529+
ax.figure.savefig(filepath)
530+
assert filepath.exists()
531+
filepath.unlink()
532+
521533
ax = lr_finder.plot(skip_start=0, skip_end=0)
522-
assert ax is not None
523-
assert ax.get_xscale() == "log"
524-
assert ax.get_xlabel() == "Learning rate"
525-
assert ax.get_ylabel() == "Loss"
526-
ax.figure.savefig("dummy_muliple_param_groups.jpg")
527-
assert os.path.exists("dummy_muliple_param_groups.jpg")
534+
_test(ax)
528535

529536
# Passing axes object
530537
from matplotlib import pyplot as plt
531538

532-
fig, ax = plt.subplots()
539+
_, ax = plt.subplots()
533540
lr_finder.plot(skip_start=0, skip_end=0, ax=ax)
534-
assert ax.get_xscale() == "log"
535-
assert ax.get_xlabel() == "Learning rate"
536-
assert ax.get_ylabel() == "Loss"
537-
ax.figure.savefig("dummy_muliple_param_groups2.jpg")
538-
assert os.path.exists("dummy_muliple_param_groups2.jpg")
541+
_test(ax)
539542

540543

541544
def _test_distrib_log_lr_and_loss(device):
@@ -562,7 +565,7 @@ def _test_distrib_log_lr_and_loss(device):
562565
assert pytest.approx(lr_finder._history["loss"][-1]) == expected_loss
563566

564567

565-
def _test_distrib_integration_mnist(device):
568+
def _test_distrib_integration_mnist(dirname, device):
566569
from torch.utils.data import DataLoader
567570
from torchvision.datasets import MNIST
568571
from torchvision.transforms import Compose, Normalize, ToTensor
@@ -598,8 +601,9 @@ def forward(self, x):
598601

599602
if idist.get_rank() == 0:
600603
ax = lr_finder.plot(skip_end=0)
601-
ax.figure.savefig("distrib_dummy.jpg")
602-
assert os.path.exists("distrib_dummy.jpg")
604+
filepath = Path(dirname) / "distrib_dummy.jpg"
605+
ax.figure.savefig(filepath)
606+
assert filepath.exists()
603607

604608
sug_lr = lr_finder.lr_suggestion()
605609
assert 1e-3 <= sug_lr <= 1
@@ -610,37 +614,37 @@ def forward(self, x):
610614

611615
@pytest.mark.distributed
612616
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
613-
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
617+
def test_distrib_gloo_cpu_or_gpu(dirname, distributed_context_single_node_gloo):
614618

615619
device = idist.device()
616620
_test_distrib_log_lr_and_loss(device)
617-
_test_distrib_integration_mnist(device)
621+
_test_distrib_integration_mnist(dirname, device)
618622

619623

620624
@pytest.mark.distributed
621625
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
622626
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
623-
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
627+
def test_distrib_nccl_gpu(dirname, distributed_context_single_node_nccl):
624628

625629
device = idist.device()
626630
_test_distrib_log_lr_and_loss(device)
627-
_test_distrib_integration_mnist(device)
631+
_test_distrib_integration_mnist(dirname, device)
628632

629633

630634
@pytest.mark.tpu
631635
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
632636
@pytest.mark.skipif(not idist.has_xla_support, reason="Not on TPU device")
633-
def test_distrib_single_device_xla():
637+
def test_distrib_single_device_xla(dirname):
634638
device = idist.device()
635639
assert "xla" in device.type
636640
_test_distrib_log_lr_and_loss(device)
637-
_test_distrib_integration_mnist(device)
641+
_test_distrib_integration_mnist(dirname, device)
638642

639643

640-
def _test_distrib_log_lr_and_loss_xla_nprocs(index):
644+
def _test_distrib_log_lr_and_loss_xla_nprocs(index, dirname):
641645
device = idist.device()
642646
_test_distrib_log_lr_and_loss(device)
643-
_test_distrib_integration_mnist(device)
647+
_test_distrib_integration_mnist(dirname, device)
644648

645649
import time
646650

@@ -649,8 +653,8 @@ def _test_distrib_log_lr_and_loss_xla_nprocs(index):
649653

650654

651655
@pytest.mark.tpu
652-
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
656+
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS is in env vars")
653657
@pytest.mark.skipif(not idist.has_xla_support, reason="Not on TPU device")
654-
def test_distrib_single_device_xla_nprocs(xmp_executor):
658+
def test_distrib_xla_nprocs(dirname, xmp_executor):
655659
n = int(os.environ["NUM_TPU_WORKERS"])
656-
xmp_executor(_test_distrib_log_lr_and_loss_xla_nprocs, args=(), nprocs=n)
660+
xmp_executor(_test_distrib_log_lr_and_loss_xla_nprocs, args=(dirname,), nprocs=n)

tests/ignite/handlers/test_state_param_scheduler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from pathlib import Path
23
from unittest.mock import patch
34

45
import pytest
@@ -291,14 +292,15 @@ def _test(scheduler_cls, scheduler_kwargs):
291292
_test(scheduler_cls, scheduler_kwargs)
292293

293294

294-
def test_torch_save_load():
295+
def test_torch_save_load(dirname):
295296

296297
lambda_state_parameter_scheduler = LambdaStateScheduler(
297298
param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99), create_new=True
298299
)
299300

300-
torch.save(lambda_state_parameter_scheduler, "dummy_lambda_state_parameter_scheduler.pt")
301-
loaded_lambda_state_parameter_scheduler = torch.load("dummy_lambda_state_parameter_scheduler.pt")
301+
filepath = Path(dirname) / "dummy_lambda_state_parameter_scheduler.pt"
302+
torch.save(lambda_state_parameter_scheduler, filepath)
303+
loaded_lambda_state_parameter_scheduler = torch.load(filepath)
302304

303305
engine1 = Engine(lambda e, b: None)
304306
lambda_state_parameter_scheduler.attach(engine1, Events.EPOCH_COMPLETED)

0 commit comments

Comments
 (0)