1414import torchvision .transforms as T
1515from lightllm .server .embed_cache .utils import read_shm , get_shm_name_data
1616from PIL import Image
17- from typing import List , Union
17+ from typing import List , Union , final
1818from io import BytesIO
1919from rpyc .utils .classic import obtain
2020from lightllm .common .quantization import Quantcfg
@@ -46,13 +46,38 @@ def __init__(self, kvargs):
4646 self .quant_type = kvargs .get ("quant_type" , None )
4747 self .quant_cfg_path = kvargs .get ("quant_cfg" , None )
4848 self .load_image_func = get_load_image_func (self .weight_dir_ )
49+ self .max_batch_size = kvargs .get ("max_batch_size" , 1 )
4950
5051 self ._init_datatype ()
5152 self ._init_config ()
5253 self ._padding_hidden_size ()
5354 self ._init_quant ()
5455 self ._init_weights ()
5556 self ._init_infer_layer ()
57+ self ._check_max_len_infer ()
58+ return
59+
60+ @final
61+ @torch .no_grad ()
62+ def _check_max_len_infer (self ):
63+ disable_check_max_len_infer = os .getenv ("DISABLE_CHECK_MAX_LEN_INFER" , None ) is not None
64+ if disable_check_max_len_infer :
65+ return
66+
67+ try :
68+ dummy_images = torch .randn (
69+ (self .MAX_PATH_NUM * self .max_batch_size , 3 , self .IMAGE_H , self .IMAGE_W ), dtype = self .data_type
70+ ).cuda ()
71+ all_img_embeds = self .forward (dummy_images )
72+ del all_img_embeds
73+ logger .info (f"vit check max_len { self .batch_max_tokens } infer ok" )
74+ except (RuntimeError , torch .OutOfMemoryError ) as e :
75+ logger .exception (str (e ))
76+ exception_str = (
77+ "Vit check max len infer fail, you can try:" "1.Set the --visual_infer_batch_size to a smaller value."
78+ )
79+ logger .error (exception_str )
80+ raise Exception (exception_str )
5681 return
5782
5883 def _init_config (self ):
@@ -66,6 +91,11 @@ def _init_config(self):
6691 repair_config (self .config , same_names = ["hidden_size" , "n_embd" , "n_embed" ])
6792 repair_config (self .config , same_names = ["num_hidden_layers" , "n_layer" ])
6893 self .layers_num = self .config ["num_hidden_layers" ]
94+
95+ # infer info
96+ self .IMAGE_H = int (os .getenv ("IMAGE_H" , 448 ))
97+ self .IMAGE_W = int (os .getenv ("IMAGE_W" , 448 ))
98+ self .MAX_PATH_NUM = os .getenv ("MAX_PATH_NUM" , 13 )
6999 return
70100
71101 def _padding_hidden_size (self ):
0 commit comments