Skip to content

Commit 7ac2214

Browse files
committed
Disable Flash Attention with USE_FLASH_ATTENTION
1 parent 2bff275 commit 7ac2214

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

backends/candle/src/lib.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,10 @@ impl CandleBackend {
423423
if dtype != DType::F16
424424
|| !cfg!(feature = "flash-attn")
425425
|| get_runtime_compute_cap().unwrap() < 80
426+
|| &std::env::var("USE_FLASH_ATTENTION")
427+
.unwrap_or("True".to_string())
428+
.to_lowercase()
429+
!= "true"
426430
{
427431
return Err(BackendError::Start("Mistral is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
428432
}
@@ -435,6 +439,10 @@ impl CandleBackend {
435439
(Config::Gte(config), Device::Cuda(_)) => {
436440
if dtype != DType::F16
437441
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
442+
|| &std::env::var("USE_FLASH_ATTENTION")
443+
.unwrap_or("True".to_string())
444+
.to_lowercase()
445+
!= "true"
438446
{
439447
tracing::info!("Starting GTE model on {:?}", device);
440448
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
@@ -447,6 +455,10 @@ impl CandleBackend {
447455
(Config::Qwen2(config), Device::Cuda(_)) => {
448456
if dtype != DType::F16
449457
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
458+
|| &std::env::var("USE_FLASH_ATTENTION")
459+
.unwrap_or("True".to_string())
460+
.to_lowercase()
461+
!= "true"
450462
{
451463
return Err(BackendError::Start("Qwen2 is only supported on Cuda devices in fp16 with flash attention v2 enabled".to_string()));
452464
}
@@ -459,6 +471,10 @@ impl CandleBackend {
459471
(Config::Qwen3(config), Device::Cuda(_)) => {
460472
if dtype != DType::F16
461473
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
474+
|| &std::env::var("USE_FLASH_ATTENTION")
475+
.unwrap_or("True".to_string())
476+
.to_lowercase()
477+
!= "true"
462478
{
463479
tracing::info!("Starting Qwen3 model on {:?}", device);
464480
Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?))

0 commit comments

Comments
 (0)