Skip to content

Commit 1752dfd

Browse files
authored
Relax the clear_cuda_cache (#1406)
1 parent 7ce3702 commit 1752dfd

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tests/conftest.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,13 @@ def get_device_properties(device: torch.device) -> dict:
146146

147147

148148
def clear_cuda_cache(device: torch.device) -> None:
149-
if (
150-
torch.cuda.memory_allocated()
151-
or torch.cuda.memory_reserved()
152-
> 0.8 * get_device_properties(device).total_memory
153-
):
149+
total_memory = get_device_properties(device).total_memory
150+
reserved_memory = torch.cuda.memory_reserved()
151+
152+
# FLASHINFER_TEST_MEMORY_THRESHOLD: threshold for PyTorch reserved memory usage (default: 0.9)
153+
threshold = float(os.environ.get("FLASHINFER_TEST_MEMORY_THRESHOLD", "0.9"))
154+
155+
if reserved_memory > threshold * total_memory:
154156
gc.collect()
155157
torch.cuda.empty_cache()
156158

0 commit comments

Comments
 (0)