Skip to content

Commit eebd285

Browse files
cfRodcopybara-github
authored andcommitted
PR #22814: Enable BF16 detection on Aarch64
Imported from GitHub PR openxla/xla#22814 Copybara import of the project: -- 810c0d2f9bcca7b7a8e8594d09ec5109fd35a66e by Crefeda Rodrigues <[email protected]>: Enable BF16 detection on Aarch64 Signed-off-by: Crefeda Rodrigues <[email protected]> Merging this change closes #22814 PiperOrigin-RevId: 753444548
1 parent aaa84db commit eebd285

File tree

4 files changed

+54
-3
lines changed

4 files changed

+54
-3
lines changed

tsl/platform/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,12 @@ cc_library(
693693
"numa.h",
694694
"snappy.h",
695695
],
696-
deps = tf_windows_aware_platform_deps("platform_port"),
696+
deps = tf_windows_aware_platform_deps("platform_port") + [
697+
":platform",
698+
"@xla//xla/tsl/platform:byte_order",
699+
"@xla//xla/tsl/platform:dynamic_annotations",
700+
"@xla//xla/tsl/platform:types",
701+
],
697702
)
698703

699704
cc_library(

tsl/platform/cpu_info.cc

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License.
1515

1616
#include "tsl/platform/cpu_info.h"
1717

18+
#include <string>
19+
1820
#include "absl/base/call_once.h"
1921
#include "xla/tsl/platform/logging.h"
2022
#include "xla/tsl/platform/types.h"
@@ -23,6 +25,7 @@ limitations under the License.
2325
#include <mutex> // NOLINT
2426
#endif
2527
#if defined(PLATFORM_IS_ARM64) && !defined(__APPLE__) && !defined(__OpenBSD__)
28+
#include <asm/hwcap.h> /* Get HWCAP bits from asm/hwcap.h */
2629
#include <sys/auxv.h>
2730
#ifndef HWCAP_CPUID
2831
#define HWCAP_CPUID (1 << 11)
@@ -375,6 +378,7 @@ void InitCPUIDInfo() {
375378

376379
class CPUIDInfo;
377380
void InitCPUIDInfo();
381+
void InitCPUIDFeatureInfo();
378382

379383
CPUIDInfo *cpuid = nullptr;
380384

@@ -386,7 +390,8 @@ class CPUIDInfo {
386390
variant_(0),
387391
cpunum_(0),
388392
is_arm_neoverse_v1_(0),
389-
is_arm_neoverse_n1_(0) {}
393+
is_arm_neoverse_n1_(0),
394+
has_bf16_(0) {}
390395

391396
static void Initialize() {
392397
// Initialize CPUIDInfo pointer.
@@ -458,26 +463,54 @@ class CPUIDInfo {
458463
}
459464
#endif // !PLATFORM_WINDOWS
460465
}
466+
static void InitializeCPUFeature() {
467+
// Initialize CPUIDInfo pointer.
468+
if (cpuid != nullptr) return;
469+
470+
cpuid = new CPUIDInfo;
471+
472+
const uint32_t hwcaps2 = getauxval(AT_HWCAP2);
473+
cpuid->has_bf16_ = IsFeatureSupported(hwcaps2, kHwcap2Bf16);
474+
}
461475

462476
int implementer() const { return implementer_; }
463477
int cpunum() const { return cpunum_; }
464478

465479
static bool TestAarch64CPU(Aarch64CPU cpu) {
466480
InitCPUIDInfo();
481+
// clang-format off
467482
switch (cpu) {
468483
case ARM_NEOVERSE_V1:
469484
return cpuid->is_arm_neoverse_v1_;
470485
default:
471-
return 0;
486+
return false;
472487
}
488+
// clang-format on
489+
return false;
490+
}
491+
492+
static bool IsFeatureSupported(uint64_t features, uint64_t feature_mask) {
493+
return (features & feature_mask);
494+
}
495+
static bool TestAarch64Feature(CPUFeature feature) {
496+
InitCPUIDFeatureInfo();
497+
switch (feature) {
498+
case AARCH64_BF16:
499+
return cpuid->has_bf16_;
500+
default:
501+
break;
502+
}
503+
return false;
473504
}
474505

475506
private:
507+
static constexpr uint64_t kHwcap2Bf16 = 1ull << 14;
476508
int implementer_;
477509
int variant_;
478510
int cpunum_;
479511
int is_arm_neoverse_v1_; // ARM NEOVERSE V1
480512
int is_arm_neoverse_n1_; // ARM NEOVERSE N1
513+
int has_bf16_;
481514
};
482515

483516
absl::once_flag cpuid_once_flag;
@@ -486,13 +519,19 @@ void InitCPUIDInfo() {
486519
absl::call_once(cpuid_once_flag, CPUIDInfo::Initialize);
487520
}
488521

522+
void InitCPUIDFeatureInfo() {
523+
absl::call_once(cpuid_once_flag, CPUIDInfo::InitializeCPUFeature);
524+
}
525+
489526
#endif // PLATFORM_IS_ARM64 && !__APPLE__ && !__OpenBSD__
490527

491528
} // namespace
492529

493530
bool TestCPUFeature(CPUFeature feature) {
494531
#ifdef PLATFORM_IS_X86
495532
return CPUIDInfo::TestFeature(feature);
533+
#elif defined(PLATFORM_IS_ARM64) && !defined(__APPLE__) && !defined(__OpenBSD__)
534+
return CPUIDInfo::TestAarch64Feature(feature);
496535
#else
497536
return false;
498537
#endif

tsl/platform/cpu_info.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ enum CPUFeature {
148148
AARCH64_NEON = 1000,
149149
AARCH64_SVE = 1001,
150150
AARCH64_SVE2 = 1002,
151+
AARCH64_BF16 = 1003, // BF16 on AArch64 systems
151152
};
152153

153154
enum Aarch64CPU {

tsl/platform/cpu_info_test.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,10 @@ TEST(CPUInfo, Aarch64NeoverseV1CPU) {
3333
}
3434
}
3535

36+
TEST(CPUInfo, Aarch64Bf16) {
37+
if (port::TestCPUFeature(port::CPUFeature::AARCH64_BF16)) {
38+
EXPECT_TRUE(port::IsAarch64CPU());
39+
}
40+
}
41+
3642
} // namespace tsl

0 commit comments

Comments
 (0)