Skip to content

Commit 89a2ff9

Browse files
authored
[TensorRT EP] Address GPU bf16 support check (microsoft#24915)
BF16 support is primarily available on NVIDIA GPUs with the Ampere and later architectures with compute capability of 8.0 or higher. If trt_bf16_enable = true and compute capability < 8, TRT EP will make trt_bf16_enable = false
1 parent 196dea1 commit 89a2ff9

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
13691369
max_workspace_size_ = info.max_workspace_size;
13701370
fp16_enable_ = info.fp16_enable;
13711371
bf16_enable_ = info.bf16_enable;
1372+
// BF16 support is primarily available on NVIDIA GPUs with the Ampere and later architectures with compute capability of 8.0 or higher.
1373+
if (bf16_enable_ && prop.major < 8) {
1374+
bf16_enable_ = false;
1375+
LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_bf16_enable is set, but platform doesn't support bf16.";
1376+
}
13721377
int8_enable_ = info.int8_enable;
13731378
if (int8_enable_) {
13741379
int8_calibration_cache_name_ = info.int8_calibration_table_name;

0 commit comments

Comments
 (0)