@@ -96,13 +96,15 @@ class ModelBase:
9696 # Mistral format specifics
9797 is_mistral_format : bool = False
9898 disable_mistral_community_chat_template : bool = False
99+ sentence_transformers_dense_modules : bool = False
99100
100101 def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , * , is_big_endian : bool = False ,
101102 use_temp_file : bool = False , eager : bool = False ,
102103 metadata_override : Path | None = None , model_name : str | None = None ,
103104 split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False ,
104105 small_first_shard : bool = False , hparams : dict [str , Any ] | None = None , remote_hf_model_id : str | None = None ,
105- disable_mistral_community_chat_template : bool = False ):
106+ disable_mistral_community_chat_template : bool = False ,
107+ sentence_transformers_dense_modules : bool = False ):
106108 if type (self ) is ModelBase or \
107109 type (self ) is TextModel or \
108110 type (self ) is MmprojModel :
@@ -117,6 +119,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
117119 self .lazy = not eager or (remote_hf_model_id is not None )
118120 self .dry_run = dry_run
119121 self .remote_hf_model_id = remote_hf_model_id
122+ self .sentence_transformers_dense_modules = sentence_transformers_dense_modules
120123 if remote_hf_model_id is not None :
121124 self .is_safetensors = True
122125
@@ -5274,6 +5277,53 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
52745277@ModelBase .register ("Gemma3TextModel" )
52755278class EmbeddingGemma (Gemma3Model ):
52765279 model_arch = gguf .MODEL_ARCH .GEMMA_EMBEDDING
5280+ module_paths = []
5281+ dense_features_dims = {}
5282+
5283+ def __init__ (self , * args , ** kwargs ):
5284+ super ().__init__ (* args , ** kwargs )
5285+ if self .sentence_transformers_dense_modules :
5286+ # read modules.json to determine if model has Dense layers
5287+ modules_file = self .dir_model / "modules.json"
5288+ if modules_file .is_file ():
5289+ with open (modules_file , encoding = "utf-8" ) as modules_json_file :
5290+ mods = json .load (modules_json_file )
5291+ for mod in mods :
5292+ if mod ["type" ] == "sentence_transformers.models.Dense" :
5293+ mod_path = mod ["path" ]
5294+ # check if model.safetensors file for Dense layer exists
5295+ model_tensors_file = self .dir_model / mod_path / "model.safetensors"
5296+ if model_tensors_file .is_file ():
5297+ self .module_paths .append (mod_path )
5298+ # read config.json of the Dense layer to get in/out features
5299+ mod_conf_file = self .dir_model / mod_path / "config.json"
5300+ if mod_conf_file .is_file ():
5301+ with open (mod_conf_file , encoding = "utf-8" ) as mod_conf_json_file :
5302+ mod_conf = json .load (mod_conf_json_file )
5303+ # hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
5304+ prefix = self ._get_dense_prefix (mod_path )
5305+ if mod_conf ["in_features" ] is not None and mod_conf ["out_features" ] is not None :
5306+ self .dense_features_dims [prefix ] = (mod_conf ["in_features" ], mod_conf ["out_features" ])
5307+
5308+ def generate_extra_tensors (self ) -> Iterable [tuple [str , Tensor ]]:
5309+ from safetensors .torch import load_file
5310+ module_paths = list (self .module_paths )
5311+ for i , module_path in enumerate (module_paths ):
5312+ tensors_file = self .dir_model / module_path / "model.safetensors"
5313+ local_tensors = load_file (tensors_file )
5314+ tensor_name = self ._get_dense_prefix (module_path )
5315+ for name , local_tensor in local_tensors .items ():
5316+ if not name .endswith (".weight" ):
5317+ continue
5318+ orig_name = name .replace ("linear" , tensor_name )
5319+ name = self .map_tensor_name (orig_name )
5320+ yield name , local_tensor .clone ()
5321+
5322+ @staticmethod
5323+ def _get_dense_prefix (module_path ) -> str :
5324+ """Get the tensor name prefix for the Dense layer from module path."""
5325+ tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
5326+ return tensor_name
52775327
52785328 def set_gguf_parameters (self ):
52795329 super ().set_gguf_parameters ()
@@ -5290,6 +5340,10 @@ def set_gguf_parameters(self):
52905340 logger .info (f"Using original sliding_window from config: { orig_sliding_window } "
52915341 f"instead of { self .hparams ['sliding_window' ]} " )
52925342 self .gguf_writer .add_sliding_window (orig_sliding_window )
5343+ if self .sentence_transformers_dense_modules :
5344+ for dense , dims in self .dense_features_dims .items ():
5345+ logger .info (f"Setting dense layer { dense } in/out features to { dims } " )
5346+ self .gguf_writer .add_dense_features_dims (dense , dims [0 ], dims [1 ])
52935347
52945348 self ._try_set_pooling_type ()
52955349
@@ -9340,6 +9394,13 @@ def parse_args() -> argparse.Namespace:
93409394 )
93419395 )
93429396
9397+ parser .add_argument (
9398+ "--sentence-transformers-dense-modules" , action = "store_true" ,
9399+ help = ("Whether to include sentence-transformers dense modules."
9400+ "It can be used for sentence-transformers models, like google/embeddinggemma-300m"
9401+ "Default these modules are not included." )
9402+ )
9403+
93439404 args = parser .parse_args ()
93449405 if not args .print_supported_models and args .model is None :
93459406 parser .error ("the following arguments are required: model" )
@@ -9402,9 +9463,13 @@ def main() -> None:
94029463 if args .remote :
94039464 hf_repo_id = args .model
94049465 from huggingface_hub import snapshot_download
9466+ allowed_patterns = ["LICENSE" , "*.json" , "*.md" , "*.txt" , "tokenizer.model" ]
9467+ if args .sentence_transformers_dense_modules :
9468+ # include sentence-transformers dense modules safetensors files
9469+ allowed_patterns .append ("*.safetensors" )
94059470 local_dir = snapshot_download (
94069471 repo_id = hf_repo_id ,
9407- allow_patterns = [ "LICENSE" , "*.json" , "*.md" , "*.txt" , "tokenizer.model" ] )
9472+ allow_patterns = allowed_patterns )
94089473 dir_model = Path (local_dir )
94099474 logger .info (f"Downloaded config and tokenizer to { local_dir } " )
94109475 else :
@@ -9472,7 +9537,8 @@ def main() -> None:
94729537 split_max_tensors = args .split_max_tensors ,
94739538 split_max_size = split_str_to_n_bytes (args .split_max_size ), dry_run = args .dry_run ,
94749539 small_first_shard = args .no_tensor_first_split ,
9475- remote_hf_model_id = hf_repo_id , disable_mistral_community_chat_template = disable_mistral_community_chat_template
9540+ remote_hf_model_id = hf_repo_id , disable_mistral_community_chat_template = disable_mistral_community_chat_template ,
9541+ sentence_transformers_dense_modules = args .sentence_transformers_dense_modules
94769542 )
94779543
94789544 if args .vocab_only :
0 commit comments