Skip to content

Commit cf9df54

Browse files
author
niushengxiao
committed
add second coroutine
1 parent deb2ddb commit cf9df54

File tree

7 files changed

+286
-216
lines changed

7 files changed

+286
-216
lines changed

lightllm/common/basemodel/cuda_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def warmup(self, model, stream_id):
8383
b_ready_cache_len=b_ready_cache_len,
8484
is_prefill=True,
8585
multimodal_params=[],
86-
# stream_id=stream_id,
86+
stream_id=stream_id,
8787
)
8888
mem_indexes = None
8989
prob_out = torch.softmax(logics, dim=-1)

lightllm/common/basemodel/layer_infer/cache_tensor_manager.py

Lines changed: 86 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import dataclasses
55
import numpy as np
66
import torch._C
7+
import threading
78
from typing import Dict, Iterable, Literal, Tuple, Union, List, Set
89
from torch.storage import UntypedStorage
910
from dataclasses import field
@@ -19,14 +20,18 @@
1920
# 用于进行引用计数调整和判断
2021
def custom_del(self: torch.Tensor):
2122
global g_cache_manager
22-
if hasattr(self, "storage_weak_ptr"):
23-
storage_weak_ptr = self.storage_weak_ptr
24-
else:
25-
storage_weak_ptr = self.untyped_storage()._weak_ref()
26-
UntypedStorage._free_weak_ref(storage_weak_ptr)
27-
if storage_weak_ptr in g_cache_manager.ptr_to_bufnode:
28-
g_cache_manager.changed_ptr.add(storage_weak_ptr)
29-
return
23+
g_cache_manager.tensor_lock.acquire()
24+
try:
25+
if hasattr(self, "storage_weak_ptr"):
26+
storage_weak_ptr = self.storage_weak_ptr
27+
else:
28+
storage_weak_ptr = self.untyped_storage()._weak_ref()
29+
UntypedStorage._free_weak_ref(storage_weak_ptr)
30+
if storage_weak_ptr in g_cache_manager.ptr_to_bufnode:
31+
g_cache_manager.changed_ptr.add(storage_weak_ptr)
32+
return
33+
finally:
34+
g_cache_manager.tensor_lock.release()
3035

3136
@dataclasses.dataclass
3237
class BufNode:
@@ -105,6 +110,7 @@ def __init__(self):
105110
self.free_shape_dtype_to_bufs: Dict[Tuple, List[BufNode]] = collections.defaultdict(list)
106111
self.calcu_shape_cache: Dict[torch.Size, int] = {}
107112
self.changed_ptr: Set[int] = set()
113+
self.tensor_lock = threading.Lock()
108114
from torch._C import _storage_Use_Count as use_count
109115

110116
# use_count 函数可以用于获取有多少 tensor 真正引用了这片显存 tensor
@@ -117,25 +123,36 @@ def __init__(self):
117123
def cache_env_in(
118124
self, is_cuda_graph: bool = False, cur_batch_size: int = 0, cuda_graph_max_batch_size: int = 0
119125
):
120-
self.managed_total_tensor_bytes = 0
121-
setattr(torch.Tensor, "__del__", custom_del)
122-
self.is_cuda_graph = is_cuda_graph
123-
if self.is_cuda_graph:
124-
if self.inner_cuda_graph_manager is None:
125-
self.inner_cuda_graph_manager = CudaGraphCacheTensorManager(cuda_graph_max_batch_size)
126-
else:
127-
assert self.inner_cuda_graph_manager.cuda_graph_max_batch_size == cuda_graph_max_batch_size
128-
self.cuda_graph_cur_batch_size = cur_batch_size
129-
assert cur_batch_size != 0
130-
return
126+
if not self.tensor_lock.acquire(blocking=False):
127+
return
128+
try:
129+
self.managed_total_tensor_bytes = 0
130+
setattr(torch.Tensor, "__del__", custom_del)
131+
self.is_cuda_graph = is_cuda_graph
132+
if self.is_cuda_graph:
133+
if self.inner_cuda_graph_manager is None:
134+
self.inner_cuda_graph_manager = CudaGraphCacheTensorManager(cuda_graph_max_batch_size)
135+
else:
136+
assert self.inner_cuda_graph_manager.cuda_graph_max_batch_size == cuda_graph_max_batch_size
137+
self.cuda_graph_cur_batch_size = cur_batch_size
138+
assert cur_batch_size != 0
139+
return
140+
finally:
141+
self.tensor_lock.release()
131142

132143
def cache_env_out(self):
133-
delattr(torch.Tensor, "__del__")
134-
self.ptr_to_bufnode.clear()
135-
self.free_shape_dtype_to_bufs.clear()
136-
self.calcu_shape_cache.clear()
137-
self.changed_ptr.clear()
138-
return
144+
if not self.tensor_lock.acquire(blocking=False):
145+
return
146+
try:
147+
if hasattr(torch.Tensor, "__del__"):
148+
delattr(torch.Tensor, "__del__")
149+
self.ptr_to_bufnode.clear()
150+
self.free_shape_dtype_to_bufs.clear()
151+
self.calcu_shape_cache.clear()
152+
self.changed_ptr.clear()
153+
return
154+
finally:
155+
self.tensor_lock.release()
139156

140157
def alloc_tensor(
141158
self,
@@ -144,55 +161,56 @@ def alloc_tensor(
144161
device: str = "cuda",
145162
is_graph_out: bool = False,
146163
) -> torch.Tensor:
147-
# shape 类型转换
148-
if isinstance(shape, list):
149-
shape = torch.Size(shape)
150-
# 是 cuda graph的时候,由cuda graph manager 接管
151-
if self.is_cuda_graph:
152-
return self.inner_cuda_graph_manager.alloc_tensor_for_cuda_graph(
153-
self.cuda_graph_cur_batch_size, shape, data_type, device, is_graph_out
154-
)
155-
156-
# 回收可能消亡的 tensor
157-
for ptr in self.changed_ptr:
158-
t_buf_node = self.ptr_to_bufnode[ptr]
159-
if self.use_count(ptr) == 1 + len(t_buf_node.shape_to_tensor):
160-
self.free_shape_dtype_to_bufs[t_buf_node.shape_key].append(t_buf_node)
161-
self.changed_ptr.clear()
162-
163-
if shape not in self.calcu_shape_cache:
164-
size = np.prod(shape)
165-
self.calcu_shape_cache[shape] = size
166-
else:
167-
size = self.calcu_shape_cache[shape]
164+
with self.tensor_lock:
165+
# shape 类型转换
166+
if isinstance(shape, list):
167+
shape = torch.Size(shape)
168+
# 是 cuda graph的时候,由cuda graph manager 接管
169+
if self.is_cuda_graph:
170+
return self.inner_cuda_graph_manager.alloc_tensor_for_cuda_graph(
171+
self.cuda_graph_cur_batch_size, shape, data_type, device, is_graph_out
172+
)
168173

169-
key = (size, data_type)
170-
buf_node_list = self.free_shape_dtype_to_bufs[key]
171-
if buf_node_list:
172-
buf_node = buf_node_list.pop()
173-
if shape not in buf_node.shape_to_tensor:
174-
mark_tensor = buf_node.inner_tensor.view(shape)
175-
buf_node.shape_to_tensor[shape] = mark_tensor
174+
# 回收可能消亡的 tensor
175+
for ptr in self.changed_ptr:
176+
t_buf_node = self.ptr_to_bufnode[ptr]
177+
if self.use_count(ptr) == 1 + len(t_buf_node.shape_to_tensor):
178+
self.free_shape_dtype_to_bufs[t_buf_node.shape_key].append(t_buf_node)
179+
self.changed_ptr.clear()
180+
181+
if shape not in self.calcu_shape_cache:
182+
size = np.prod(shape)
183+
self.calcu_shape_cache[shape] = size
176184
else:
177-
mark_tensor = buf_node.shape_to_tensor[shape]
185+
size = self.calcu_shape_cache[shape]
186+
187+
key = (size, data_type)
188+
buf_node_list = self.free_shape_dtype_to_bufs[key]
189+
if buf_node_list:
190+
buf_node = buf_node_list.pop()
191+
if shape not in buf_node.shape_to_tensor:
192+
mark_tensor = buf_node.inner_tensor.view(shape)
193+
buf_node.shape_to_tensor[shape] = mark_tensor
194+
else:
195+
mark_tensor = buf_node.shape_to_tensor[shape]
196+
ans = mark_tensor.data # 返回一个新的引用, 否则引用计数会无法判断
197+
ans.storage_weak_ptr = buf_node.storage_weak_ptr
198+
return ans
199+
200+
buf_tensor = torch.empty((size,), dtype=data_type, device=device, requires_grad=False)
201+
# 用于调试显存占用的重要日志
202+
# self.managed_total_tensor_bytes += buf_tensor.element_size() * buf_tensor.numel()
203+
# logger.info(f"gpu cache managed_total_tensor_bytes: {self.managed_total_tensor_bytes}")
204+
storage_weak_ptr = buf_tensor.untyped_storage()._weak_ref()
205+
buf_node = BufNode(buf_tensor, key, storage_weak_ptr)
206+
self.ptr_to_bufnode[storage_weak_ptr] = buf_node
207+
if shape not in buf_node.shape_to_tensor:
208+
buf_node.shape_to_tensor[shape] = buf_node.inner_tensor.view(shape)
209+
mark_tensor = buf_node.shape_to_tensor[shape]
178210
ans = mark_tensor.data # 返回一个新的引用, 否则引用计数会无法判断
179211
ans.storage_weak_ptr = buf_node.storage_weak_ptr
180212
return ans
181213

182-
buf_tensor = torch.empty((size,), dtype=data_type, device=device, requires_grad=False)
183-
# 用于调试显存占用的重要日志
184-
# self.managed_total_tensor_bytes += buf_tensor.element_size() * buf_tensor.numel()
185-
# logger.info(f"gpu cache managed_total_tensor_bytes: {self.managed_total_tensor_bytes}")
186-
storage_weak_ptr = buf_tensor.untyped_storage()._weak_ref()
187-
buf_node = BufNode(buf_tensor, key, storage_weak_ptr)
188-
self.ptr_to_bufnode[storage_weak_ptr] = buf_node
189-
if shape not in buf_node.shape_to_tensor:
190-
buf_node.shape_to_tensor[shape] = buf_node.inner_tensor.view(shape)
191-
mark_tensor = buf_node.shape_to_tensor[shape]
192-
ans = mark_tensor.data # 返回一个新的引用, 否则引用计数会无法判断
193-
ans.storage_weak_ptr = buf_node.storage_weak_ptr
194-
return ans
195-
196214
else:
197215
logger.info("USE_GPU_TENSOR_CACHE is OFF")
198216

0 commit comments

Comments
 (0)