@@ -57,27 +57,28 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
5757
5858 candidate_blocks = self .buffers .get (buffer_name , [])
5959
60- # Find the best-fit available buffer.
61- best_fit_block : Optional [BufferBlock ] = None
62- smallest_sufficient_size = float ('inf' )
63- for block in candidate_blocks :
64- # Skip buffers that are too small.
65- if block .buffer .numel () < required_memory_size :
66- continue
67-
68- # Find the smallest buffer that is still large enough (best-fit).
69- if block .buffer .numel () < smallest_sufficient_size :
70- # Use reserved block if find one.
71- if best_fit_block is not None and best_fit_block .is_reserved and not block .is_reserved :
60+ if reserve_buffer :
61+ # Find the best-fit available buffer.
62+ best_fit_block : Optional [BufferBlock ] = None
63+ smallest_sufficient_size = float ('inf' )
64+ for block in candidate_blocks :
65+ # Skip buffers that are too small.
66+ if block .buffer .numel () < required_memory_size :
7267 continue
7368
74- best_fit_block = block
75- smallest_sufficient_size = block .buffer .numel ()
69+ # Find the smallest buffer that is still large enough (best-fit).
70+ if block .buffer .numel () < smallest_sufficient_size :
71+ # Use reserved block if find one.
72+ if best_fit_block is not None and best_fit_block .is_reserved and not block .is_reserved :
73+ continue
7674
77- if reserve_buffer and best_fit_block is not None :
78- # A suitable buffer was found, so reuse it.
79- best_fit_block .is_reserved = True
80- return self ._view_as (best_fit_block .buffer , tensor_shape , dtype )
75+ best_fit_block = block
76+ smallest_sufficient_size = block .buffer .numel ()
77+
78+ if best_fit_block is not None :
79+ # A suitable buffer was found, so reuse it.
80+ best_fit_block .is_reserved = True
81+ return self ._view_as (best_fit_block .buffer , tensor_shape , dtype )
8182
8283 for block in list (candidate_blocks ):
8384 if not block .is_reserved :
@@ -88,22 +89,27 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
8889
8990 # No suitable buffer was found, so allocate a new one.
9091 # The new buffer is created with uint8 to represent raw bytes.
92+ def _create_buffer ():
93+ return torch .zeros ((required_memory_size , ),
94+ device = 'cuda' ,
95+ dtype = torch .uint8 )
96+
9197 new_buffer_tensor = None
9298 try :
93- with torch .cuda .memory .use_mem_pool (get_shared_pool ()):
94- new_buffer_tensor = torch .zeros ((required_memory_size , ),
95- device = 'cuda' ,
96- dtype = torch .uint8 )
99+ mem_pool = get_shared_pool ()
100+ if mem_pool is not None :
101+ with torch .cuda .memory .use_mem_pool ():
102+ new_buffer_tensor = _create_buffer ()
103+ else :
104+ new_buffer_tensor = _create_buffer ()
97105 except Exception as ex :
98106 # Need to check if this is an OOM exception
99107 logger .debug (
100108 f"Exception happened to create tensor from given memory pool: { str (ex )} "
101109 )
102110 # if exception happens during allocating memory from shared pool, retry
103111 # to allocate from default pool
104- new_buffer_tensor = torch .zeros ((required_memory_size , ),
105- device = 'cuda' ,
106- dtype = torch .uint8 )
112+ new_buffer_tensor = _create_buffer ()
107113
108114 new_block = BufferBlock (buffer = new_buffer_tensor ,
109115 is_reserved = reserve_buffer )
0 commit comments