44import dataclasses
55import numpy as np
66import torch ._C
7+ import threading
78from typing import Dict , Iterable , Literal , Tuple , Union , List , Set
89from torch .storage import UntypedStorage
910from dataclasses import field
1920 # 用于进行引用计数调整和判断
2021 def custom_del (self : torch .Tensor ):
2122 global g_cache_manager
22- if hasattr (self , "storage_weak_ptr" ):
23- storage_weak_ptr = self .storage_weak_ptr
24- else :
25- storage_weak_ptr = self .untyped_storage ()._weak_ref ()
26- UntypedStorage ._free_weak_ref (storage_weak_ptr )
27- if storage_weak_ptr in g_cache_manager .ptr_to_bufnode :
28- g_cache_manager .changed_ptr .add (storage_weak_ptr )
29- return
23+ g_cache_manager .tensor_lock .acquire ()
24+ try :
25+ if hasattr (self , "storage_weak_ptr" ):
26+ storage_weak_ptr = self .storage_weak_ptr
27+ else :
28+ storage_weak_ptr = self .untyped_storage ()._weak_ref ()
29+ UntypedStorage ._free_weak_ref (storage_weak_ptr )
30+ if storage_weak_ptr in g_cache_manager .ptr_to_bufnode :
31+ g_cache_manager .changed_ptr .add (storage_weak_ptr )
32+ return
33+ finally :
34+ g_cache_manager .tensor_lock .release ()
3035
3136 @dataclasses .dataclass
3237 class BufNode :
@@ -105,6 +110,7 @@ def __init__(self):
105110 self .free_shape_dtype_to_bufs : Dict [Tuple , List [BufNode ]] = collections .defaultdict (list )
106111 self .calcu_shape_cache : Dict [torch .Size , int ] = {}
107112 self .changed_ptr : Set [int ] = set ()
113+ self .tensor_lock = threading .Lock ()
108114 from torch ._C import _storage_Use_Count as use_count
109115
110116 # use_count 函数可以用于获取有多少 tensor 真正引用了这片显存 tensor
@@ -117,25 +123,36 @@ def __init__(self):
117123 def cache_env_in (
118124 self , is_cuda_graph : bool = False , cur_batch_size : int = 0 , cuda_graph_max_batch_size : int = 0
119125 ):
120- self .managed_total_tensor_bytes = 0
121- setattr (torch .Tensor , "__del__" , custom_del )
122- self .is_cuda_graph = is_cuda_graph
123- if self .is_cuda_graph :
124- if self .inner_cuda_graph_manager is None :
125- self .inner_cuda_graph_manager = CudaGraphCacheTensorManager (cuda_graph_max_batch_size )
126- else :
127- assert self .inner_cuda_graph_manager .cuda_graph_max_batch_size == cuda_graph_max_batch_size
128- self .cuda_graph_cur_batch_size = cur_batch_size
129- assert cur_batch_size != 0
130- return
126+ if not self .tensor_lock .acquire (blocking = False ):
127+ return
128+ try :
129+ self .managed_total_tensor_bytes = 0
130+ setattr (torch .Tensor , "__del__" , custom_del )
131+ self .is_cuda_graph = is_cuda_graph
132+ if self .is_cuda_graph :
133+ if self .inner_cuda_graph_manager is None :
134+ self .inner_cuda_graph_manager = CudaGraphCacheTensorManager (cuda_graph_max_batch_size )
135+ else :
136+ assert self .inner_cuda_graph_manager .cuda_graph_max_batch_size == cuda_graph_max_batch_size
137+ self .cuda_graph_cur_batch_size = cur_batch_size
138+ assert cur_batch_size != 0
139+ return
140+ finally :
141+ self .tensor_lock .release ()
131142
132143 def cache_env_out (self ):
133- delattr (torch .Tensor , "__del__" )
134- self .ptr_to_bufnode .clear ()
135- self .free_shape_dtype_to_bufs .clear ()
136- self .calcu_shape_cache .clear ()
137- self .changed_ptr .clear ()
138- return
144+ if not self .tensor_lock .acquire (blocking = False ):
145+ return
146+ try :
147+ if hasattr (torch .Tensor , "__del__" ):
148+ delattr (torch .Tensor , "__del__" )
149+ self .ptr_to_bufnode .clear ()
150+ self .free_shape_dtype_to_bufs .clear ()
151+ self .calcu_shape_cache .clear ()
152+ self .changed_ptr .clear ()
153+ return
154+ finally :
155+ self .tensor_lock .release ()
139156
140157 def alloc_tensor (
141158 self ,
@@ -144,55 +161,56 @@ def alloc_tensor(
144161 device : str = "cuda" ,
145162 is_graph_out : bool = False ,
146163 ) -> torch .Tensor :
147- # shape 类型转换
148- if isinstance (shape , list ):
149- shape = torch .Size (shape )
150- # 是 cuda graph的时候,由cuda graph manager 接管
151- if self .is_cuda_graph :
152- return self .inner_cuda_graph_manager .alloc_tensor_for_cuda_graph (
153- self .cuda_graph_cur_batch_size , shape , data_type , device , is_graph_out
154- )
155-
156- # 回收可能消亡的 tensor
157- for ptr in self .changed_ptr :
158- t_buf_node = self .ptr_to_bufnode [ptr ]
159- if self .use_count (ptr ) == 1 + len (t_buf_node .shape_to_tensor ):
160- self .free_shape_dtype_to_bufs [t_buf_node .shape_key ].append (t_buf_node )
161- self .changed_ptr .clear ()
162-
163- if shape not in self .calcu_shape_cache :
164- size = np .prod (shape )
165- self .calcu_shape_cache [shape ] = size
166- else :
167- size = self .calcu_shape_cache [shape ]
164+ with self .tensor_lock :
165+ # shape 类型转换
166+ if isinstance (shape , list ):
167+ shape = torch .Size (shape )
168+ # 是 cuda graph的时候,由cuda graph manager 接管
169+ if self .is_cuda_graph :
170+ return self .inner_cuda_graph_manager .alloc_tensor_for_cuda_graph (
171+ self .cuda_graph_cur_batch_size , shape , data_type , device , is_graph_out
172+ )
168173
169- key = (size , data_type )
170- buf_node_list = self .free_shape_dtype_to_bufs [key ]
171- if buf_node_list :
172- buf_node = buf_node_list .pop ()
173- if shape not in buf_node .shape_to_tensor :
174- mark_tensor = buf_node .inner_tensor .view (shape )
175- buf_node .shape_to_tensor [shape ] = mark_tensor
174+ # 回收可能消亡的 tensor
175+ for ptr in self .changed_ptr :
176+ t_buf_node = self .ptr_to_bufnode [ptr ]
177+ if self .use_count (ptr ) == 1 + len (t_buf_node .shape_to_tensor ):
178+ self .free_shape_dtype_to_bufs [t_buf_node .shape_key ].append (t_buf_node )
179+ self .changed_ptr .clear ()
180+
181+ if shape not in self .calcu_shape_cache :
182+ size = np .prod (shape )
183+ self .calcu_shape_cache [shape ] = size
176184 else :
177- mark_tensor = buf_node .shape_to_tensor [shape ]
185+ size = self .calcu_shape_cache [shape ]
186+
187+ key = (size , data_type )
188+ buf_node_list = self .free_shape_dtype_to_bufs [key ]
189+ if buf_node_list :
190+ buf_node = buf_node_list .pop ()
191+ if shape not in buf_node .shape_to_tensor :
192+ mark_tensor = buf_node .inner_tensor .view (shape )
193+ buf_node .shape_to_tensor [shape ] = mark_tensor
194+ else :
195+ mark_tensor = buf_node .shape_to_tensor [shape ]
196+ ans = mark_tensor .data # 返回一个新的引用, 否则引用计数会无法判断
197+ ans .storage_weak_ptr = buf_node .storage_weak_ptr
198+ return ans
199+
200+ buf_tensor = torch .empty ((size ,), dtype = data_type , device = device , requires_grad = False )
201+ # 用于调试显存占用的重要日志
202+ # self.managed_total_tensor_bytes += buf_tensor.element_size() * buf_tensor.numel()
203+ # logger.info(f"gpu cache managed_total_tensor_bytes: {self.managed_total_tensor_bytes}")
204+ storage_weak_ptr = buf_tensor .untyped_storage ()._weak_ref ()
205+ buf_node = BufNode (buf_tensor , key , storage_weak_ptr )
206+ self .ptr_to_bufnode [storage_weak_ptr ] = buf_node
207+ if shape not in buf_node .shape_to_tensor :
208+ buf_node .shape_to_tensor [shape ] = buf_node .inner_tensor .view (shape )
209+ mark_tensor = buf_node .shape_to_tensor [shape ]
178210 ans = mark_tensor .data # 返回一个新的引用, 否则引用计数会无法判断
179211 ans .storage_weak_ptr = buf_node .storage_weak_ptr
180212 return ans
181213
182- buf_tensor = torch .empty ((size ,), dtype = data_type , device = device , requires_grad = False )
183- # 用于调试显存占用的重要日志
184- # self.managed_total_tensor_bytes += buf_tensor.element_size() * buf_tensor.numel()
185- # logger.info(f"gpu cache managed_total_tensor_bytes: {self.managed_total_tensor_bytes}")
186- storage_weak_ptr = buf_tensor .untyped_storage ()._weak_ref ()
187- buf_node = BufNode (buf_tensor , key , storage_weak_ptr )
188- self .ptr_to_bufnode [storage_weak_ptr ] = buf_node
189- if shape not in buf_node .shape_to_tensor :
190- buf_node .shape_to_tensor [shape ] = buf_node .inner_tensor .view (shape )
191- mark_tensor = buf_node .shape_to_tensor [shape ]
192- ans = mark_tensor .data # 返回一个新的引用, 否则引用计数会无法判断
193- ans .storage_weak_ptr = buf_node .storage_weak_ptr
194- return ans
195-
196214else :
197215 logger .info ("USE_GPU_TENSOR_CACHE is OFF" )
198216
0 commit comments