Skip to content

Commit 8b74ae5

Browse files
authored
Parse modules.json to identify default Dense modules (#701)
1 parent 4fcd45d commit 8b74ae5

31 files changed

+7672
-8525
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backends/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ homepage.workspace = true
88
[dependencies]
99
clap = { workspace = true, optional = true }
1010
hf-hub = { workspace = true }
11+
serde = { workspace = true }
1112
serde_json = { workspace = true }
1213
text-embeddings-backend-core = { path = "core" }
1314
text-embeddings-backend-python = { path = "python", optional = true }

backends/candle/src/lib.rs

Lines changed: 138 additions & 126 deletions
Large diffs are not rendered by default.

backends/candle/tests/common.rs

Lines changed: 80 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,34 @@ pub fn sort_embeddings(embeddings: Embeddings) -> (Vec<Vec<f32>>, Vec<Vec<f32>>)
103103
(pooled_embeddings, raw_embeddings)
104104
}
105105

106+
#[derive(Deserialize, PartialEq)]
107+
enum ModuleType {
108+
#[serde(rename = "sentence_transformers.models.Dense")]
109+
Dense,
110+
#[serde(rename = "sentence_transformers.models.Normalize")]
111+
Normalize,
112+
#[serde(rename = "sentence_transformers.models.Pooling")]
113+
Pooling,
114+
#[serde(rename = "sentence_transformers.models.Transformer")]
115+
Transformer,
116+
}
117+
118+
#[derive(Deserialize)]
119+
struct ModuleConfig {
120+
#[allow(dead_code)]
121+
idx: usize,
122+
#[allow(dead_code)]
123+
name: String,
124+
path: String,
125+
#[serde(rename = "type")]
126+
module_type: ModuleType,
127+
}
128+
106129
pub fn download_artifacts(
107130
model_id: &'static str,
108131
revision: Option<&'static str>,
109132
dense_path: Option<&'static str>,
110-
) -> Result<PathBuf> {
133+
) -> Result<(PathBuf, Option<Vec<String>>)> {
111134
let mut builder = ApiBuilder::from_env().with_progress(false);
112135

113136
if let Some(cache_dir) = std::env::var_os("HUGGINGFACE_HUB_CACHE") {
@@ -142,41 +165,35 @@ pub fn download_artifacts(
142165
}
143166
};
144167

145-
// Download dense path files if specified
146-
if let Some(dense_path) = dense_path {
147-
let dense_config_path = format!("{}/config.json", dense_path);
148-
match api_repo.get(&dense_config_path) {
149-
Ok(_) => tracing::info!("Downloaded dense config: {}", dense_config_path),
150-
Err(err) => tracing::warn!(
151-
"Could not download dense config {}: {}",
152-
dense_config_path,
153-
err
154-
),
155-
}
156-
157-
// Try to download dense model files (safetensors first, then pytorch)
158-
let dense_safetensors_path = format!("{}/model.safetensors", dense_path);
159-
match api_repo.get(&dense_safetensors_path) {
160-
Ok(_) => tracing::info!("Downloaded dense safetensors: {}", dense_safetensors_path),
161-
Err(_) => {
162-
tracing::warn!("Dense safetensors not found. Trying pytorch_model.bin");
163-
let dense_pytorch_path = format!("{}/pytorch_model.bin", dense_path);
164-
match api_repo.get(&dense_pytorch_path) {
165-
Ok(_) => {
166-
tracing::info!("Downloaded dense pytorch model: {}", dense_pytorch_path)
168+
let dense_paths = if let Ok(modules_path) = api_repo.get("modules.json") {
169+
match parse_dense_paths_from_modules(&modules_path) {
170+
Ok(paths) => match paths.len() {
171+
0 => None,
172+
1 => {
173+
let path = if let Some(path) = dense_path {
174+
path.to_string()
175+
} else {
176+
paths[0].clone()
177+
};
178+
179+
download_dense_module(&api_repo, &path)?;
180+
Some(vec![path])
181+
}
182+
_ => {
183+
for path in &paths {
184+
download_dense_module(&api_repo, &path)?;
167185
}
168-
Err(err) => tracing::warn!(
169-
"Could not download dense pytorch model {}: {}",
170-
dense_pytorch_path,
171-
err
172-
),
186+
Some(paths)
173187
}
174-
}
188+
},
189+
_ => None,
175190
}
176-
}
191+
} else {
192+
None
193+
};
177194

178195
let model_root = model_files[0].parent().unwrap().to_path_buf();
179-
Ok(model_root)
196+
Ok((model_root, dense_paths))
180197
}
181198

182199
fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
@@ -218,6 +235,38 @@ fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
218235
Ok(safetensors_files)
219236
}
220237

238+
fn parse_dense_paths_from_modules(modules_path: &PathBuf) -> Result<Vec<String>, std::io::Error> {
239+
let content = std::fs::read_to_string(modules_path)?;
240+
let modules: Vec<ModuleConfig> = serde_json::from_str(&content)
241+
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
242+
243+
Ok(modules
244+
.into_iter()
245+
.filter(|module| module.module_type == ModuleType::Dense)
246+
.map(|module| module.path)
247+
.collect::<Vec<String>>())
248+
}
249+
250+
fn download_dense_module(api: &ApiRepo, dense_path: &str) -> Result<PathBuf, ApiError> {
251+
let config_file = format!("{}/config.json", dense_path);
252+
tracing::info!("Downloading `{}`", config_file);
253+
let config_path = api.get(&config_file)?;
254+
255+
let safetensors_file = format!("{}/model.safetensors", dense_path);
256+
tracing::info!("Downloading `{}`", safetensors_file);
257+
match api.get(&safetensors_file) {
258+
Ok(_) => {}
259+
Err(err) => {
260+
tracing::warn!("Could not download `{}`: {}", safetensors_file, err);
261+
let pytorch_file = format!("{}/pytorch_model.bin", dense_path);
262+
tracing::info!("Downloading `{}`", pytorch_file);
263+
api.get(&pytorch_file)?;
264+
}
265+
}
266+
267+
Ok(config_path.parent().unwrap().to_path_buf())
268+
}
269+
221270
#[allow(unused)]
222271
pub(crate) fn relative_matcher() -> YamlMatcher<SnapshotScores> {
223272
YamlMatcher::new()

0 commit comments

Comments
 (0)