Skip to content

Commit da554f9

Browse files
authored
[Bug] Fix Negative Cuda Memory Usage (vllm-project#25683)
Signed-off-by: yewentao256 <[email protected]>
1 parent aac622e commit da554f9

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3517,7 +3517,6 @@ def capture_model(self) -> int:
35173517
compilation_counter.num_gpu_runner_capture_triggers += 1
35183518

35193519
start_time = time.perf_counter()
3520-
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
35213520

35223521
@contextmanager
35233522
def freeze_gc():
@@ -3540,6 +3539,7 @@ def freeze_gc():
35403539
# can reuse the memory pool allocated for the large shapes.
35413540
set_cudagraph_capturing_enabled(True)
35423541
with freeze_gc(), graph_capture(device=self.device):
3542+
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
35433543
cudagraph_mode = self.compilation_config.cudagraph_mode
35443544
assert cudagraph_mode is not None
35453545
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
@@ -3568,6 +3568,9 @@ def freeze_gc():
35683568
cudagraph_runtime_mode=CUDAGraphMode.FULL,
35693569
uniform_decode=True)
35703570

3571+
torch.cuda.synchronize()
3572+
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
3573+
35713574
# Disable cudagraph capturing globally, so any unexpected cudagraph
35723575
# capturing will be detected and raise an error after here.
35733576
# Note: We don't put it into graph_capture context manager because
@@ -3576,7 +3579,6 @@ def freeze_gc():
35763579
set_cudagraph_capturing_enabled(False)
35773580

35783581
end_time = time.perf_counter()
3579-
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
35803582
elapsed_time = end_time - start_time
35813583
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
35823584
# This usually takes 5~20 seconds.

0 commit comments

Comments
 (0)