From 45a15c88b86338820211fa6fd5e4f410b24ebff7 Mon Sep 17 00:00:00 2001 From: CNE FICHEPOIL Pierre Date: Mon, 7 Apr 2025 16:28:25 +0200 Subject: [PATCH 1/2] improve Flash Attention compatibility check for SM_1xx --- src/models/model.cc | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/models/model.cc b/src/models/model.cc index b8e1c2d8f..3123570d3 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -844,16 +844,18 @@ namespace ctranslate2 { " running independently a model in each device"); } - bool is_sm8x = false; - bool is_sm90 = false; if (device == Device::CUDA) { int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); auto dprops = ctranslate2::cuda::get_device_properties(device_id); - is_sm8x = dprops.major == 8 && dprops.minor >= 0; - is_sm90 = dprops.major == 9 && dprops.minor == 0; - } - if (use_flash_attention && (device != Device::CUDA || (!is_sm8x && !is_sm90))) { - throw std::invalid_argument("FlashAttention only supports Ampere GPUs or newer."); + float compute_capability = dprops.major + (dprops.minor / 10.0f); + + // Minimum compute capability for Flash Attention is Ampere (8.0) + const float min_flash_attn_compute_capability = 8.0f; + bool supports_flash_attention = compute_capability >= min_flash_attn_compute_capability; + } + + if (use_flash_attention && (device != Device::CUDA || !supports_flash_attention)) { + throw std::invalid_argument("FlashAttention only supports Ampere GPUs (compute capability >= 8.0) or newer."); } #endif From ea1b6b0eb5c166544940514a86e0635a3791fac9 Mon Sep 17 00:00:00 2001 From: ice0 Date: Sat, 12 Apr 2025 13:05:43 +0700 Subject: [PATCH 2/2] fix: Flash Attention compatibility check for SM_1xx (RTX 5000 series) Fixed build error Includes the original commit from PR #1873 --- src/models/model.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/models/model.cc b/src/models/model.cc index 3123570d3..f9270d10c 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -844,6 +844,7 @@ namespace ctranslate2 { " running independently a model in each device"); } + bool supports_flash_attention = false; if (device == Device::CUDA) { int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); auto dprops = ctranslate2::cuda::get_device_properties(device_id); @@ -851,7 +852,7 @@ namespace ctranslate2 { // Minimum compute capability for Flash Attention is Ampere (8.0) const float min_flash_attn_compute_capability = 8.0f; - bool supports_flash_attention = compute_capability >= min_flash_attn_compute_capability; + supports_flash_attention = compute_capability >= min_flash_attn_compute_capability; } if (use_flash_attention && (device != Device::CUDA || !supports_flash_attention)) {