Skip to content

Commit 519ecac

Browse files
authored
Add Dense layer for 2_Dense/ modules (#660)
1 parent 45df4fa commit 519ecac

33 files changed

+8744
-56
lines changed

README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,16 @@ Options:
231231

232232
[env: DEFAULT_PROMPT=]
233233

234+
--dense-path <DENSE_PATH>
235+
Optionally, define the path to the Dense module required for some embedding models.
236+
237+
Some embedding models require an extra `Dense` module which contains a single Linear layer and an activation function. By default, those `Dense` modules are stored under the `2_Dense` directory, but there might be cases where different `Dense` modules are provided, to convert the pooled embeddings into different dimensions, available as `2_Dense_<dims>` e.g. https://huggingface.co/NovaSearch/stella_en_400M_v5.
238+
239+
Note that this argument is optional, only required to be set if the path to the `Dense` module is other than `2_Dense`. And it also applies when leveraging the `candle` backend.
240+
241+
[env: DENSE_PATH=]
242+
[default: 2_Dense]
243+
234244
--hf-token <HF_TOKEN>
235245
Your Hugging Face Hub token
236246

@@ -304,10 +314,10 @@ Options:
304314

305315
[env: CORS_ALLOW_ORIGIN=]
306316

307-
-h, --help
317+
-h, --help
308318
Print help (see a summary with '-h')
309319

310-
-V, --version
320+
-V, --version
311321
Print version
312322
```
313323

backends/candle/src/layers/linear.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,19 @@ impl Linear {
6868
),
6969
}
7070
} else {
71-
let w = match x.dims() {
72-
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
73-
_ => self.weight.t()?,
71+
let (x, w) = match x.dims() {
72+
&[bsize, _, _] => (x, self.weight.broadcast_left(bsize)?.t()?),
73+
// Metal devices require contiguous tensors for 2D matrix multiplication apparently
74+
_ if matches!(x.device(), Device::Metal(_)) => (&x.contiguous()?, self.weight.t()?),
75+
_ => (x, self.weight.t()?),
7476
};
7577
let x = x.matmul(&w)?;
78+
7679
let x = match &self.bias {
7780
None => Ok(x),
7881
Some(bias) => x.broadcast_add(bias),
7982
}?;
83+
8084
if let Some(act) = &self.act {
8185
match act {
8286
HiddenAct::Gelu => x.gelu(),

backends/candle/src/lib.rs

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ use crate::compute_cap::{
1111
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
1212
};
1313
use crate::models::{
14-
BertConfig, BertModel, DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, JinaBertModel,
15-
JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model, ModernBertConfig,
16-
ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model,
14+
BertConfig, BertModel, Dense, DenseConfig, DenseLayer, DistilBertConfig, DistilBertModel,
15+
GTEConfig, GTEModel, JinaBertModel, JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig,
16+
Model, ModernBertConfig, ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config,
17+
Qwen3Config, Qwen3Model,
1718
};
1819
#[cfg(feature = "cuda")]
1920
use crate::models::{
@@ -114,13 +115,15 @@ enum Config {
114115
pub struct CandleBackend {
115116
device: Device,
116117
model: Box<dyn Model + Send>,
118+
dense: Option<Box<dyn DenseLayer + Send>>,
117119
}
118120

119121
impl CandleBackend {
120122
pub fn new(
121123
model_path: &Path,
122124
dtype: String,
123125
model_type: ModelType,
126+
dense_path: Option<&Path>,
124127
) -> Result<Self, BackendError> {
125128
// Default files
126129
let default_safetensors = model_path.join("model.safetensors");
@@ -468,9 +471,50 @@ impl CandleBackend {
468471
}
469472
};
470473

474+
// If `2_Dense/model.safetensors` or `2_Dense/pytorch_model.bin` is amongst the downloaded artifacts, then create a Dense
475+
// block and provide it to the `CandleBackend`, otherwise, None
476+
let dense = if let Some(dense_path) = dense_path {
477+
let dense_safetensors = dense_path.join("model.safetensors");
478+
let dense_pytorch = dense_path.join("pytorch_model.bin");
479+
480+
if dense_safetensors.exists() || dense_pytorch.exists() {
481+
let dense_config_path = dense_path.join("config.json");
482+
483+
let dense_config_str =
484+
std::fs::read_to_string(&dense_config_path).map_err(|err| {
485+
BackendError::Start(format!(
486+
"Unable to read `{dense_path:?}/config.json` file: {err:?}",
487+
))
488+
})?;
489+
let dense_config: DenseConfig =
490+
serde_json::from_str(&dense_config_str).map_err(|err| {
491+
BackendError::Start(format!(
492+
"Unable to parse `{dense_path:?}/config.json`: {err:?}",
493+
))
494+
})?;
495+
496+
let dense_vb = if dense_safetensors.exists() {
497+
unsafe {
498+
VarBuilder::from_mmaped_safetensors(&[dense_safetensors], dtype, &device)
499+
}
500+
.s()?
501+
} else {
502+
VarBuilder::from_pth(&dense_pytorch, dtype, &device).s()?
503+
};
504+
505+
Some(Box::new(Dense::load(dense_vb, &dense_config).s()?)
506+
as Box<dyn DenseLayer + Send>)
507+
} else {
508+
None
509+
}
510+
} else {
511+
None
512+
};
513+
471514
Ok(Self {
472515
device,
473516
model: model?,
517+
dense,
474518
})
475519
}
476520
}
@@ -507,6 +551,19 @@ impl Backend for CandleBackend {
507551
// Run forward
508552
let (pooled_embeddings, raw_embeddings) = self.model.embed(batch).e()?;
509553

554+
// Apply dense layer if available
555+
let pooled_embeddings = match pooled_embeddings {
556+
None => None,
557+
Some(pooled_embeddings) => {
558+
let pooled_embeddings = if let Some(ref dense) = self.dense {
559+
dense.forward(&pooled_embeddings).e()?
560+
} else {
561+
pooled_embeddings
562+
};
563+
Some(pooled_embeddings)
564+
}
565+
};
566+
510567
// Device => Host data transfer
511568
let pooled_embeddings = match pooled_embeddings {
512569
None => vec![],
@@ -540,6 +597,7 @@ impl Backend for CandleBackend {
540597
let batch_size = batch.len();
541598

542599
let results = self.model.predict(batch).e()?;
600+
543601
let results = results.to_dtype(DType::F32).e()?.to_vec2().e()?;
544602

545603
let mut predictions =

backends/candle/src/models/dense.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
use crate::layers::Linear;
2+
use candle::{Result, Tensor};
3+
use candle_nn::VarBuilder;
4+
use serde::Deserialize;
5+
6+
#[derive(Debug, Clone, Deserialize, PartialEq)]
7+
/// The activation functions in `2_Dense/config.json` are defined as PyTorch imports
8+
pub enum DenseActivation {
9+
#[serde(rename = "torch.nn.modules.activation.Tanh")]
10+
/// e.g. https://huggingface.co/sentence-transformers/LaBSE/blob/main/2_Dense/config.json
11+
Tanh,
12+
#[serde(rename = "torch.nn.modules.linear.Identity")]
13+
/// e.g. https://huggingface.co/NovaSearch/stella_en_400M_v5/blob/main/2_Dense/config.json
14+
Identity,
15+
}
16+
17+
impl DenseActivation {
18+
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
19+
match self {
20+
Self::Tanh => x.tanh(),
21+
Self::Identity => Ok(x.clone()),
22+
}
23+
}
24+
}
25+
26+
#[derive(Debug, Clone, PartialEq, Deserialize)]
27+
pub struct DenseConfig {
28+
in_features: usize,
29+
out_features: usize,
30+
bias: bool,
31+
activation_function: Option<DenseActivation>,
32+
}
33+
34+
pub trait DenseLayer {
35+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor>;
36+
}
37+
38+
#[derive(Debug)]
39+
pub struct Dense {
40+
linear: Linear,
41+
activation: DenseActivation,
42+
span: tracing::Span,
43+
}
44+
45+
impl Dense {
46+
pub fn load(vb: VarBuilder, config: &DenseConfig) -> Result<Self> {
47+
let weight = vb.get((config.out_features, config.in_features), "linear.weight")?;
48+
let bias = if config.bias {
49+
Some(vb.get(config.out_features, "linear.bias")?)
50+
} else {
51+
None
52+
};
53+
let linear = Linear::new(weight, bias, None);
54+
55+
let activation = config
56+
.activation_function
57+
.clone()
58+
.unwrap_or(DenseActivation::Identity);
59+
60+
Ok(Self {
61+
linear,
62+
activation,
63+
span: tracing::span!(tracing::Level::TRACE, "dense"),
64+
})
65+
}
66+
}
67+
68+
impl DenseLayer for Dense {
69+
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
70+
let _enter = self.span.enter();
71+
72+
let hidden_states = self.linear.forward(hidden_states)?;
73+
self.activation.forward(&hidden_states)
74+
}
75+
}

backends/candle/src/models/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ extern crate intel_mkl_src;
55
extern crate accelerate_src;
66

77
mod bert;
8+
mod dense;
89
mod distilbert;
910
mod jina;
1011
mod jina_code;
@@ -49,6 +50,7 @@ mod qwen3;
4950

5051
pub use bert::{BertConfig, BertModel, PositionEmbeddingType};
5152
use candle::{Result, Tensor};
53+
pub use dense::{Dense, DenseConfig, DenseLayer};
5254
pub use distilbert::{DistilBertConfig, DistilBertModel};
5355
#[allow(unused_imports)]
5456
pub use gte::{GTEClassificationHead, GTEConfig, GTEModel, GTEMLP};

backends/candle/tests/common.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ pub fn sort_embeddings(embeddings: Embeddings) -> (Vec<Vec<f32>>, Vec<Vec<f32>>)
106106
pub fn download_artifacts(
107107
model_id: &'static str,
108108
revision: Option<&'static str>,
109+
dense_path: Option<&'static str>,
109110
) -> Result<PathBuf> {
110111
let mut builder = ApiBuilder::from_env().with_progress(false);
111112

@@ -140,6 +141,40 @@ pub fn download_artifacts(
140141
vec![p]
141142
}
142143
};
144+
145+
// Download dense path files if specified
146+
if let Some(dense_path) = dense_path {
147+
let dense_config_path = format!("{}/config.json", dense_path);
148+
match api_repo.get(&dense_config_path) {
149+
Ok(_) => tracing::info!("Downloaded dense config: {}", dense_config_path),
150+
Err(err) => tracing::warn!(
151+
"Could not download dense config {}: {}",
152+
dense_config_path,
153+
err
154+
),
155+
}
156+
157+
// Try to download dense model files (safetensors first, then pytorch)
158+
let dense_safetensors_path = format!("{}/model.safetensors", dense_path);
159+
match api_repo.get(&dense_safetensors_path) {
160+
Ok(_) => tracing::info!("Downloaded dense safetensors: {}", dense_safetensors_path),
161+
Err(_) => {
162+
tracing::warn!("Dense safetensors not found. Trying pytorch_model.bin");
163+
let dense_pytorch_path = format!("{}/pytorch_model.bin", dense_path);
164+
match api_repo.get(&dense_pytorch_path) {
165+
Ok(_) => {
166+
tracing::info!("Downloaded dense pytorch model: {}", dense_pytorch_path)
167+
}
168+
Err(err) => tracing::warn!(
169+
"Could not download dense pytorch model {}: {}",
170+
dense_pytorch_path,
171+
err
172+
),
173+
}
174+
}
175+
}
176+
}
177+
143178
let model_root = model_files[0].parent().unwrap().to_path_buf();
144179
Ok(model_root)
145180
}

0 commit comments

Comments
 (0)