@@ -242,7 +242,7 @@ def get_non_negative_vision_feature_layers(v_hparams):
242242 the model as an unset value. If no vision feature layer is found, we leave it unset.
243243 """
244244 num_hidden_layers = v_hparams ["num_hidden_layers" ]
245- to_uint = lambda layer_idx : layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
245+ to_non_negative = lambda layer_idx : layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
246246 feature_layers_key = None
247247 # Key used for llava models in transformers
248248 if "vision_feature_layer" in config :
@@ -254,11 +254,12 @@ def get_non_negative_vision_feature_layers(v_hparams):
254254 feature_layers = config [feature_layers_key ]
255255 if isinstance (feature_layers , int ):
256256 feature_layers = [feature_layers ]
257- return [to_uint (feature_layer ) for feature_layer in feature_layers ]
257+ return [to_non_negative (feature_layer ) for feature_layer in feature_layers ]
258258
259- if has_vision_encoder :
260- feature_layers = get_non_negative_vision_feature_layers (v_hparams )
259+ # Determine if we have explicitly specified vision feature layers in our config
260+ feature_layers = get_non_negative_vision_feature_layers (v_hparams )
261261
262+ if has_vision_encoder :
262263 # Siglip does not have a visual projector; set projection dim to 0
263264 if args .clip_model_is_siglip :
264265 visual_projection_dim = 0
@@ -273,7 +274,10 @@ def get_non_negative_vision_feature_layers(v_hparams):
273274 fout .add_uint32 ("clip.vision.projection_dim" , visual_projection_dim )
274275 fout .add_uint32 (k (KEY_ATTENTION_HEAD_COUNT , VISION ), v_hparams ["num_attention_heads" ])
275276 fout .add_float32 (k (KEY_ATTENTION_LAYERNORM_EPS , VISION ), v_hparams ["layer_norm_eps" ])
276- block_count = v_hparams ["num_hidden_layers" ]
277+ if feature_layers :
278+ block_count = max (feature_layers )
279+ else :
280+ block_count = v_hparams ["num_hidden_layers" ] - 1 if has_llava_projector else v_hparams ["num_hidden_layers" ]
277281 fout .add_uint32 (k (KEY_BLOCK_COUNT , VISION ), block_count )
278282 # /**
279283 # "image_grid_pinpoints": [
@@ -342,6 +346,13 @@ def get_non_negative_vision_feature_layers(v_hparams):
342346
343347
344348if has_llava_projector :
349+ # By default, we drop the last layer for llava projector
350+ # models unless we have explicitly set vision feature layers
351+ if feature_layers is None :
352+ model .vision_model .encoder .layers .pop (- 1 )
353+ else :
354+ model .vision_model .encoder .layers = model .vision_model .encoder .layers [:max (feature_layers )]
355+
345356 projector = torch .load (args .llava_projector )
346357 for name , data in projector .items ():
347358 name = get_tensor_name (name )
0 commit comments