@@ -93,6 +93,10 @@ def __init__(self):
9393 self .cuda_graph_cur_batch_size = None
9494 self .is_cuda_graph = False
9595 self .managed_total_tensor_bytes = 0
96+ # 防止误用导致显存泄露,添加标记变量。
97+ # 当使用者没有合法的调用 cache_env_in 和 cache_env_out 的时候
98+ # 如果调用了alloc_tensor 接口,则退化为 torch.empty 申请方式。
99+ self .cache_env_ok = False
96100
97101 def cache_env_in (
98102 self , is_cuda_graph : bool = False , cur_batch_size : int = 0 , cuda_graph_max_batch_size : int = 0
@@ -107,6 +111,7 @@ def cache_env_in(
107111 assert self .inner_cuda_graph_manager .cuda_graph_max_batch_size == cuda_graph_max_batch_size
108112 self .cuda_graph_cur_batch_size = cur_batch_size
109113 assert cur_batch_size != 0
114+ self .cache_env_ok = True
110115 return
111116
112117 def cache_env_out (self ):
@@ -115,6 +120,7 @@ def cache_env_out(self):
115120 self .free_shape_dtype_to_bufs .clear ()
116121 self .calcu_shape_cache .clear ()
117122 self .changed_ptr .clear ()
123+ self .cache_env_ok = False
118124 return
119125
120126 def alloc_tensor (
@@ -129,6 +135,11 @@ def alloc_tensor(
129135 # shape 类型转换
130136 if isinstance (shape , list ):
131137 shape = torch .Size (shape )
138+
139+ # cache manager 没有被正常使用时
140+ if not self .cache_env_ok :
141+ return torch .empty (shape , dtype = data_type , device = device , requires_grad = False )
142+
132143 # 是 cuda graph的时候,由cuda graph manager 接管
133144 if self .is_cuda_graph :
134145 return self .inner_cuda_graph_manager .alloc_tensor_for_cuda_graph (
0 commit comments