|
9 | 9 |
|
10 | 10 | from __future__ import annotations |
11 | 11 |
|
12 | | -import os |
13 | 12 | import tempfile |
14 | 13 | import unittest |
15 | 14 | from unittest.mock import Mock, patch |
16 | 15 |
|
17 | | -import torch.distributed.launcher as launcher |
18 | 16 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator |
19 | | -from torch import distributed as dist |
20 | 17 |
|
21 | 18 | from torchtnt.utils.loggers.tensorboard import TensorBoardLogger |
22 | | -from torchtnt.utils.test_utils import get_pet_launch_config, skip_if_not_distributed |
23 | 19 |
|
24 | 20 |
|
25 | 21 | class TensorBoardLoggerTest(unittest.TestCase): |
@@ -74,26 +70,6 @@ def test_log_rank_zero(self: TensorBoardLoggerTest) -> None: |
74 | 70 | logger = TensorBoardLogger(path=log_dir) |
75 | 71 | self.assertEqual(logger.writer, None) |
76 | 72 |
|
77 | | - @staticmethod |
78 | | - def _test_distributed() -> None: |
79 | | - dist.init_process_group("gloo") |
80 | | - rank = dist.get_rank() |
81 | | - with tempfile.TemporaryDirectory() as log_dir: |
82 | | - test_path = "correct" |
83 | | - invalid_path = "invalid" |
84 | | - if rank == 0: |
85 | | - logger = TensorBoardLogger(os.path.join(log_dir, test_path)) |
86 | | - else: |
87 | | - logger = TensorBoardLogger(os.path.join(log_dir, invalid_path)) |
88 | | - |
89 | | - assert test_path in logger.path |
90 | | - assert invalid_path not in logger.path |
91 | | - |
92 | | - @skip_if_not_distributed |
93 | | - def test_multiple_workers(self: TensorBoardLoggerTest) -> None: |
94 | | - config = get_pet_launch_config(2) |
95 | | - launcher.elastic_launch(config, entrypoint=self._test_distributed)() |
96 | | - |
97 | 73 | def test_add_scalars_call_is_correctly_passed_to_summary_writer( |
98 | 74 | self: TensorBoardLoggerTest, |
99 | 75 | ) -> None: |
|
0 commit comments