33from lightllm .models .llama .layer_weights .pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
44
55from lightllm .models .internlm2 .layer_weights .pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
6+ from lightllm .models .vit .model import VisionTransformer
7+ from lightllm .utils .envs_utils import get_env_start_args
8+ from lightllm .common .image_cache_manager import image_cache_manager
69
710
811# add key: language_model.xxx -> xxx
@@ -15,9 +18,45 @@ def rename_weight_keys(weights):
1518 weights [k [len (prefix ) :]] = weights [k ]
1619
1720
21+ class InternVLPreAndPostLayerWeight (LlamaPreAndPostLayerWeight ):
22+ def __init__ (self , data_type , network_config , mode ):
23+ super ().__init__ (data_type , network_config , mode )
24+ # if we don't assign an extra process for visual model, we need initialize the image cache manager here
25+ if get_env_start_args ().disable_extra_process_for_multimodal :
26+ kvargs = {
27+ "weight_dir" : get_env_start_args ().model_dir ,
28+ "data_type" : self .data_type_ ,
29+ "quant_type" : get_env_start_args ().vit_quant_type ,
30+ "quant_cfg" : get_env_start_args ().vit_quant_cfg ,
31+ "max_batch_size" : get_env_start_args ().visual_infer_batch_size ,
32+ }
33+ self .visual_model = VisionTransformer (
34+ kvargs = kvargs ,
35+ )
36+ image_cache_manager .set_max_size (get_env_start_args ().cache_capacity * 2 )
37+ return
38+
39+ def load_hf_weights (self , weights ):
40+ rename_weight_keys (weights )
41+ super ().load_hf_weights (weights )
42+
43+
1844class InternVLPhi3PreAndPostLayerWeight (LlamaPreAndPostLayerWeight ):
1945 def __init__ (self , data_type , network_config , mode ):
2046 super ().__init__ (data_type , network_config , mode )
47+ # if we don't assign an extra process for visual model, we need initialize the image cache manager here
48+ if get_env_start_args ().disable_extra_process_for_multimodal :
49+ kvargs = {
50+ "weight_dir" : get_env_start_args ().model_dir ,
51+ "data_type" : self .data_type_ ,
52+ "quant_type" : get_env_start_args ().vit_quant_type ,
53+ "quant_cfg" : get_env_start_args ().vit_quant_cfg ,
54+ "max_batch_size" : get_env_start_args ().visual_infer_batch_size ,
55+ }
56+ self .visual_model = VisionTransformer (
57+ kvargs = kvargs ,
58+ )
59+ image_cache_manager .set_max_size (get_env_start_args ().cache_capacity * 2 )
2160 return
2261
2362 def load_hf_weights (self , weights ):
@@ -29,6 +68,19 @@ def load_hf_weights(self, weights):
2968class InternVLInternlm2PreAndPostLayerWeight (Internlm2PreAndPostLayerWeight ):
3069 def __init__ (self , data_type , network_config , mode ):
3170 super ().__init__ (data_type , network_config , mode )
71+ # if we don't assign an extra process for visual model, we need initialize the image cache manager here
72+ if get_env_start_args ().disable_extra_process_for_multimodal :
73+ kvargs = {
74+ "weight_dir" : get_env_start_args ().model_dir ,
75+ "data_type" : self .data_type_ ,
76+ "quant_type" : get_env_start_args ().vit_quant_type ,
77+ "quant_cfg" : get_env_start_args ().vit_quant_cfg ,
78+ "max_batch_size" : get_env_start_args ().visual_infer_batch_size ,
79+ }
80+ self .visual_model = VisionTransformer (
81+ kvargs = kvargs ,
82+ )
83+ image_cache_manager .set_max_size (get_env_start_args ().cache_capacity * 2 )
3284 return
3385
3486 def load_hf_weights (self , weights ):
@@ -40,6 +92,19 @@ def load_hf_weights(self, weights):
4092class InternVLLlamaPreAndPostLayerWeight (LlamaPreAndPostLayerWeight ):
4193 def __init__ (self , data_type , network_config , mode ):
4294 super ().__init__ (data_type , network_config , mode )
95+ # if we don't assign an extra process for visual model, we need initialize the image cache manager here
96+ if get_env_start_args ().disable_extra_process_for_multimodal :
97+ kvargs = {
98+ "weight_dir" : get_env_start_args ().model_dir ,
99+ "data_type" : self .data_type_ ,
100+ "quant_type" : get_env_start_args ().vit_quant_type ,
101+ "quant_cfg" : get_env_start_args ().vit_quant_cfg ,
102+ "max_batch_size" : get_env_start_args ().visual_infer_batch_size ,
103+ }
104+ self .visual_model = VisionTransformer (
105+ kvargs = kvargs ,
106+ )
107+ image_cache_manager .set_max_size (get_env_start_args ().cache_capacity * 2 )
43108 return
44109
45110 def load_hf_weights (self , weights ):
0 commit comments