16
16
17
17
import torch
18
18
19
+ from vllm .logger import init_logger
19
20
from vllm .utils import is_pin_memory_available
20
21
22
+ logger = init_logger (__name__ )
23
+
21
24
22
25
def find_loaded_library (lib_name ) -> Optional [str ]:
23
26
"""
@@ -165,6 +168,9 @@ def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
165
168
py_d_mem = allocation_handle [2 ]
166
169
self .pointer_to_data [py_d_mem ] = AllocationData (
167
170
allocation_handle , self .current_tag )
171
+ logger .debug (
172
+ "Allocated %s bytes for %s with address %s from cumem allocator" ,
173
+ allocation_handle [1 ], self .current_tag , py_d_mem )
168
174
return
169
175
170
176
def _python_free_callback (self , ptr : int ) -> HandleType :
@@ -174,6 +180,9 @@ def _python_free_callback(self, ptr: int) -> HandleType:
174
180
data = self .pointer_to_data .pop (ptr )
175
181
if data .cpu_backup_tensor is not None :
176
182
data .cpu_backup_tensor = None
183
+ logger .debug (
184
+ "Freed %s bytes for %s with address %s from cumem allocator" ,
185
+ data .handle [1 ], data .tag , ptr )
177
186
return data .handle
178
187
179
188
def sleep (
@@ -197,9 +206,14 @@ def sleep(
197
206
198
207
assert isinstance (offload_tags , tuple )
199
208
209
+ total_bytes = 0
210
+ backup_bytes = 0
211
+
200
212
for ptr , data in self .pointer_to_data .items ():
201
213
handle = data .handle
214
+ total_bytes += handle [1 ]
202
215
if data .tag in offload_tags :
216
+ backup_bytes += handle [1 ]
203
217
size_in_bytes = handle [1 ]
204
218
cpu_backup_tensor = torch .empty (
205
219
size_in_bytes ,
@@ -211,6 +225,12 @@ def sleep(
211
225
data .cpu_backup_tensor = cpu_backup_tensor
212
226
unmap_and_release (handle )
213
227
228
+ logger .info (
229
+ "CuMemAllocator: sleep freed %.2f GiB memory in total, of which "
230
+ "%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded "
231
+ "directly." , total_bytes / 1024 ** 3 , backup_bytes / 1024 ** 3 ,
232
+ (total_bytes - backup_bytes ) / 1024 ** 3 )
233
+
214
234
gc .collect ()
215
235
torch .cuda .empty_cache ()
216
236
@@ -267,12 +287,17 @@ def use_memory_pool(self, tag: Optional[str] = None):
267
287
# when using pluggable allocator, see
268
288
# https://github.com/pytorch/pytorch/issues/145168 .
269
289
# if we have some memory allocated and then freed,
270
- # the memory will not be released.
271
- # right now it is fine, because we only use this allocator
272
- # during weight loading and kv cache creation, where we only
273
- # allocate memory.
274
- # TODO: we need to find a way to release the memory,
275
- # i.e. calling torch.cuda.empty_cache()
290
+ # the memory will not be released, e.g. in online quantization,
291
+ # where the model is created in higher precision, and then
292
+ # quantized in lower precision.
293
+ # Find all unused allocations and manually release them.
294
+ # TODO: we should expose `empty_cache` method in the memory pool.
295
+ # TODO: ask for help from PyTorch team to expose this method.
296
+ allocations = data [0 ].snapshot ()
297
+ for allocation in allocations :
298
+ if allocation ["allocated_size" ] == 0 :
299
+ handle = self ._python_free_callback (allocation ["address" ])
300
+ unmap_and_release (handle )
276
301
self .current_tag = old_tag
277
302
278
303
def get_current_usage (self ) -> int :
0 commit comments