Skip to content

Commit a595bdb

Browse files
feat: support flash attention for Jina (#119)
1 parent 8c43390 commit a595bdb

File tree

17 files changed

+2927
-295
lines changed

17 files changed

+2927
-295
lines changed

Cargo.lock

Lines changed: 331 additions & 255 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@ homepage = "https://github.com/huggingface/text-embeddings-inference"
1818

1919
[patch.crates-io]
2020
cudarc = { git = "https://github.com/OlivierDehaene/cudarc", rev = "c19522f1e411ab453d71bdfad3383b118cd4216f" }
21-
candle = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-core" }
22-
candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-nn" }
23-
candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-transformers" }
24-
candle-flash-attn = { git = "https://github.com/OlivierDehaene/candle", rev = "9f2b4081b83a0e47ec1b12caa71d3cac7cc2161e", package = "candle-flash-attn" }
21+
candle = { git = "https://github.com/OlivierDehaene/candle", rev = "7a181166d96480ec0302b496469427b3db0ab71b", package = "candle-core" }
22+
candle-nn = { git = "https://github.com/OlivierDehaene/candle", rev = "7a181166d96480ec0302b496469427b3db0ab71b", package = "candle-nn" }
23+
candle-transformers = { git = "https://github.com/OlivierDehaene/candle", rev = "7a181166d96480ec0302b496469427b3db0ab71b", package = "candle-transformers" }
24+
candle-flash-attn = { git = "https://github.com/OlivierDehaene/candle", rev = "7a181166d96480ec0302b496469427b3db0ab71b", package = "candle-flash-attn" }
2525
hf-hub = { git = "https://github.com/huggingface/hf-hub", rev = "b167f69692be5f49eb8003788f7f8a499a98b096" }
2626

27+
2728
[profile.release]
2829
debug = 1
2930
incremental = true

backends/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ clap = ["dep:clap", "text-embeddings-backend-core/clap"]
1818
python = ["dep:text-embeddings-backend-python"]
1919
candle = ["dep:text-embeddings-backend-candle"]
2020
cuda = ["text-embeddings-backend-candle?/cuda"]
21+
metal = ["text-embeddings-backend-candle?/metal"]
2122
mkl = ["text-embeddings-backend-candle?/mkl"]
2223
mkl-dynamic = ["text-embeddings-backend-candle?/mkl-dynamic"]
2324
accelerate = ["text-embeddings-backend-candle?/accelerate"]

backends/candle/Cargo.toml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ homepage.workspace = true
88
[dependencies]
99
accelerate-src = { version = "0.3.2", optional = true }
1010
intel-mkl-src = { version = "0.8.1", optional = true }
11-
candle = { version = "0.3.0", package = "candle-core", default-features = false }
12-
candle-nn = { version = "0.3.0" }
13-
candle-transformers = { version = "0.3.0" }
14-
candle-flash-attn = { version = "0.3.0", optional = true }
15-
candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "62b75f1ea4e0961fad7b983ee8d723ed6fd68be5", optional = true }
16-
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "58684e116aae248c353f87846ddf0b2a8a7ed855", optional = true }
17-
candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "5ed96012a693dff9685320765dd55a57fdaecdd6", optional = true }
11+
candle = { version = "^0.3", package = "candle-core", default-features = false }
12+
candle-nn = { version = "^0.3" }
13+
candle-transformers = { version = "^0.3" }
14+
candle-flash-attn = { version = "^0.3", optional = true }
15+
candle-flash-attn-v1 = { git = "https://github.com/huggingface/candle-flash-attn-v1", rev = "d5b873e4555b7f460ed639d96f26cb014f2daad7", optional = true }
16+
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "c8a810ffe649c5f4634cbe1f0aaf02f6025fe5a5", optional = true }
17+
candle-layer-norm = { git = "https://github.com/huggingface/candle-layer-norm", rev = "0dd5bdceb9ba7cded921c62f9ddd66e7726327ba", optional = true }
1818
text-embeddings-backend-core = { path = "../core" }
1919
tracing = "^0.1"
2020
safetensors = "^0.4"
@@ -36,6 +36,7 @@ anyhow = { version = "1", features = ["backtrace"] }
3636

3737
[features]
3838
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
39+
metal = ["candle/metal", "candle-nn/metal"]
3940
mkl = ["dep:intel-mkl-src", "intel-mkl-src/mkl-static-lp64-iomp", "candle/mkl", "candle-nn/mkl"]
4041
mkl-dynamic = ["dep:intel-mkl-src", "intel-mkl-src/mkl-dynamic-lp64-iomp", "candle/mkl-dynamic", "candle-nn/mkl-dynamic"]
4142
cuda = ["candle/cuda", "candle-nn/cuda", "dep:candle-cublaslt", "dep:candle-layer-norm"]

backends/candle/src/alibi.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ fn get_slopes_power_of_2(n: usize) -> Vec<f64> {
2121
(0..n).map(|i| start * start.powi(i as i32)).collect()
2222
}
2323

24-
fn alibi_head_slopes(num_attention_heads: usize) -> Vec<f64> {
24+
pub fn alibi_head_slopes(num_attention_heads: usize) -> Vec<f64> {
2525
if (num_attention_heads as f64).log2().fract() == 0.0 {
2626
// `num_attention_heads` is a power of 2
2727
get_slopes_power_of_2(num_attention_heads)

backends/candle/src/layers/layer_norm.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ impl LayerNorm {
2727
let _enter = self.span.enter();
2828

2929
match hidden_states.device() {
30-
Device::Cpu => {
30+
Device::Cpu | Device::Metal(_) => {
3131
let hidden_states = hidden_states.add(residual)?;
3232
let hidden_states_dtype = hidden_states.dtype();
3333
let internal_dtype = match hidden_states_dtype {
@@ -61,7 +61,7 @@ impl LayerNorm {
6161
&hidden_states,
6262
&residual,
6363
&self.weight,
64-
&self.bias,
64+
Some(&self.bias),
6565
self.epsilon,
6666
)?;
6767
result.reshape(original_shape)

backends/candle/src/lib.rs

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ use crate::compute_cap::{
1212
};
1313
#[cfg(feature = "cuda")]
1414
use crate::models::FlashBertModel;
15+
#[cfg(feature = "cuda")]
16+
use crate::models::FlashJinaBertModel;
1517
use crate::models::{BertModel, JinaBertModel, Model, PositionEmbeddingType};
1618
use candle::{DType, Device};
1719
use candle_nn::VarBuilder;
@@ -36,10 +38,14 @@ impl CandleBackend {
3638
serde_json::from_str(&config).map_err(|err| BackendError::Start(err.to_string()))?;
3739

3840
// Get candle device
39-
let device = match Device::cuda_if_available(0) {
40-
Ok(device) => device,
41-
Err(err) => return Err(BackendError::Start(err.to_string())),
42-
};
41+
let device = if candle::utils::cuda_is_available() {
42+
Device::new_cuda(0)
43+
} else if candle::utils::metal_is_available() {
44+
Device::new_metal(0)
45+
} else {
46+
Ok(Device::Cpu)
47+
}
48+
.map_err(|err| BackendError::Start(err.to_string()))?;
4349

4450
// Check model type
4551
if config.model_type != Some("bert".to_string())
@@ -79,12 +85,12 @@ impl CandleBackend {
7985
.s()?;
8086

8187
let model: Box<dyn Model + Send> = match device {
82-
Device::Cpu => {
88+
Device::Cpu | Device::Metal(_) => {
8389
if config.position_embedding_type == PositionEmbeddingType::Alibi {
84-
tracing::info!("Starting JinaBert model on CPU");
90+
tracing::info!("Starting JinaBert model on {:?}", device);
8591
Box::new(JinaBertModel::load(vb, &config, model_type).s()?)
8692
} else {
87-
tracing::info!("Starting Bert model on CPU");
93+
tracing::info!("Starting Bert model on {:?}", device);
8894
Box::new(BertModel::load(vb, &config, model_type).s()?)
8995
}
9096
}
@@ -108,6 +114,16 @@ impl CandleBackend {
108114
{
109115
tracing::info!("Starting FlashBert model on Cuda");
110116
Box::new(FlashBertModel::load(vb, &config, model_type).s()?)
117+
} else if cfg!(feature = "flash-attn")
118+
&& dtype == DType::F16
119+
&& config.position_embedding_type == PositionEmbeddingType::Alibi
120+
&& &std::env::var("USE_FLASH_ATTENTION")
121+
.unwrap_or("True".to_string())
122+
.to_lowercase()
123+
== "true"
124+
{
125+
tracing::info!("Starting FlashJinaBertModel model on Cuda");
126+
Box::new(FlashJinaBertModel::load(vb, &config, model_type).s()?)
111127
} else if config.position_embedding_type == PositionEmbeddingType::Alibi {
112128
tracing::info!("Starting JinaBert model on Cuda");
113129
Box::new(JinaBertModel::load(vb, &config, model_type).s()?)

backends/candle/src/models.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,23 @@ extern crate accelerate_src;
66

77
mod bert;
88

9+
#[cfg(feature = "cuda")]
10+
mod flash_bert;
11+
12+
#[cfg(feature = "cuda")]
13+
mod flash_jina;
14+
mod jina;
15+
916
pub use bert::{BertModel, Config, PositionEmbeddingType};
1017
use candle::{Result, Tensor};
1118
pub use jina::JinaBertModel;
1219
use text_embeddings_backend_core::Batch;
1320

1421
#[cfg(feature = "cuda")]
15-
mod flash_bert;
16-
mod jina;
22+
pub use flash_bert::FlashBertModel;
1723

1824
#[cfg(feature = "cuda")]
19-
pub use flash_bert::FlashBertModel;
25+
pub use flash_jina::FlashJinaBertModel;
2026

2127
pub(crate) trait Model {
2228
fn is_padded(&self) -> bool;

backends/candle/src/models/bert.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,6 @@ impl BertModel {
437437
ModelType::Embedding(pool) => (pool, None),
438438
};
439439

440-
// Check pool type
441-
if pool != Pool::Mean && pool != Pool::Cls {
442-
candle::bail!("Pool type {pool:?} is not supported");
443-
}
444-
445440
let (embeddings, encoder) = match (
446441
BertEmbeddings::load(vb.pp("embeddings"), config),
447442
BertEncoder::load(vb.pp("encoder"), config),

backends/candle/src/models/flash_bert.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -353,11 +353,6 @@ impl FlashBertModel {
353353
ModelType::Embedding(pool) => (pool, None),
354354
};
355355

356-
// Check pool type
357-
if pool != Pool::Mean && pool != Pool::Cls {
358-
candle::bail!("Pool type {pool:?} is not supported");
359-
}
360-
361356
let (embeddings, encoder) = match (
362357
BertEmbeddings::load(vb.pp("embeddings"), config),
363358
BertEncoder::load(vb.pp("encoder"), config),

0 commit comments

Comments
 (0)