@@ -79,40 +79,55 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
7979 best_fit_block = block
8080 smallest_sufficient_size = block .buffer .numel ()
8181
82+ for block in list (candidate_blocks ):
83+ if not block .is_reserved :
84+ if best_fit_block is not None :
85+ if block is not best_fit_block :
86+ # Need to call del BufferBlock.buffer, otherwise memory isn't
87+ # released and OOM may happen.
88+ del block .buffer
89+ candidate_blocks .remove (block )
90+ else :
91+ del block .buffer
92+ candidate_blocks .remove (block )
93+
8294 if best_fit_block is not None :
8395 if reserve_buffer :
96+ # A suitable buffer was found, so reuse it.
8497 best_fit_block .is_reserved = True
85- # A suitable buffer was found, so reuse it.
86- return self ._view_as (best_fit_block .buffer , tensor_shape , dtype )
87-
88- for block in list (candidate_blocks ):
89- if not block .is_reserved :
90- # Need to call del BufferBlock.buffer, otherwise memory isn't
91- # released and OOM may happen.
92- buffer_size = block .buffer .numel ()
93- del block .buffer
94- if buffer_size >= 1024 * 1024 * 1024 :
95- torch .cuda .empty_cache ()
96- candidate_blocks .remove (block )
98+ return self ._view_as (best_fit_block .buffer , tensor_shape , dtype )
99+ else :
100+ # TODO: to reuse tensors both in graph pool and normal pool.
101+ if best_fit_block .is_reserved :
102+ return self ._view_as (best_fit_block .buffer , tensor_shape ,
103+ dtype )
104+ else :
105+ del best_fit_block .buffer
106+ candidate_blocks .remove (best_fit_block )
107+
108+ def _create_buffer ():
109+ return torch .zeros ((required_memory_size , ),
110+ device = 'cuda' ,
111+ dtype = torch .uint8 )
97112
98113 # No suitable buffer was found, so allocate a new one.
99114 # The new buffer is created with uint8 to represent raw bytes.
100115 new_buffer_tensor = None
101116 try :
102- with torch .cuda .memory .use_mem_pool (get_shared_pool ()):
103- new_buffer_tensor = torch .empty ((required_memory_size , ),
104- device = 'cuda' ,
105- dtype = torch .uint8 )
117+ new_buffer_tensor = _create_buffer ()
106118 except Exception as ex :
107- # Need to check if this is an OOM exception
119+ # Need to check if this is an OOM exception``
108120 logger .debug (
109121 f"Exception happened to create tensor from given memory pool: { str (ex )} "
110122 )
111- # if exception happens during allocating memory from shared pool, retry
112- # to allocate from default pool
113- new_buffer_tensor = torch .empty ((required_memory_size , ),
114- device = 'cuda' ,
115- dtype = torch .uint8 )
123+ # if exception happens during allocating memory from default pool, retry
124+ # to allocate from shared pool. Try best to avoid fragmentation in shared pool.
125+ mem_pool = get_shared_pool ()
126+ if mem_pool is not None :
127+ with torch .cuda .memory .use_mem_pool (mem_pool ):
128+ new_buffer_tensor = _create_buffer ()
129+ else :
130+ raise ex
116131
117132 new_block = BufferBlock (buffer = new_buffer_tensor ,
118133 is_reserved = reserve_buffer )
0 commit comments