@@ -15,6 +15,8 @@ limitations under the License.
1515
1616#include " tsl/platform/cpu_info.h"
1717
18+ #include < string>
19+
1820#include " absl/base/call_once.h"
1921#include " xla/tsl/platform/logging.h"
2022#include " xla/tsl/platform/types.h"
@@ -23,6 +25,7 @@ limitations under the License.
2325#include < mutex> // NOLINT
2426#endif
2527#if defined(PLATFORM_IS_ARM64) && !defined(__APPLE__) && !defined(__OpenBSD__)
28+ #include < asm/hwcap.h> /* Get HWCAP bits from asm/hwcap.h */
2629#include < sys/auxv.h>
2730#ifndef HWCAP_CPUID
2831#define HWCAP_CPUID (1 << 11 )
@@ -375,6 +378,7 @@ void InitCPUIDInfo() {
375378
376379class CPUIDInfo ;
377380void InitCPUIDInfo ();
381+ void InitCPUIDFeatureInfo ();
378382
379383CPUIDInfo *cpuid = nullptr ;
380384
@@ -386,7 +390,8 @@ class CPUIDInfo {
386390 variant_ (0 ),
387391 cpunum_(0 ),
388392 is_arm_neoverse_v1_(0 ),
389- is_arm_neoverse_n1_(0 ) {}
393+ is_arm_neoverse_n1_(0 ),
394+ has_bf16_(0 ) {}
390395
391396 static void Initialize () {
392397 // Initialize CPUIDInfo pointer.
@@ -458,26 +463,54 @@ class CPUIDInfo {
458463 }
459464#endif // !PLATFORM_WINDOWS
460465 }
466+ static void InitializeCPUFeature () {
467+ // Initialize CPUIDInfo pointer.
468+ if (cpuid != nullptr ) return ;
469+
470+ cpuid = new CPUIDInfo;
471+
472+ const uint32_t hwcaps2 = getauxval (AT_HWCAP2);
473+ cpuid->has_bf16_ = IsFeatureSupported (hwcaps2, kHwcap2Bf16 );
474+ }
461475
462476 int implementer () const { return implementer_; }
463477 int cpunum () const { return cpunum_; }
464478
465479 static bool TestAarch64CPU (Aarch64CPU cpu) {
466480 InitCPUIDInfo ();
481+ // clang-format off
467482 switch (cpu) {
468483 case ARM_NEOVERSE_V1:
469484 return cpuid->is_arm_neoverse_v1_ ;
470485 default :
471- return 0 ;
486+ return false ;
472487 }
488+ // clang-format on
489+ return false ;
490+ }
491+
492+ static bool IsFeatureSupported (uint64_t features, uint64_t feature_mask) {
493+ return (features & feature_mask);
494+ }
495+ static bool TestAarch64Feature (CPUFeature feature) {
496+ InitCPUIDFeatureInfo ();
497+ switch (feature) {
498+ case AARCH64_BF16:
499+ return cpuid->has_bf16_ ;
500+ default :
501+ break ;
502+ }
503+ return false ;
473504 }
474505
475506 private:
507+ static constexpr uint64_t kHwcap2Bf16 = 1ull << 14 ;
476508 int implementer_;
477509 int variant_;
478510 int cpunum_;
479511 int is_arm_neoverse_v1_; // ARM NEOVERSE V1
480512 int is_arm_neoverse_n1_; // ARM NEOVERSE N1
513+ int has_bf16_;
481514};
482515
483516absl::once_flag cpuid_once_flag;
@@ -486,13 +519,19 @@ void InitCPUIDInfo() {
486519 absl::call_once (cpuid_once_flag, CPUIDInfo::Initialize);
487520}
488521
522+ void InitCPUIDFeatureInfo () {
523+ absl::call_once (cpuid_once_flag, CPUIDInfo::InitializeCPUFeature);
524+ }
525+
489526#endif // PLATFORM_IS_ARM64 && !__APPLE__ && !__OpenBSD__
490527
491528} // namespace
492529
493530bool TestCPUFeature (CPUFeature feature) {
494531#ifdef PLATFORM_IS_X86
495532 return CPUIDInfo::TestFeature (feature);
533+ #elif defined(PLATFORM_IS_ARM64) && !defined(__APPLE__) && !defined(__OpenBSD__)
534+ return CPUIDInfo::TestAarch64Feature (feature);
496535#else
497536 return false ;
498537#endif
0 commit comments