We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7ce3702 commit 1752dfdCopy full SHA for 1752dfd
tests/conftest.py
@@ -146,11 +146,13 @@ def get_device_properties(device: torch.device) -> dict:
146
147
148
def clear_cuda_cache(device: torch.device) -> None:
149
- if (
150
- torch.cuda.memory_allocated()
151
- or torch.cuda.memory_reserved()
152
- > 0.8 * get_device_properties(device).total_memory
153
- ):
+ total_memory = get_device_properties(device).total_memory
+ reserved_memory = torch.cuda.memory_reserved()
+
+ # FLASHINFER_TEST_MEMORY_THRESHOLD: threshold for PyTorch reserved memory usage (default: 0.9)
+ threshold = float(os.environ.get("FLASHINFER_TEST_MEMORY_THRESHOLD", "0.9"))
154
155
+ if reserved_memory > threshold * total_memory:
156
gc.collect()
157
torch.cuda.empty_cache()
158
0 commit comments