Skip to content

Commit f363ae8

Browse files
authored
reset per process memory fraction in test_cuda.py test_mempool_limited_memory_with_allocator (#2811)
Use try/finally block. This follows a similar pattern elsewhere in test_cuda.py. Fixes #ROCm/TheRock#2118.
1 parent 6ecd7c5 commit f363ae8

File tree

1 file changed

+60
-59
lines changed

1 file changed

+60
-59
lines changed

test/test_cuda.py

Lines changed: 60 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -5456,67 +5456,68 @@ def test_mempool_limited_memory_with_allocator(self):
54565456
nelem_1mb = 1024 * 1024 // 4
54575457

54585458
self._setup_mempool_limited_memory_test(80)
5459-
# remaining free mem: 80 mb
5460-
# mempool_use [] 0 mb
5461-
# mempool_do_not_use [] 0 mb
5462-
# default pool [] 0 mb
5463-
with torch.cuda.use_mem_pool(pool_do_not_use):
5464-
a = torch.randn(40 * nelem_1mb, device="cuda")
5465-
with torch.cuda.use_mem_pool(pool_use):
5466-
b = torch.randn(40 * nelem_1mb, device="cuda")
5467-
a_dataptr = a.data_ptr()
5468-
b_dataptr = b.data_ptr()
5469-
# remaining free mem: 0 mb
5470-
# mempool_do_not_use [aaaa] 40 mb
5471-
# mempool_use [bbbb] 40 mb
5472-
# default pool [] 0 mb
5473-
with self.assertRaises(torch.OutOfMemoryError):
5474-
# out of memory
5475-
c = torch.randn(40 * nelem_1mb, device="cuda")
5476-
5477-
del a, b
5478-
# remaining free mem: 0 mb
5479-
# mempool_do_not_use [____] 40 mb
5480-
# mempool_use [____] 40 mb
5481-
# default pool [] 0 mb
5482-
5483-
# c should not oom and instead can use mempool_use as fallback
5484-
c = torch.randn(30 * nelem_1mb, device="cuda")
5485-
c_dataptr = c.data_ptr()
5486-
# remaining free mem: 0 mb
5487-
# mempool_do_not_use [____] 40 mb
5488-
# mempool_use [ccc_] 40 mb
5489-
# default pool [] 0 mb
5490-
with self.assertRaises(torch.OutOfMemoryError):
5491-
# out of memory since can't use mempool_do_not_use
5492-
d = torch.randn(30 * nelem_1mb, device="cuda")
5493-
5494-
del c
5495-
# remaining free mem: 0 mb
5496-
# mempool_do_not_use [____] 40 mb
5497-
# mempool_use [____] 40 mb
5498-
# default pool [] 0 mb
5499-
5500-
# expect that we used same memory address for both a and c
5501-
self.assertEqual(b_dataptr, c_dataptr)
5502-
5503-
# make sure we can still use mempool_use as intended after c is deleted
5504-
with torch.cuda.use_mem_pool(pool_use):
5505-
e = torch.randn(20 * nelem_1mb, device="cuda")
5506-
# remaining free mem: 0 mb
5507-
# mempool_do_not_use [____] 40 mb
5508-
# mempool_use [ee__] 40 mb
5509-
# default pool [] 0 mb
5510-
5511-
e_dataptr = e.data_ptr()
5512-
del e
5513-
5514-
self.assertEqual(e_dataptr, c_dataptr)
5459+
try:
5460+
# remaining free mem: 80 mb
5461+
# mempool_use [] 0 mb
5462+
# mempool_do_not_use [] 0 mb
5463+
# default pool [] 0 mb
5464+
with torch.cuda.use_mem_pool(pool_do_not_use):
5465+
a = torch.randn(40 * nelem_1mb, device="cuda")
5466+
with torch.cuda.use_mem_pool(pool_use):
5467+
b = torch.randn(40 * nelem_1mb, device="cuda")
5468+
a_dataptr = a.data_ptr()
5469+
b_dataptr = b.data_ptr()
5470+
# remaining free mem: 0 mb
5471+
# mempool_do_not_use [aaaa] 40 mb
5472+
# mempool_use [bbbb] 40 mb
5473+
# default pool [] 0 mb
5474+
with self.assertRaises(torch.OutOfMemoryError):
5475+
# out of memory
5476+
c = torch.randn(40 * nelem_1mb, device="cuda")
55155477

5516-
# pool's destructor calls emptyCache()
5517-
del pool_use, pool_do_not_use
5478+
del a, b
5479+
# remaining free mem: 0 mb
5480+
# mempool_do_not_use [____] 40 mb
5481+
# mempool_use [____] 40 mb
5482+
# default pool [] 0 mb
5483+
5484+
# c should not oom and instead can use mempool_use as fallback
5485+
c = torch.randn(30 * nelem_1mb, device="cuda")
5486+
c_dataptr = c.data_ptr()
5487+
# remaining free mem: 0 mb
5488+
# mempool_do_not_use [____] 40 mb
5489+
# mempool_use [ccc_] 40 mb
5490+
# default pool [] 0 mb
5491+
with self.assertRaises(torch.OutOfMemoryError):
5492+
# out of memory since can't use mempool_do_not_use
5493+
d = torch.randn(30 * nelem_1mb, device="cuda")
5494+
5495+
del c
5496+
# remaining free mem: 0 mb
5497+
# mempool_do_not_use [____] 40 mb
5498+
# mempool_use [____] 40 mb
5499+
# default pool [] 0 mb
55185500

5519-
self._teardown_mempool_limited_memory_test()
5501+
# expect that we used same memory address for both a and c
5502+
self.assertEqual(b_dataptr, c_dataptr)
5503+
5504+
# make sure we can still use mempool_use as intended after c is deleted
5505+
with torch.cuda.use_mem_pool(pool_use):
5506+
e = torch.randn(20 * nelem_1mb, device="cuda")
5507+
# remaining free mem: 0 mb
5508+
# mempool_do_not_use [____] 40 mb
5509+
# mempool_use [ee__] 40 mb
5510+
# default pool [] 0 mb
5511+
5512+
e_dataptr = e.data_ptr()
5513+
del e
5514+
5515+
self.assertEqual(e_dataptr, c_dataptr)
5516+
5517+
# pool's destructor calls emptyCache()
5518+
del pool_use, pool_do_not_use
5519+
finally:
5520+
self._teardown_mempool_limited_memory_test()
55205521

55215522
def test_mempool_multithread(self):
55225523
pool_ids = []

0 commit comments

Comments
 (0)