@@ -5456,67 +5456,68 @@ def test_mempool_limited_memory_with_allocator(self):
54565456 nelem_1mb = 1024 * 1024 // 4
54575457
54585458 self ._setup_mempool_limited_memory_test (80 )
5459- # remaining free mem: 80 mb
5460- # mempool_use [] 0 mb
5461- # mempool_do_not_use [] 0 mb
5462- # default pool [] 0 mb
5463- with torch .cuda .use_mem_pool (pool_do_not_use ):
5464- a = torch .randn (40 * nelem_1mb , device = "cuda" )
5465- with torch .cuda .use_mem_pool (pool_use ):
5466- b = torch .randn (40 * nelem_1mb , device = "cuda" )
5467- a_dataptr = a .data_ptr ()
5468- b_dataptr = b .data_ptr ()
5469- # remaining free mem: 0 mb
5470- # mempool_do_not_use [aaaa] 40 mb
5471- # mempool_use [bbbb] 40 mb
5472- # default pool [] 0 mb
5473- with self .assertRaises (torch .OutOfMemoryError ):
5474- # out of memory
5475- c = torch .randn (40 * nelem_1mb , device = "cuda" )
5476-
5477- del a , b
5478- # remaining free mem: 0 mb
5479- # mempool_do_not_use [____] 40 mb
5480- # mempool_use [____] 40 mb
5481- # default pool [] 0 mb
5482-
5483- # c should not oom and instead can use mempool_use as fallback
5484- c = torch .randn (30 * nelem_1mb , device = "cuda" )
5485- c_dataptr = c .data_ptr ()
5486- # remaining free mem: 0 mb
5487- # mempool_do_not_use [____] 40 mb
5488- # mempool_use [ccc_] 40 mb
5489- # default pool [] 0 mb
5490- with self .assertRaises (torch .OutOfMemoryError ):
5491- # out of memory since can't use mempool_do_not_use
5492- d = torch .randn (30 * nelem_1mb , device = "cuda" )
5493-
5494- del c
5495- # remaining free mem: 0 mb
5496- # mempool_do_not_use [____] 40 mb
5497- # mempool_use [____] 40 mb
5498- # default pool [] 0 mb
5499-
5500- # expect that we used same memory address for both a and c
5501- self .assertEqual (b_dataptr , c_dataptr )
5502-
5503- # make sure we can still use mempool_use as intended after c is deleted
5504- with torch .cuda .use_mem_pool (pool_use ):
5505- e = torch .randn (20 * nelem_1mb , device = "cuda" )
5506- # remaining free mem: 0 mb
5507- # mempool_do_not_use [____] 40 mb
5508- # mempool_use [ee__] 40 mb
5509- # default pool [] 0 mb
5510-
5511- e_dataptr = e .data_ptr ()
5512- del e
5513-
5514- self .assertEqual (e_dataptr , c_dataptr )
5459+ try :
5460+ # remaining free mem: 80 mb
5461+ # mempool_use [] 0 mb
5462+ # mempool_do_not_use [] 0 mb
5463+ # default pool [] 0 mb
5464+ with torch .cuda .use_mem_pool (pool_do_not_use ):
5465+ a = torch .randn (40 * nelem_1mb , device = "cuda" )
5466+ with torch .cuda .use_mem_pool (pool_use ):
5467+ b = torch .randn (40 * nelem_1mb , device = "cuda" )
5468+ a_dataptr = a .data_ptr ()
5469+ b_dataptr = b .data_ptr ()
5470+ # remaining free mem: 0 mb
5471+ # mempool_do_not_use [aaaa] 40 mb
5472+ # mempool_use [bbbb] 40 mb
5473+ # default pool [] 0 mb
5474+ with self .assertRaises (torch .OutOfMemoryError ):
5475+ # out of memory
5476+ c = torch .randn (40 * nelem_1mb , device = "cuda" )
55155477
5516- # pool's destructor calls emptyCache()
5517- del pool_use , pool_do_not_use
5478+ del a , b
5479+ # remaining free mem: 0 mb
5480+ # mempool_do_not_use [____] 40 mb
5481+ # mempool_use [____] 40 mb
5482+ # default pool [] 0 mb
5483+
5484+ # c should not oom and instead can use mempool_use as fallback
5485+ c = torch .randn (30 * nelem_1mb , device = "cuda" )
5486+ c_dataptr = c .data_ptr ()
5487+ # remaining free mem: 0 mb
5488+ # mempool_do_not_use [____] 40 mb
5489+ # mempool_use [ccc_] 40 mb
5490+ # default pool [] 0 mb
5491+ with self .assertRaises (torch .OutOfMemoryError ):
5492+ # out of memory since can't use mempool_do_not_use
5493+ d = torch .randn (30 * nelem_1mb , device = "cuda" )
5494+
5495+ del c
5496+ # remaining free mem: 0 mb
5497+ # mempool_do_not_use [____] 40 mb
5498+ # mempool_use [____] 40 mb
5499+ # default pool [] 0 mb
55185500
5519- self ._teardown_mempool_limited_memory_test ()
5501+ # expect that we used same memory address for both a and c
5502+ self .assertEqual (b_dataptr , c_dataptr )
5503+
5504+ # make sure we can still use mempool_use as intended after c is deleted
5505+ with torch .cuda .use_mem_pool (pool_use ):
5506+ e = torch .randn (20 * nelem_1mb , device = "cuda" )
5507+ # remaining free mem: 0 mb
5508+ # mempool_do_not_use [____] 40 mb
5509+ # mempool_use [ee__] 40 mb
5510+ # default pool [] 0 mb
5511+
5512+ e_dataptr = e .data_ptr ()
5513+ del e
5514+
5515+ self .assertEqual (e_dataptr , c_dataptr )
5516+
5517+ # pool's destructor calls emptyCache()
5518+ del pool_use , pool_do_not_use
5519+ finally :
5520+ self ._teardown_mempool_limited_memory_test ()
55205521
55215522 def test_mempool_multithread (self ):
55225523 pool_ids = []
0 commit comments