Skip to content

Commit caa2d6c

Browse files
committed
update & fix format
1 parent ab5d933 commit caa2d6c

File tree

6 files changed

+42
-50
lines changed

6 files changed

+42
-50
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -179,28 +179,16 @@ def _init_mem_manager(self):
179179
)
180180

181181
if self.enable_hiradix_cache:
182-
from lightllm.common.radixmem_buffer import RadixMemoryBuffer, init_shared_data, get_shared_data, MemPropties
183-
from lightllm.common.radixmem_manager import RadixBufferManager
182+
from lightllm.common.radixmem_buffer import get_shared_data, MemPropties
183+
from lightllm.common.radixmem_manager import build_radix_manager
184184
mem_propties = MemPropties(
185185
self.hiradix_cache_token_num,
186186
dtype=self.data_type,
187187
head_num=self.config["num_attention_heads"] // self.tp_world_size_,
188188
head_dim=self.config["n_embed"] // self.config["num_attention_heads"],
189189
layer_num=self.config["n_layer"]
190190
)
191-
init_shared_data(
192-
mem_propties=mem_propties,
193-
device="cpu" if not self.hiradix_cache_gpu else "cuda"
194-
)
195-
radix_mem_buffer = RadixMemoryBuffer(
196-
mem_propties,
197-
shared_data=get_shared_data(),
198-
lock=self.radix_lock,
199-
device="cpu" if not self.hiradix_cache_gpu else "cuda"
200-
)
201-
self.radix_manager = RadixBufferManager(radix_buffer=radix_mem_buffer,
202-
radix_mem_data=get_shared_data(),
203-
lock=self.radix_lock)
191+
self.radix_manager = build_radix_manager(mem_propties, self.hiradix_cache_gpu, self.radix_lock)
204192
self.mem_propties = mem_propties
205193
self.shared_mem_data = get_shared_data()
206194
return

lightllm/common/radixmem_buffer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,16 @@ def free_req_index(self, req_id: int):
137137
return
138138
index = self.req_mem_index[req_id]
139139
self._free(index)
140-
logger.info(f"Freed memory index for request {req_id} size {len(index)}, left size {self.can_use_mem_size.get_value()}")
140+
logger.info(f"Freed memory index for request {req_id} size {len(index)}, "
141+
f"left size {self.can_use_mem_size.get_value()}")
141142
del self.req_mem_index[req_id]
142143

143144
def alloc(self, need_size) -> torch.Tensor:
144145
with self.lock:
145146
if need_size > self.mark_end.get_value() - self.mark_start.get_value():
146147
logger.error(
147-
f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size.get_value()}"
148+
f"warn no enough cache need_size {need_size} "
149+
f"left_size {self.can_use_mem_size.get_value()}"
148150
)
149151
raise RuntimeError(f"Not enough memory to allocate {need_size} tokens.")
150152

@@ -160,7 +162,8 @@ def set_req_mem_index(self, req_id: int, index: List[int]):
160162
"""Set the memory index for a specific request ID."""
161163
with self.lock:
162164
if req_id in self.req_mem_index:
163-
logger.info(f"Request ID {req_id} already exists. Overwriting index {self.req_mem_index[req_id]} with {index}.")
165+
logger.info(f"Request ID {req_id} already exists. "
166+
f"Overwriting index {self.req_mem_index[req_id]} with {index}.")
164167
self.req_mem_index[req_id] = index
165168
logger.info(f"radix mem buffer insert req {req_id}, current disk work num {self._get_current_work_num()}")
166169

lightllm/common/radixmem_manager.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.multiprocessing as mp
77
from collections import OrderedDict
88

9+
from .radixmem_buffer import MemPropties, init_shared_data, get_shared_data
910
from .radixmem_buffer import SharedRadixMemoryData, RadixMemoryBuffer
1011

1112
from lightllm.utils.log_utils import init_logger
@@ -116,4 +117,29 @@ def query_cache(self, tokens: List[int]) -> int:
116117
def clear(self):
117118
with self.lock:
118119
self.radix_buffer.req_mem_index.clear()
119-
self.lru_queue[:] = []
120+
self.lru_queue[:] = []
121+
122+
def build_radix_manager(mem_propties: MemPropties,
123+
use_gpu: bool,
124+
radix_lock) -> RadixBufferManager:
125+
device = "cuda" if use_gpu else "cpu"
126+
127+
init_shared_data(
128+
mem_propties=mem_propties,
129+
device=device,
130+
)
131+
132+
radix_mem_buffer = RadixMemoryBuffer(
133+
mem_propties=mem_propties,
134+
shared_data=get_shared_data(),
135+
lock=radix_lock,
136+
device=device,
137+
)
138+
139+
radix_manager = RadixBufferManager(
140+
radix_buffer=radix_mem_buffer,
141+
radix_mem_data=get_shared_data(),
142+
lock=radix_lock,
143+
)
144+
145+
return radix_manager

lightllm/models/deepseek2/model.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,28 +113,16 @@ def _init_mem_manager(self):
113113
mem_fraction=self.mem_fraction,
114114
)
115115
if self.enable_hiradix_cache:
116-
from lightllm.common.radixmem_buffer import RadixMemoryBuffer, init_shared_data, get_shared_data, MemPropties
117-
from lightllm.common.radixmem_manager import RadixBufferManager
116+
from lightllm.common.radixmem_buffer import get_shared_data, MemPropties
117+
from lightllm.common.radixmem_manager import build_radix_manager
118118
mem_propties = MemPropties(
119119
self.hiradix_cache_token_num,
120120
dtype=self.data_type,
121121
head_num=1,
122122
head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"],
123123
layer_num=self.config["num_hidden_layers"] + added_mtp_layer_num,
124124
)
125-
init_shared_data(
126-
mem_propties=mem_propties,
127-
device="cpu" if not self.hiradix_cache_gpu else "cuda"
128-
)
129-
radix_mem_buffer = RadixMemoryBuffer(
130-
mem_propties,
131-
shared_data=get_shared_data(),
132-
lock=self.radix_lock,
133-
device="cpu" if not self.hiradix_cache_gpu else "cuda"
134-
)
135-
self.radix_manager = RadixBufferManager(radix_buffer=radix_mem_buffer,
136-
radix_mem_data=get_shared_data(),
137-
lock=self.radix_lock)
125+
self.radix_manager = build_radix_manager(mem_propties, self.hiradix_cache_gpu, self.radix_lock)
138126
self.mem_propties = mem_propties
139127
self.shared_mem_data = get_shared_data()
140128
return

lightllm/models/qwen2/model.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,28 +53,16 @@ def _init_mem_manager(self):
5353
)
5454

5555
if self.enable_hiradix_cache:
56-
from lightllm.common.radixmem_buffer import RadixMemoryBuffer, init_shared_data, get_shared_data, MemPropties
57-
from lightllm.common.radixmem_manager import RadixBufferManager
56+
from lightllm.common.radixmem_buffer import MemPropties, get_shared_data, MemPropties
57+
from lightllm.common.radixmem_manager import build_radix_manager
5858
mem_propties = MemPropties(
5959
self.hiradix_cache_token_num,
6060
dtype=self.data_type,
6161
head_num=2 * tp_k_head_num_,
6262
head_dim=head_dim_,
6363
layer_num=self.config["num_hidden_layers"],
6464
)
65-
init_shared_data(
66-
mem_propties=mem_propties,
67-
device="cpu" if not self.hiradix_cache_gpu else "cuda"
68-
)
69-
radix_mem_buffer = RadixMemoryBuffer(
70-
mem_propties,
71-
shared_data=get_shared_data(),
72-
lock=self.radix_lock,
73-
device="cpu" if not self.hiradix_cache_gpu else "cuda"
74-
)
75-
self.radix_manager = RadixBufferManager(radix_buffer=radix_mem_buffer,
76-
radix_mem_data=get_shared_data(),
77-
lock=self.radix_lock)
65+
self.radix_manager = build_radix_manager(mem_propties, self.hiradix_cache_gpu, self.radix_lock)
7866
self.mem_propties = mem_propties
7967
self.shared_mem_data = get_shared_data()
8068
return

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from lightllm.utils.log_utils import init_logger
88
from lightllm.models import get_model
99
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache
10-
from lightllm.server.router.model_infer.infer_batch import InferReq
1110
from lightllm.server.router.dynamic_prompt.hiradix.hiradix_cache import HiRadixCache
1211
from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams
1312
from lightllm.server.router.token_load import TokenLoad

0 commit comments

Comments
 (0)