@@ -103,11 +103,34 @@ pub fn sort_embeddings(embeddings: Embeddings) -> (Vec<Vec<f32>>, Vec<Vec<f32>>)
103
103
( pooled_embeddings, raw_embeddings)
104
104
}
105
105
106
+ #[ derive( Deserialize , PartialEq ) ]
107
+ enum ModuleType {
108
+ #[ serde( rename = "sentence_transformers.models.Dense" ) ]
109
+ Dense ,
110
+ #[ serde( rename = "sentence_transformers.models.Normalize" ) ]
111
+ Normalize ,
112
+ #[ serde( rename = "sentence_transformers.models.Pooling" ) ]
113
+ Pooling ,
114
+ #[ serde( rename = "sentence_transformers.models.Transformer" ) ]
115
+ Transformer ,
116
+ }
117
+
118
+ #[ derive( Deserialize ) ]
119
+ struct ModuleConfig {
120
+ #[ allow( dead_code) ]
121
+ idx : usize ,
122
+ #[ allow( dead_code) ]
123
+ name : String ,
124
+ path : String ,
125
+ #[ serde( rename = "type" ) ]
126
+ module_type : ModuleType ,
127
+ }
128
+
106
129
pub fn download_artifacts (
107
130
model_id : & ' static str ,
108
131
revision : Option < & ' static str > ,
109
132
dense_path : Option < & ' static str > ,
110
- ) -> Result < PathBuf > {
133
+ ) -> Result < ( PathBuf , Option < Vec < String > > ) > {
111
134
let mut builder = ApiBuilder :: from_env ( ) . with_progress ( false ) ;
112
135
113
136
if let Some ( cache_dir) = std:: env:: var_os ( "HUGGINGFACE_HUB_CACHE" ) {
@@ -142,41 +165,35 @@ pub fn download_artifacts(
142
165
}
143
166
} ;
144
167
145
- // Download dense path files if specified
146
- if let Some ( dense_path) = dense_path {
147
- let dense_config_path = format ! ( "{}/config.json" , dense_path) ;
148
- match api_repo. get ( & dense_config_path) {
149
- Ok ( _) => tracing:: info!( "Downloaded dense config: {}" , dense_config_path) ,
150
- Err ( err) => tracing:: warn!(
151
- "Could not download dense config {}: {}" ,
152
- dense_config_path,
153
- err
154
- ) ,
155
- }
156
-
157
- // Try to download dense model files (safetensors first, then pytorch)
158
- let dense_safetensors_path = format ! ( "{}/model.safetensors" , dense_path) ;
159
- match api_repo. get ( & dense_safetensors_path) {
160
- Ok ( _) => tracing:: info!( "Downloaded dense safetensors: {}" , dense_safetensors_path) ,
161
- Err ( _) => {
162
- tracing:: warn!( "Dense safetensors not found. Trying pytorch_model.bin" ) ;
163
- let dense_pytorch_path = format ! ( "{}/pytorch_model.bin" , dense_path) ;
164
- match api_repo. get ( & dense_pytorch_path) {
165
- Ok ( _) => {
166
- tracing:: info!( "Downloaded dense pytorch model: {}" , dense_pytorch_path)
168
+ let dense_paths = if let Ok ( modules_path) = api_repo. get ( "modules.json" ) {
169
+ match parse_dense_paths_from_modules ( & modules_path) {
170
+ Ok ( paths) => match paths. len ( ) {
171
+ 0 => None ,
172
+ 1 => {
173
+ let path = if let Some ( path) = dense_path {
174
+ path. to_string ( )
175
+ } else {
176
+ paths[ 0 ] . clone ( )
177
+ } ;
178
+
179
+ download_dense_module ( & api_repo, & path) ?;
180
+ Some ( vec ! [ path] )
181
+ }
182
+ _ => {
183
+ for path in & paths {
184
+ download_dense_module ( & api_repo, & path) ?;
167
185
}
168
- Err ( err) => tracing:: warn!(
169
- "Could not download dense pytorch model {}: {}" ,
170
- dense_pytorch_path,
171
- err
172
- ) ,
186
+ Some ( paths)
173
187
}
174
- }
188
+ } ,
189
+ _ => None ,
175
190
}
176
- }
191
+ } else {
192
+ None
193
+ } ;
177
194
178
195
let model_root = model_files[ 0 ] . parent ( ) . unwrap ( ) . to_path_buf ( ) ;
179
- Ok ( model_root)
196
+ Ok ( ( model_root, dense_paths ) )
180
197
}
181
198
182
199
fn download_safetensors ( api : & ApiRepo ) -> Result < Vec < PathBuf > , ApiError > {
@@ -218,6 +235,38 @@ fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
218
235
Ok ( safetensors_files)
219
236
}
220
237
238
+ fn parse_dense_paths_from_modules ( modules_path : & PathBuf ) -> Result < Vec < String > , std:: io:: Error > {
239
+ let content = std:: fs:: read_to_string ( modules_path) ?;
240
+ let modules: Vec < ModuleConfig > = serde_json:: from_str ( & content)
241
+ . map_err ( |err| std:: io:: Error :: new ( std:: io:: ErrorKind :: InvalidData , err) ) ?;
242
+
243
+ Ok ( modules
244
+ . into_iter ( )
245
+ . filter ( |module| module. module_type == ModuleType :: Dense )
246
+ . map ( |module| module. path )
247
+ . collect :: < Vec < String > > ( ) )
248
+ }
249
+
250
+ fn download_dense_module ( api : & ApiRepo , dense_path : & str ) -> Result < PathBuf , ApiError > {
251
+ let config_file = format ! ( "{}/config.json" , dense_path) ;
252
+ tracing:: info!( "Downloading `{}`" , config_file) ;
253
+ let config_path = api. get ( & config_file) ?;
254
+
255
+ let safetensors_file = format ! ( "{}/model.safetensors" , dense_path) ;
256
+ tracing:: info!( "Downloading `{}`" , safetensors_file) ;
257
+ match api. get ( & safetensors_file) {
258
+ Ok ( _) => { }
259
+ Err ( err) => {
260
+ tracing:: warn!( "Could not download `{}`: {}" , safetensors_file, err) ;
261
+ let pytorch_file = format ! ( "{}/pytorch_model.bin" , dense_path) ;
262
+ tracing:: info!( "Downloading `{}`" , pytorch_file) ;
263
+ api. get ( & pytorch_file) ?;
264
+ }
265
+ }
266
+
267
+ Ok ( config_path. parent ( ) . unwrap ( ) . to_path_buf ( ) )
268
+ }
269
+
221
270
#[ allow( unused) ]
222
271
pub ( crate ) fn relative_matcher ( ) -> YamlMatcher < SnapshotScores > {
223
272
YamlMatcher :: new ( )
0 commit comments