Skip to content

Commit 4356d2e

Browse files
authored
fix
1 parent b2183e6 commit 4356d2e

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

lightllm/common/basemodel/layer_infer/cache_tensor_manager.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def __init__(self):
9393
self.cuda_graph_cur_batch_size = None
9494
self.is_cuda_graph = False
9595
self.managed_total_tensor_bytes = 0
96+
# 防止误用导致显存泄露,添加标记变量。
97+
# 当使用者没有合法的调用 cache_env_in 和 cache_env_out 的时候
98+
# 如果调用了alloc_tensor 接口,则退化为 torch.empty 申请方式。
99+
self.cache_env_ok = False
96100

97101
def cache_env_in(
98102
self, is_cuda_graph: bool = False, cur_batch_size: int = 0, cuda_graph_max_batch_size: int = 0
@@ -107,6 +111,7 @@ def cache_env_in(
107111
assert self.inner_cuda_graph_manager.cuda_graph_max_batch_size == cuda_graph_max_batch_size
108112
self.cuda_graph_cur_batch_size = cur_batch_size
109113
assert cur_batch_size != 0
114+
self.cache_env_ok = True
110115
return
111116

112117
def cache_env_out(self):
@@ -115,6 +120,7 @@ def cache_env_out(self):
115120
self.free_shape_dtype_to_bufs.clear()
116121
self.calcu_shape_cache.clear()
117122
self.changed_ptr.clear()
123+
self.cache_env_ok = False
118124
return
119125

120126
def alloc_tensor(
@@ -129,6 +135,11 @@ def alloc_tensor(
129135
# shape 类型转换
130136
if isinstance(shape, list):
131137
shape = torch.Size(shape)
138+
139+
# cache manager 没有被正常使用时
140+
if not self.cache_env_ok:
141+
return torch.empty(shape, dtype=data_type, device=device, requires_grad=False)
142+
132143
# 是 cuda graph的时候,由cuda graph manager 接管
133144
if self.is_cuda_graph:
134145
return self.inner_cuda_graph_manager.alloc_tensor_for_cuda_graph(

0 commit comments

Comments
 (0)