diff --git a/src/models/model.cc b/src/models/model.cc index b8e1c2d8f..f9270d10c 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -844,16 +844,19 @@ namespace ctranslate2 { " running independently a model in each device"); } - bool is_sm8x = false; - bool is_sm90 = false; + bool supports_flash_attention = false; if (device == Device::CUDA) { int device_id = ctranslate2::get_device_index(ctranslate2::Device::CUDA); auto dprops = ctranslate2::cuda::get_device_properties(device_id); - is_sm8x = dprops.major == 8 && dprops.minor >= 0; - is_sm90 = dprops.major == 9 && dprops.minor == 0; - } - if (use_flash_attention && (device != Device::CUDA || (!is_sm8x && !is_sm90))) { - throw std::invalid_argument("FlashAttention only supports Ampere GPUs or newer."); + float compute_capability = dprops.major + (dprops.minor / 10.0f); + + // Minimum compute capability for Flash Attention is Ampere (8.0) + const float min_flash_attn_compute_capability = 8.0f; + supports_flash_attention = compute_capability >= min_flash_attn_compute_capability; + } + + if (use_flash_attention && (device != Device::CUDA || !supports_flash_attention)) { + throw std::invalid_argument("FlashAttention only supports Ampere GPUs (compute capability >= 8.0) or newer."); } #endif