Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 60 additions & 59 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5456,67 +5456,68 @@ def test_mempool_limited_memory_with_allocator(self):
nelem_1mb = 1024 * 1024 // 4

self._setup_mempool_limited_memory_test(80)
# remaining free mem: 80 mb
# mempool_use [] 0 mb
# mempool_do_not_use [] 0 mb
# default pool [] 0 mb
with torch.cuda.use_mem_pool(pool_do_not_use):
a = torch.randn(40 * nelem_1mb, device="cuda")
with torch.cuda.use_mem_pool(pool_use):
b = torch.randn(40 * nelem_1mb, device="cuda")
a_dataptr = a.data_ptr()
b_dataptr = b.data_ptr()
# remaining free mem: 0 mb
# mempool_do_not_use [aaaa] 40 mb
# mempool_use [bbbb] 40 mb
# default pool [] 0 mb
with self.assertRaises(torch.OutOfMemoryError):
# out of memory
c = torch.randn(40 * nelem_1mb, device="cuda")

del a, b
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [____] 40 mb
# default pool [] 0 mb

# c should not oom and instead can use mempool_use as fallback
c = torch.randn(30 * nelem_1mb, device="cuda")
c_dataptr = c.data_ptr()
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [ccc_] 40 mb
# default pool [] 0 mb
with self.assertRaises(torch.OutOfMemoryError):
# out of memory since can't use mempool_do_not_use
d = torch.randn(30 * nelem_1mb, device="cuda")

del c
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [____] 40 mb
# default pool [] 0 mb

# expect that we used same memory address for both a and c
self.assertEqual(b_dataptr, c_dataptr)

# make sure we can still use mempool_use as intended after c is deleted
with torch.cuda.use_mem_pool(pool_use):
e = torch.randn(20 * nelem_1mb, device="cuda")
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [ee__] 40 mb
# default pool [] 0 mb

e_dataptr = e.data_ptr()
del e

self.assertEqual(e_dataptr, c_dataptr)
try:
# remaining free mem: 80 mb
# mempool_use [] 0 mb
# mempool_do_not_use [] 0 mb
# default pool [] 0 mb
with torch.cuda.use_mem_pool(pool_do_not_use):
a = torch.randn(40 * nelem_1mb, device="cuda")
with torch.cuda.use_mem_pool(pool_use):
b = torch.randn(40 * nelem_1mb, device="cuda")
a_dataptr = a.data_ptr()
b_dataptr = b.data_ptr()
# remaining free mem: 0 mb
# mempool_do_not_use [aaaa] 40 mb
# mempool_use [bbbb] 40 mb
# default pool [] 0 mb
with self.assertRaises(torch.OutOfMemoryError):
# out of memory
c = torch.randn(40 * nelem_1mb, device="cuda")

# pool's destructor calls emptyCache()
del pool_use, pool_do_not_use
del a, b
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [____] 40 mb
# default pool [] 0 mb

# c should not oom and instead can use mempool_use as fallback
c = torch.randn(30 * nelem_1mb, device="cuda")
c_dataptr = c.data_ptr()
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [ccc_] 40 mb
# default pool [] 0 mb
with self.assertRaises(torch.OutOfMemoryError):
# out of memory since can't use mempool_do_not_use
d = torch.randn(30 * nelem_1mb, device="cuda")

del c
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [____] 40 mb
# default pool [] 0 mb

self._teardown_mempool_limited_memory_test()
# expect that we used same memory address for both a and c
self.assertEqual(b_dataptr, c_dataptr)

# make sure we can still use mempool_use as intended after c is deleted
with torch.cuda.use_mem_pool(pool_use):
e = torch.randn(20 * nelem_1mb, device="cuda")
# remaining free mem: 0 mb
# mempool_do_not_use [____] 40 mb
# mempool_use [ee__] 40 mb
# default pool [] 0 mb

e_dataptr = e.data_ptr()
del e

self.assertEqual(e_dataptr, c_dataptr)

# pool's destructor calls emptyCache()
del pool_use, pool_do_not_use
finally:
self._teardown_mempool_limited_memory_test()

def test_mempool_multithread(self):
pool_ids = []
Expand Down