@@ -3517,7 +3517,6 @@ def capture_model(self) -> int:
3517
3517
compilation_counter .num_gpu_runner_capture_triggers += 1
3518
3518
3519
3519
start_time = time .perf_counter ()
3520
- start_free_gpu_memory = torch .cuda .mem_get_info ()[0 ]
3521
3520
3522
3521
@contextmanager
3523
3522
def freeze_gc ():
@@ -3540,6 +3539,7 @@ def freeze_gc():
3540
3539
# can reuse the memory pool allocated for the large shapes.
3541
3540
set_cudagraph_capturing_enabled (True )
3542
3541
with freeze_gc (), graph_capture (device = self .device ):
3542
+ start_free_gpu_memory = torch .cuda .mem_get_info ()[0 ]
3543
3543
cudagraph_mode = self .compilation_config .cudagraph_mode
3544
3544
assert cudagraph_mode is not None
3545
3545
if cudagraph_mode .mixed_mode () != CUDAGraphMode .NONE :
@@ -3568,6 +3568,9 @@ def freeze_gc():
3568
3568
cudagraph_runtime_mode = CUDAGraphMode .FULL ,
3569
3569
uniform_decode = True )
3570
3570
3571
+ torch .cuda .synchronize ()
3572
+ end_free_gpu_memory = torch .cuda .mem_get_info ()[0 ]
3573
+
3571
3574
# Disable cudagraph capturing globally, so any unexpected cudagraph
3572
3575
# capturing will be detected and raise an error after here.
3573
3576
# Note: We don't put it into graph_capture context manager because
@@ -3576,7 +3579,6 @@ def freeze_gc():
3576
3579
set_cudagraph_capturing_enabled (False )
3577
3580
3578
3581
end_time = time .perf_counter ()
3579
- end_free_gpu_memory = torch .cuda .mem_get_info ()[0 ]
3580
3582
elapsed_time = end_time - start_time
3581
3583
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
3582
3584
# This usually takes 5~20 seconds.
0 commit comments