diff --git a/test/test_cuda.py b/test/test_cuda.py index d293601fad138..eabfc0652edf6 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -5456,67 +5456,68 @@ def test_mempool_limited_memory_with_allocator(self): nelem_1mb = 1024 * 1024 // 4 self._setup_mempool_limited_memory_test(80) - # remaining free mem: 80 mb - # mempool_use [] 0 mb - # mempool_do_not_use [] 0 mb - # default pool [] 0 mb - with torch.cuda.use_mem_pool(pool_do_not_use): - a = torch.randn(40 * nelem_1mb, device="cuda") - with torch.cuda.use_mem_pool(pool_use): - b = torch.randn(40 * nelem_1mb, device="cuda") - a_dataptr = a.data_ptr() - b_dataptr = b.data_ptr() - # remaining free mem: 0 mb - # mempool_do_not_use [aaaa] 40 mb - # mempool_use [bbbb] 40 mb - # default pool [] 0 mb - with self.assertRaises(torch.OutOfMemoryError): - # out of memory - c = torch.randn(40 * nelem_1mb, device="cuda") - - del a, b - # remaining free mem: 0 mb - # mempool_do_not_use [____] 40 mb - # mempool_use [____] 40 mb - # default pool [] 0 mb - - # c should not oom and instead can use mempool_use as fallback - c = torch.randn(30 * nelem_1mb, device="cuda") - c_dataptr = c.data_ptr() - # remaining free mem: 0 mb - # mempool_do_not_use [____] 40 mb - # mempool_use [ccc_] 40 mb - # default pool [] 0 mb - with self.assertRaises(torch.OutOfMemoryError): - # out of memory since can't use mempool_do_not_use - d = torch.randn(30 * nelem_1mb, device="cuda") - - del c - # remaining free mem: 0 mb - # mempool_do_not_use [____] 40 mb - # mempool_use [____] 40 mb - # default pool [] 0 mb - - # expect that we used same memory address for both a and c - self.assertEqual(b_dataptr, c_dataptr) - - # make sure we can still use mempool_use as intended after c is deleted - with torch.cuda.use_mem_pool(pool_use): - e = torch.randn(20 * nelem_1mb, device="cuda") - # remaining free mem: 0 mb - # mempool_do_not_use [____] 40 mb - # mempool_use [ee__] 40 mb - # default pool [] 0 mb - - e_dataptr = e.data_ptr() - del e - - self.assertEqual(e_dataptr, c_dataptr) + try: + # remaining free mem: 80 mb + # mempool_use [] 0 mb + # mempool_do_not_use [] 0 mb + # default pool [] 0 mb + with torch.cuda.use_mem_pool(pool_do_not_use): + a = torch.randn(40 * nelem_1mb, device="cuda") + with torch.cuda.use_mem_pool(pool_use): + b = torch.randn(40 * nelem_1mb, device="cuda") + a_dataptr = a.data_ptr() + b_dataptr = b.data_ptr() + # remaining free mem: 0 mb + # mempool_do_not_use [aaaa] 40 mb + # mempool_use [bbbb] 40 mb + # default pool [] 0 mb + with self.assertRaises(torch.OutOfMemoryError): + # out of memory + c = torch.randn(40 * nelem_1mb, device="cuda") - # pool's destructor calls emptyCache() - del pool_use, pool_do_not_use + del a, b + # remaining free mem: 0 mb + # mempool_do_not_use [____] 40 mb + # mempool_use [____] 40 mb + # default pool [] 0 mb + + # c should not oom and instead can use mempool_use as fallback + c = torch.randn(30 * nelem_1mb, device="cuda") + c_dataptr = c.data_ptr() + # remaining free mem: 0 mb + # mempool_do_not_use [____] 40 mb + # mempool_use [ccc_] 40 mb + # default pool [] 0 mb + with self.assertRaises(torch.OutOfMemoryError): + # out of memory since can't use mempool_do_not_use + d = torch.randn(30 * nelem_1mb, device="cuda") + + del c + # remaining free mem: 0 mb + # mempool_do_not_use [____] 40 mb + # mempool_use [____] 40 mb + # default pool [] 0 mb - self._teardown_mempool_limited_memory_test() + # expect that we used same memory address for both a and c + self.assertEqual(b_dataptr, c_dataptr) + + # make sure we can still use mempool_use as intended after c is deleted + with torch.cuda.use_mem_pool(pool_use): + e = torch.randn(20 * nelem_1mb, device="cuda") + # remaining free mem: 0 mb + # mempool_do_not_use [____] 40 mb + # mempool_use [ee__] 40 mb + # default pool [] 0 mb + + e_dataptr = e.data_ptr() + del e + + self.assertEqual(e_dataptr, c_dataptr) + + # pool's destructor calls emptyCache() + del pool_use, pool_do_not_use + finally: + self._teardown_mempool_limited_memory_test() def test_mempool_multithread(self): pool_ids = []