Skip to content

Commit 093952e

Browse files
rehm-gfacebook-github-bot
authored andcommitted
Don't CUDA sync by default on timer (#1029)
Summary: Pull Request resolved: #1029 Only CUDA sync when it is explicitly requested. In this case we are essentially performing CUDA sync by default Reviewed By: JKSenthil, diego-urgell Differential Revision: D81793986 fbshipit-source-id: 32e61c2f0ad425699b31298953dee9ea9b094bce
1 parent c1e2ca1 commit 093952e

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

tests/utils/test_timer_gpu.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,32 @@ class TimerGPUTest(unittest.TestCase):
2020
@skip_if_not_gpu
2121
@patch("torch.cuda.synchronize")
2222
def test_timer_synchronize(self, mock_synchronize: MagicMock) -> None:
23-
"""Make sure that torch.cuda.synchronize() is called when GPU is present."""
23+
"""Make sure that torch.cuda.synchronize() is not called by default when GPU is present."""
2424

2525
start_event = torch.cuda.Event(enable_timing=True)
2626
end_event = torch.cuda.Event(enable_timing=True)
2727
timer = Timer()
2828

2929
# Do not explicitly call synchronize, timer must call it for test to pass.
3030

31+
with timer.time("action_1"):
32+
start_event.record()
33+
time.sleep(0.5)
34+
end_event.record()
35+
36+
self.assertEqual(mock_synchronize.call_count, 0)
37+
38+
@skip_if_not_gpu
39+
@patch("torch.cuda.synchronize")
40+
def test_timer_synchronize_when_explicit(self, mock_synchronize: MagicMock) -> None:
41+
"""Make sure that torch.cuda.synchronize() is called when GPU is present and sync is explicit."""
42+
43+
start_event = torch.cuda.Event(enable_timing=True)
44+
end_event = torch.cuda.Event(enable_timing=True)
45+
timer = Timer(cuda_sync=True)
46+
47+
# Do not explicitly call synchronize, timer must call it for test to pass.
48+
3149
with timer.time("action_1"):
3250
start_event.record()
3351
time.sleep(0.5)

torchtnt/utils/timer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,7 @@ def __init__(
146146
raise ValueError(
147147
"CUDA must be available in order to enable CUDA synchronization."
148148
)
149-
self.cuda_sync: bool = (
150-
cuda_sync if cuda_sync is not None else torch.cuda.is_available()
151-
)
149+
self.cuda_sync: bool = cuda_sync if cuda_sync is not None else False
152150
self.verbose = verbose
153151
self.recorded_durations: Dict[str, List[float]] = defaultdict(list)
154152

0 commit comments

Comments
 (0)