Skip to content

Disable Flash Attention with USE_FLASH_ATTENTION #692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,13 @@ Once you're done, someone will review your PR shortly (see the section "Who can

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the
[documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and
[here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the [contributor guideline](https://github.com/huggingface/text-embeddings-inference/blob/main/CONTRIBUTING.md)?
- [ ] Was this discussed/approved via a GitHub issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs).
- [ ] Did you write any new necessary tests? If applicable, did you include or update the `insta` snapshots?

## Who can review?

Expand All @@ -34,7 +29,6 @@ members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @


@OlivierDehaene OR @Narsil
@Narsil OR @alvarobartt

-->
14 changes: 8 additions & 6 deletions .github/workflows/trufflehog.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main
with:
extra_args: --results=verified,unknown --exclude-detectors=postgres
1 change: 0 additions & 1 deletion CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Contributor Covenant Code of Conduct

## Our Pledge
Expand Down
16 changes: 8 additions & 8 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
-->

# Contribute to text-embeddings-inference
# Contribute to Text Embeddings Inference (TEI)

Everyone is welcome to contribute, and we value everybody's contribution. Code
contributions are not the only way to help the community. Answering questions, helping
Expand All @@ -31,7 +31,7 @@ However you choose to contribute, please be mindful and respect our

## Ways to contribute

There are several ways you can contribute to text-embeddings-inference.
There are several ways you can contribute to Text Embeddings Inference (TEI).

* Fix outstanding issues with the existing code.
* Submit issues related to bugs or desired new features.
Expand All @@ -52,7 +52,7 @@ feedback.

### Did you find a bug?

The text-embeddings-inference library is robust and reliable thanks to users who report the problems they encounter.
The Text Embeddings Inference (TEI) solution is robust and reliable thanks to users who report the problems they encounter.

Before you report an issue, we would really appreciate it if you could **make sure the bug was not
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the
Expand All @@ -68,7 +68,7 @@ we can quickly resolve it:

### Do you want a new feature?

If there is a new feature you'd like to see in text-embeddings-inference, please open an issue and describe:
If there is a new feature you'd like to see in Text Embeddings Inference (TEI), please open an issue and describe:

1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it
a feature related to something you need for a project? Is it something you worked on and think it could benefit
Expand All @@ -94,7 +94,7 @@ New models are constantly released and if you want to implement a new model, ple
* Link to the implementation if it is open-sourced.
* Link to the model weights if they are available.

If you are willing to contribute the model yourself, let us know so we can help you add it to text-embeddings-inference!
If you are willing to contribute the model yourself, let us know so we can help you add it to Text Embeddings Inference (TEI)!

## Do you want to add documentation?

Expand All @@ -104,8 +104,8 @@ happy to make the changes or help you make a contribution if you're interested!

## I want to become a maintainer of the project. How do I get there?

TGI is a project led and managed by Hugging Face as it powers our internal services. However, we are happy to have
motivated individuals from other organizations join us as maintainers with the goal of making TGI the best inference
service.
Text Embeddings Inference (TEI) is a project led and managed by Hugging Face as it powers our internal services. However, we are happy to have
motivated individuals from other organizations join us as maintainers with the goal of making TEI the best inference
service for embedding models on production on production.

If you are such an individual (or organization), please reach out to us and let's collaborate.
16 changes: 16 additions & 0 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,10 @@ impl CandleBackend {
if dtype != DType::F16
|| !cfg!(feature = "flash-attn")
|| get_runtime_compute_cap().unwrap() < 80
|| &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
!= "true"
{
return Err(BackendError::Start("Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
}
Expand All @@ -435,6 +439,10 @@ impl CandleBackend {
(Config::Gte(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
|| &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
!= "true"
{
tracing::info!("Starting GTE model on {:?}", device);
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
Expand All @@ -447,6 +455,10 @@ impl CandleBackend {
(Config::Qwen2(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
|| &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
!= "true"
{
return Err(BackendError::Start("Qwen2 is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
}
Expand All @@ -459,6 +471,10 @@ impl CandleBackend {
(Config::Qwen3(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
|| &std::env::var("USE_FLASH_ATTENTION")
.unwrap_or("True".to_string())
.to_lowercase()
!= "true"
{
tracing::info!("Starting Qwen3 model on {:?}", device);
Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?))
Expand Down
2 changes: 1 addition & 1 deletion backends/grpc-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use grpc_metadata::InjectTelemetryContext;
use tonic::transport::{Channel, Uri};
use tracing::instrument;

/// Text Generation Inference gRPC client
/// Text Embeddings Inference gRPC client
#[derive(Debug, Clone)]
pub struct Client {
stub: EmbeddingServiceClient<Channel>,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import torch
from text_embeddings_server.utils.device import use_ipex, is_hpu

import torch
from loguru import logger

from text_embeddings_server.utils.device import is_hpu, use_ipex

if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")

Expand All @@ -30,7 +31,7 @@
except ImportError:
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"Use the official Docker image (ghcr.io/huggingface/text-embeddings-inference:cuda-latest) "
"or install flash attention v2 with `cd server && make install install-flash-attention-v2`"
)
if not (is_sm8x or is_sm90):
Expand All @@ -45,7 +46,7 @@
except ImportError:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"Use the official Docker image (ghcr.io/huggingface/text-embeddings-inference:cuda-latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e

Expand Down
23 changes: 17 additions & 6 deletions backends/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl Backend {
}
}
for shape in shapes.iter() {
let batch = self.create_warmup_batch(*shape, max_token as u32);
let batch = self.create_warmup_batch(*shape, max_token as u32, seq_bucket_size as u32);
match &self.model_type {
ModelType::Classifier => self.predict(batch).await.map(|_| ()),
ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()),
Expand All @@ -179,19 +179,30 @@ impl Backend {
}

#[instrument(skip_all)]
pub fn create_warmup_batch(&self, shape: (u32, u32), max_token: u32) -> Batch {
pub fn create_warmup_batch(
&self,
shape: (u32, u32),
max_token: u32,
seq_bucket_size: u32,
) -> Batch {
let (batch_size, length) = shape;
let min_length = length.saturating_sub(seq_bucket_size).saturating_add(1);
let tmp_length = if min_length < length {
rand::rng().random_range(min_length..length)
} else {
length
};
let mut batched_input_ids = Vec::new();
let mut batched_token_type_ids = Vec::new();
let mut batched_position_ids = Vec::new();
let mut cumulative_seq_lengths = Vec::with_capacity(batch_size as usize + 1);
let mut pooled_indices = Vec::with_capacity(batch_size as usize);
cumulative_seq_lengths.push(0);
let input_ids: Vec<u32> = (0..length)
let input_ids: Vec<u32> = (0..tmp_length)
.map(|_| rand::rng().random_range(0..max_token))
.collect();
let token_type_ids: Vec<u32> = vec![0; length as usize];
let position_ids: Vec<u32> = (0..length).collect();
let token_type_ids: Vec<u32> = vec![0; tmp_length as usize];
let position_ids: Vec<u32> = (0..tmp_length).collect();
let mut current_length = 0;
for batch_id in 0..batch_size {
batched_input_ids.extend(input_ids.iter().cloned());
Expand All @@ -206,7 +217,7 @@ impl Backend {
token_type_ids: batched_token_type_ids,
position_ids: batched_position_ids,
cumulative_seq_lengths,
max_length: length,
max_length: tmp_length,
pooled_indices,
raw_indices: vec![],
}
Expand Down
Loading