Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lightllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,24 +79,24 @@ def init_custom_reduce(self) -> None:
if not HAS_VLLM or not has_nvlink() or self.dp_world_size not in [2, 4, 6, 8]:
return
args = get_env_start_args()
if not args.enable_custom_allreduce:
if args.disable_custom_allreduce:
return
ranks = list([get_global_rank() - get_current_rank_in_dp() + i for i in range(self.dp_world_size)])
cpu_group = dist.new_group(ranks, backend="gloo")
self.custom_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device())
logger.info("Enable VLLM ALLReduce.")
logger.info("Enable Custom ALLReduce. You can disable it by settting --disable_custom_allreduce.")

def init_custom_gather(self) -> None:
if not HAS_LIGHTLLM_KERNEL or not has_nvlink() or self.dp_world_size not in [2, 4, 6, 8]:
return

args = get_env_start_args()
if not args.enable_custom_allgather:
if args.disable_custom_allgather:
return
ranks = list([get_global_rank() - get_current_rank_in_dp() + i for i in range(self.dp_world_size)])
cpu_group = dist.new_group(ranks, backend="gloo")
self.custom_gather = CustomAllgather(cpu_group, torch.cuda.current_device())
logger.info("Enable Custom ALLGather.")
logger.info("Enable Custom ALLGather. You can disable it by settting --disable_custom_allgather")

def all_reduce(self, input_: torch.Tensor) -> None:
if self.custom_reduce is not None and self.custom_reduce.should_custom_ar(input_):
Expand Down
10 changes: 8 additions & 2 deletions lightllm/distributed/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@
import ctypes
from contextlib import contextmanager
from typing import List, Optional, Union

import os
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from lightllm.common.vllm_kernel import _custom_ops as ops
from lightllm.common.cuda_wrapper import CudaRTLibrary
from lightllm.utils.log_utils import init_logger
from lightllm.utils.vllm_utils import is_full_nvlink
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager

use_vllm_custom_allreduce = os.getenv("USE_VLLM_CUSTOM_ALLREDUCE", "1").upper() in ["1", "TRUE", "ON"]
if use_vllm_custom_allreduce:
from lightllm.common.vllm_kernel import _custom_ops as ops
else:
import sgl_kernel
import sgl_kernel.allreduce as ops

ops.meta_size()
custom_ar = True

Expand Down
10 changes: 5 additions & 5 deletions lightllm/models/internvl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
return input_ids


@ModelRegistry(["internvl_chat"], condition=llm_model_type_is("phi3"))
@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is("phi3"))
class InternVLPhi3TpPartModel(Phi3TpPartModel):
# weight class
pre_and_post_weight_class = InternVLPhi3PreAndPostLayerWeight
Expand All @@ -177,7 +177,7 @@ def _init_config(self):
return


@ModelRegistry(["internvl_chat"], condition=llm_model_type_is("internlm2"))
@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is("internlm2"))
class InternVLInternlm2TpPartModel(Internlm2TpPartModel):
# weight class
pre_and_post_weight_class = InternVLInternlm2PreAndPostLayerWeight
Expand All @@ -201,7 +201,7 @@ def _init_config(self):
return


@ModelRegistry(["internvl_chat"], condition=llm_model_type_is("llama"))
@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is("llama"))
class InternVLLlamaTpPartModel(LlamaTpPartModel):
# weight class
pre_and_post_weight_class = InternVLLlamaPreAndPostLayerWeight
Expand All @@ -225,7 +225,7 @@ def _init_config(self):
return


@ModelRegistry(["internvl_chat"], condition=llm_model_type_is("qwen2"))
@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is("qwen2"))
class InternVLQwen2TpPartModel(Qwen2TpPartModel):
# weight class
pre_and_post_weight_class = InternVLLlamaPreAndPostLayerWeight
Expand All @@ -249,7 +249,7 @@ def _init_config(self):
return


@ModelRegistry(["internvl_chat"], condition=llm_model_type_is(["deepseek_v2", "deepseek_v3"]))
@ModelRegistry(["internvl_chat"], is_multimodal=True, condition=llm_model_type_is(["deepseek_v2", "deepseek_v3"]))
class InternVLDeepSeek2TpPartModel(Deepseek2TpPartModel):
# support Deepseek2,3,R1
# weight class
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None):
return input_ids


@ModelRegistry("llava")
@ModelRegistry("llava", is_multimodal=True)
class LlavaTpPartModel(LlamaTpPartModel):
# weight class
pre_and_post_weight_class = LlavaPreAndPostLayerWeight
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/qwen2_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
return input_ids


@ModelRegistry(["qwen2_vl", "qwen2_5_vl"])
@ModelRegistry(["qwen2_vl", "qwen2_5_vl"], is_multimodal=True)
class Qwen2VLTpPartModel(Qwen2TpPartModel):

pre_layer_infer_class = LlamaMultimodalPreLayerInfer
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/qwen_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None):
return input_ids


@ModelRegistry("qwen", condition=lambda cfg: "visual" in cfg)
@ModelRegistry("qwen", is_multimodal=True, condition=lambda cfg: "visual" in cfg)
class QWenVLTpPartModel(QWenTpPartModel):

# infer class
Expand Down
32 changes: 31 additions & 1 deletion lightllm/models/vit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torchvision.transforms as T
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from PIL import Image
from typing import List, Union
from typing import List, Union, final
from io import BytesIO
from rpyc.utils.classic import obtain
from lightllm.common.quantization import Quantcfg
Expand Down Expand Up @@ -46,13 +46,38 @@ def __init__(self, kvargs):
self.quant_type = kvargs.get("quant_type", None)
self.quant_cfg_path = kvargs.get("quant_cfg", None)
self.load_image_func = get_load_image_func(self.weight_dir_)
self.max_batch_size = kvargs.get("max_batch_size", 1)

self._init_datatype()
self._init_config()
self._padding_hidden_size()
self._init_quant()
self._init_weights()
self._init_infer_layer()
self._check_max_len_infer()
return

@final
@torch.no_grad()
def _check_max_len_infer(self):
disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None
if disable_check_max_len_infer:
return

try:
dummy_images = torch.randn(
(self.MAX_PATH_NUM * self.max_batch_size, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type
).cuda()
all_img_embeds = self.forward(dummy_images)
del all_img_embeds
logger.info(f"vit check max_len {self.batch_max_tokens} infer ok")
except (RuntimeError, torch.OutOfMemoryError) as e:
logger.exception(str(e))
exception_str = (
"Vit check max len infer fail, you can try:" "1.Set the --visual_infer_batch_size to a smaller value."
)
logger.error(exception_str)
raise Exception(exception_str)
return

def _init_config(self):
Expand All @@ -66,6 +91,11 @@ def _init_config(self):
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
self.layers_num = self.config["num_hidden_layers"]

# infer info
self.IMAGE_H = int(os.getenv("IMAGE_H", 448))
self.IMAGE_W = int(os.getenv("IMAGE_W", 448))
self.MAX_PATH_NUM = os.getenv("MAX_PATH_NUM", 13)
return

def _padding_hidden_size(self):
Expand Down
1 change: 1 addition & 0 deletions lightllm/models/vit/triton_kernel/gelu_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def gelu_fwd(input, use_custom_tensor_mananger=False):
output = g_cache_manager.alloc_tensor(shape, dtype, device=device)
else:
output = torch.empty_like(input)
assert input.is_contiguous(), "Input tensor must be contiguous"
n_elements = input.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
gelu_kernel[grid](output, input, n_elements, BLOCK_SIZE=1024)
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service."
)
parser.add_argument("--enable_custom_allreduce", action="store_true", help="Whether to disable cutom allreduce.")
parser.add_argument("--enable_custom_allgather", action="store_true", help="Whether to enable cutom allgather.")
parser.add_argument("--disable_custom_allreduce", action="store_true", help="Whether to disable cutom allreduce.")
parser.add_argument("--disable_custom_allgather", action="store_true", help="Whether to enable cutom allgather.")
parser.add_argument(
"--enable_tpsp_mix_mode",
action="store_true",
Expand Down
4 changes: 1 addition & 3 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,8 @@ def normal_or_p_d_start(args):
set_unique_server_name(args)

if args.enable_mps:
from lightllm.utils.device_utils import enable_mps, set_gpu_exclusive_mode
from lightllm.utils.device_utils import enable_mps

for i in range(args.tp):
set_gpu_exclusive_mode(gpu_index=i)
enable_mps()

if args.run_mode not in ["normal", "prefill", "decode"]:
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/visualserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ async def wait_to_model_ready(self):
"visual_gpu_ids": self.args.visual_gpu_ids,
"quant_type": self.args.vit_quant_type,
"quant_cfg": self.args.vit_quant_cfg,
"max_batch_size": min(self.infer_batch_size // self.vit_dp, 1),
}
init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs))
await asyncio.gather(*init_model_ret)
Expand Down
5 changes: 1 addition & 4 deletions lightllm/server/visualserver/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def exposed_init_model(self, kvargs):
"data_type": self.data_type,
"quant_type": kvargs["quant_type"],
"quant_cfg": kvargs["quant_cfg"],
"max_batch_size": kvargs["max_batch_size"],
}
self.model = VisionTransformer(kvargs)
# self.model = InternVLVisionModel()
Expand Down Expand Up @@ -149,10 +150,6 @@ async def encode(self, images: List[ImageItem]):
def _init_env(port, device_id):
# 注册graceful 退出的处理
graceful_registry(inspect.currentframe().f_code.co_name)
from lightllm.utils.device_utils import set_sm_limit

if get_env_start_args().enable_mps:
set_sm_limit(60, device_id) # the visual server can take up to 60% of the sm

t = ThreadedServer(VisualModelRpcServer(), port=port, protocol_config={"allow_pickle": True})
t.start()
Expand Down
5 changes: 1 addition & 4 deletions lightllm/utils/start_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,10 @@ def kill_recursive(proc):

# recover the gpu compute mode
is_enable_mps = get_env_start_args().enable_mps
world_size = get_env_start_args().tp
if is_enable_mps:
from lightllm.utils.device_utils import stop_mps, set_gpu_default_mode
from lightllm.utils.device_utils import stop_mps

stop_mps()
for i in range(world_size):
set_gpu_default_mode(gpu_index=i)
logger.info("All processes terminated gracefully.")


Expand Down
12 changes: 5 additions & 7 deletions test/model/model_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from multiprocessing import Queue
import multiprocessing
from transformers import PretrainedConfig
from lightllm.utils.dist_utils import init_distributed_env
from lightllm.utils.dist_utils import init_distributed_env, get_current_rank_in_dp
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.models import get_model
from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch, PrefillMicroBatch
Expand Down Expand Up @@ -39,7 +39,7 @@ def test_model_inference(args):
"batch_max_tokens": args.batch_size * args.input_len,
"run_mode": "normal",
"max_seq_length": args.max_req_total_len,
"disable_cudagraph": True if args.profile else False,
"disable_cudagraph": args.disable_cudagraph,
}
proc = multiprocessing.Process(
target=tppart_model_infer,
Expand Down Expand Up @@ -173,15 +173,13 @@ def torch_profile(fn, log_dir=None):
torch.cuda.synchronize()
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=1),
record_shapes=False,
profile_memory=False,
on_trace_ready=torch.profiler.tensorboard_trace_handler(log_dir),
) as prof:
for _ in range(3):
fn()
prof.step()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
fn()
if get_current_rank_in_dp() == 0:
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


def tppart_model_infer(args, model_kvargs, batch_size, input_len, output_len, ans_queue):
Expand Down