Skip to content

Commit 5d40700

Browse files
vit check max batch size infer (#887)
Co-authored-by: baishihao <[email protected]> Co-authored-by: hiworldwzj <[email protected]>
1 parent 2e559dc commit 5d40700

File tree

14 files changed

+63
-35
lines changed

14 files changed

+63
-35
lines changed

lightllm/distributed/communication_op.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,24 @@ def init_custom_reduce(self) -> None:
7979
if not HAS_VLLM or not has_nvlink() or self.dp_world_size not in [2, 4, 6, 8]:
8080
return
8181
args = get_env_start_args()
82-
if not args.enable_custom_allreduce:
82+
if args.disable_custom_allreduce:
8383
return
8484
ranks = list([get_global_rank() - get_current_rank_in_dp() + i for i in range(self.dp_world_size)])
8585
cpu_group = dist.new_group(ranks, backend="gloo")
8686
self.custom_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device())
87-
logger.info("Enable VLLM ALLReduce.")
87+
logger.info("Enable Custom ALLReduce. You can disable it by settting --disable_custom_allreduce.")
8888

8989
def init_custom_gather(self) -> None:
9090
if not HAS_LIGHTLLM_KERNEL or not has_nvlink() or self.dp_world_size not in [2, 4, 6, 8]:
9191
return
9292

9393
args = get_env_start_args()
94-
if not args.enable_custom_allgather:
94+
if args.disable_custom_allgather:
9595
return
9696
ranks = list([get_global_rank() - get_current_rank_in_dp() + i for i in range(self.dp_world_size)])
9797
cpu_group = dist.new_group(ranks, backend="gloo")
9898
self.custom_gather = CustomAllgather(cpu_group, torch.cuda.current_device())
99-
logger.info("Enable Custom ALLGather.")
99+
logger.info("Enable Custom ALLGather. You can disable it by settting --disable_custom_allgather")
100100

101101
def all_reduce(self, input_: torch.Tensor) -> None:
102102
if self.custom_reduce is not None and self.custom_reduce.should_custom_ar(input_):

lightllm/distributed/custom_all_reduce.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,23 @@
1919
import ctypes
2020
from contextlib import contextmanager
2121
from typing import List, Optional, Union
22-
22+
import os
2323
import torch
2424
import torch.distributed as dist
2525
from torch.distributed import ProcessGroup
2626

27-
from lightllm.common.vllm_kernel import _custom_ops as ops
2827
from lightllm.common.cuda_wrapper import CudaRTLibrary
2928
from lightllm.utils.log_utils import init_logger
3029
from lightllm.utils.vllm_utils import is_full_nvlink
3130
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
3231

32+
use_vllm_custom_allreduce = os.getenv("USE_VLLM_CUSTOM_ALLREDUCE", "1").upper() in ["1", "TRUE", "ON"]
33+
if use_vllm_custom_allreduce:
34+
from lightllm.common.vllm_kernel import _custom_ops as ops
35+
else:
36+
import sgl_kernel
37+
import sgl_kernel.allreduce as ops
38+
3339
ops.meta_size()
3440
custom_ar = True
3541

lightllm/models/internvl/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
153153
return input_ids
154154

155155

156-
@ModelRegistry(["internvl_chat"], condition=llm_model_type_is("phi3"))
156+
@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is("phi3"))
157157
class InternVLPhi3TpPartModel(Phi3TpPartModel):
158158
# weight class
159159
pre_and_post_weight_class = InternVLPhi3PreAndPostLayerWeight
@@ -177,7 +177,7 @@ def _init_config(self):
177177
return
178178

179179

180-
@ModelRegistry(["internvl_chat"], condition=llm_model_type_is("internlm2"))
180+
@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is("internlm2"))
181181
class InternVLInternlm2TpPartModel(Internlm2TpPartModel):
182182
# weight class
183183
pre_and_post_weight_class = InternVLInternlm2PreAndPostLayerWeight
@@ -201,7 +201,7 @@ def _init_config(self):
201201
return
202202

203203

204-
@ModelRegistry(["internvl_chat"], condition=llm_model_type_is("llama"))
204+
@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is("llama"))
205205
class InternVLLlamaTpPartModel(LlamaTpPartModel):
206206
# weight class
207207
pre_and_post_weight_class = InternVLLlamaPreAndPostLayerWeight
@@ -225,7 +225,7 @@ def _init_config(self):
225225
return
226226

227227

228-
@ModelRegistry(["internvl_chat"], condition=llm_model_type_is("qwen2"))
228+
@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is("qwen2"))
229229
class InternVLQwen2TpPartModel(Qwen2TpPartModel):
230230
# weight class
231231
pre_and_post_weight_class = InternVLLlamaPreAndPostLayerWeight
@@ -249,7 +249,7 @@ def _init_config(self):
249249
return
250250

251251

252-
@ModelRegistry(["internvl_chat"], condition=llm_model_type_is(["deepseek_v2", "deepseek_v3"]))
252+
@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is(["deepseek_v2", "deepseek_v3"]))
253253
class InternVLDeepSeek2TpPartModel(Deepseek2TpPartModel):
254254
# support Deepseek2,3,R1
255255
# weight class

lightllm/models/llava/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None):
7878
return input_ids
7979

8080

81-
@ModelRegistry("llava")
81+
@ModelRegistry("llava", is_multimodal=True)
8282
class LlavaTpPartModel(LlamaTpPartModel):
8383
# weight class
8484
pre_and_post_weight_class = LlavaPreAndPostLayerWeight

lightllm/models/qwen2_vl/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
9191
return input_ids
9292

9393

94-
@ModelRegistry(["qwen2_vl", "qwen2_5_vl"])
94+
@ModelRegistry(["qwen2_vl", "qwen2_5_vl"], is_multimodal=True)
9595
class Qwen2VLTpPartModel(Qwen2TpPartModel):
9696

9797
pre_layer_infer_class = LlamaMultimodalPreLayerInfer

lightllm/models/qwen_vl/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None):
9292
return input_ids
9393

9494

95-
@ModelRegistry("qwen", condition=lambda cfg: "visual" in cfg)
95+
@ModelRegistry("qwen", is_multimodal=True, condition=lambda cfg: "visual" in cfg)
9696
class QWenVLTpPartModel(QWenTpPartModel):
9797

9898
# infer class

lightllm/models/vit/model.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torchvision.transforms as T
1515
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
1616
from PIL import Image
17-
from typing import List, Union
17+
from typing import List, Union, final
1818
from io import BytesIO
1919
from rpyc.utils.classic import obtain
2020
from lightllm.common.quantization import Quantcfg
@@ -46,13 +46,38 @@ def __init__(self, kvargs):
4646
self.quant_type = kvargs.get("quant_type", None)
4747
self.quant_cfg_path = kvargs.get("quant_cfg", None)
4848
self.load_image_func = get_load_image_func(self.weight_dir_)
49+
self.max_batch_size = kvargs.get("max_batch_size", 1)
4950

5051
self._init_datatype()
5152
self._init_config()
5253
self._padding_hidden_size()
5354
self._init_quant()
5455
self._init_weights()
5556
self._init_infer_layer()
57+
self._check_max_len_infer()
58+
return
59+
60+
@final
61+
@torch.no_grad()
62+
def _check_max_len_infer(self):
63+
disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None
64+
if disable_check_max_len_infer:
65+
return
66+
67+
try:
68+
dummy_images = torch.randn(
69+
(self.MAX_PATH_NUM * self.max_batch_size, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type
70+
).cuda()
71+
all_img_embeds = self.forward(dummy_images)
72+
del all_img_embeds
73+
logger.info(f"vit check max_len {self.batch_max_tokens} infer ok")
74+
except (RuntimeError, torch.OutOfMemoryError) as e:
75+
logger.exception(str(e))
76+
exception_str = (
77+
"Vit check max len infer fail, you can try:" "1.Set the --visual_infer_batch_size to a smaller value."
78+
)
79+
logger.error(exception_str)
80+
raise Exception(exception_str)
5681
return
5782

5883
def _init_config(self):
@@ -66,6 +91,11 @@ def _init_config(self):
6691
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
6792
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
6893
self.layers_num = self.config["num_hidden_layers"]
94+
95+
# infer info
96+
self.IMAGE_H = int(os.getenv("IMAGE_H", 448))
97+
self.IMAGE_W = int(os.getenv("IMAGE_W", 448))
98+
self.MAX_PATH_NUM = os.getenv("MAX_PATH_NUM", 13)
6999
return
70100

71101
def _padding_hidden_size(self):

lightllm/models/vit/triton_kernel/gelu_vit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def gelu_fwd(input, use_custom_tensor_mananger=False):
3030
output = g_cache_manager.alloc_tensor(shape, dtype, device=device)
3131
else:
3232
output = torch.empty_like(input)
33+
assert input.is_contiguous(), "Input tensor must be contiguous"
3334
n_elements = input.numel()
3435
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
3536
gelu_kernel[grid](output, input, n_elements, BLOCK_SIZE=1024)

lightllm/server/api_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ def make_argument_parser() -> argparse.ArgumentParser:
225225
parser.add_argument(
226226
"--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service."
227227
)
228-
parser.add_argument("--enable_custom_allreduce", action="store_true", help="Whether to disable cutom allreduce.")
229-
parser.add_argument("--enable_custom_allgather", action="store_true", help="Whether to enable cutom allgather.")
228+
parser.add_argument("--disable_custom_allreduce", action="store_true", help="Whether to disable cutom allreduce.")
229+
parser.add_argument("--disable_custom_allgather", action="store_true", help="Whether to enable cutom allgather.")
230230
parser.add_argument(
231231
"--enable_tpsp_mix_mode",
232232
action="store_true",

lightllm/server/api_start.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,8 @@ def normal_or_p_d_start(args):
6565
set_unique_server_name(args)
6666

6767
if args.enable_mps:
68-
from lightllm.utils.device_utils import enable_mps, set_gpu_exclusive_mode
68+
from lightllm.utils.device_utils import enable_mps
6969

70-
for i in range(args.tp):
71-
set_gpu_exclusive_mode(gpu_index=i)
7270
enable_mps()
7371

7472
if args.run_mode not in ["normal", "prefill", "decode"]:

0 commit comments

Comments
 (0)