From c4aa5c42fb69f432a06605c3c6c685b3b808f1f0 Mon Sep 17 00:00:00 2001 From: Chao Chen Date: Tue, 18 Feb 2025 02:38:47 +0800 Subject: [PATCH] amd bf16 gpu float support --- third_party/xla/xla/service/gpu/gpu_float_support.cc | 10 ++++++---- .../xla/xla/service/gpu/gpu_float_support_test.cc | 6 ++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gpu_float_support.cc b/third_party/xla/xla/service/gpu/gpu_float_support.cc index 1403ad021a217d..5e8a2c59179f0d 100644 --- a/third_party/xla/xla/service/gpu/gpu_float_support.cc +++ b/third_party/xla/xla/service/gpu/gpu_float_support.cc @@ -99,10 +99,12 @@ bool GpuFloatSupport::IsSupported(const HloInstruction& hlo) const { case HloOpcode::kSubtract: case HloOpcode::kMultiply: { if (LowPrecisionType() == BF16) { - auto* cuda_compute_capability = - std::get_if(&compute_capability_); - return cuda_compute_capability != nullptr && - cuda_compute_capability->IsAtLeastHopper(); + if (std::holds_alternative(compute_capability_)){ + return std::get(compute_capability_).IsAtLeastHopper(); + } + else if (std::holds_alternative(compute_capability_)){ + return std::get(compute_capability_).gfx9_mi200_or_later(); + } } return false; } diff --git a/third_party/xla/xla/service/gpu/gpu_float_support_test.cc b/third_party/xla/xla/service/gpu/gpu_float_support_test.cc index 1d2f6c167bb090..da4d266e697268 100644 --- a/third_party/xla/xla/service/gpu/gpu_float_support_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_float_support_test.cc @@ -253,5 +253,11 @@ TEST_F(FloatSupportTest, ShouldKeepBf16OnHopper) { /*should_convert_rhs=*/false, BF16); } +TEST_F(FloatSupportTest, ShouldKeepBf16OnMI200orLater) { + TestDotConversion(BF16, BF16, F32, se::RocmComputeCapability("gfx940"), + /*should_convert_lhs=*/false, + /*should_convert_rhs=*/false, BF16); + } + } // namespace } // namespace xla::gpu