11##############################################################################
22# MIT License
33#
4- # Copyright (c) 2025 Advanced Micro Devices, Inc. All Rights Reserved.
4+ # Copyright (c) 2025 - 2026 Advanced Micro Devices, Inc. All Rights Reserved.
55#
66# Permission is hereby granted, free of charge, to any person obtaining a copy
77# of this software and associated documentation files (the "Software"), to deal
143143 "F6" : {"gfx950" : 131072 },
144144 "F6F4" : {"gfx950" : 131072 }, # Mixed precision F6 x F4
145145 "F8" : dict .fromkeys (["gfx90a" , "gfx940" , "gfx941" , "gfx942" , "gfx950" ], 32768 ),
146- "F16" : dict .fromkeys (["gfx90a" , "gfx940" , "gfx941" , "gfx942" , "gfx950" ], 16384 ),
146+ "F16" : dict .fromkeys (["gfx90a" , "gfx940" , "gfx941" , "gfx942" ], 16384 )
147+ | dict .fromkeys (["gfx950" ], 32768 ),
147148 "F32" : dict .fromkeys (
148149 ["gfx908" , "gfx90a" , "gfx940" , "gfx941" , "gfx942" , "gfx950" ], 4096
149150 ),
150- "BF16" : dict .fromkeys (["gfx940" , "gfx941" , "gfx942" , "gfx950" ], 16384 )
151- | dict .fromkeys (["gfx90a" ], 8192 ),
152- "I8" : dict .fromkeys (["gfx940" , "gfx941" , "gfx942" , "gfx950" ], 32768 )
153- | dict .fromkeys (["gfx90a" ], 16384 ),
151+ "BF16" : dict .fromkeys (["gfx940" , "gfx941" , "gfx942" ], 16384 )
152+ | dict .fromkeys (["gfx90a" ], 8192 )
153+ | dict .fromkeys (["gfx950" ], 32768 ),
154+ "I8" : dict .fromkeys (["gfx940" , "gfx941" , "gfx942" ], 32768 )
155+ | dict .fromkeys (["gfx90a" ], 16384 )
156+ | dict .fromkeys (["gfx950" ], 65536 ),
154157 "F64" : dict .fromkeys (["gfx90a" , "gfx940" , "gfx941" , "gfx942" , "gfx950" ], 2048 ),
155158}
156159
@@ -726,7 +729,7 @@ def flops_bench(device: int, type: str, unit: str, rate: int) -> PerfMetrics:
726729
727730extern "C" __global__ void mfma_f32(int iter, float *dummy)
728731{
729- float a = threadIdx.x;
732+ float a = threadIdx.x;
730733 vec16<float> result = {0};
731734
732735 for(int i = 0; i < iter; ++i)
@@ -749,15 +752,25 @@ def flops_bench(device: int, type: str, unit: str, rate: int) -> PerfMetrics:
749752
750753extern "C" __global__ void mfma_f16(int iter, float *dummy)
751754{
752- vec4<__fp16> a;
753- a[1] = a[0] = threadIdx.x;
754-
755755 vec16<float> result = {0};
756-
756+ #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || \
757+ defined(__gfx941__) || defined(__gfx942__)
758+ vec4<__fp16> a;
759+ a[3] = a[2] = a[1] = a[0] = threadIdx.x;
757760 for(int i = 0; i < iter; ++i)
758761 {
759762 result = __builtin_amdgcn_mfma_f32_32x32x8f16(a, a, result, 0, 0, 0);
760763 }
764+ #elif defined(__gfx950__)
765+ vec8<__fp16> a;
766+ a[7] = a[6] = a[5] = a[4] = a[3] = a[2] = a[1] = a[0] = threadIdx.x;
767+ for(int i = 0; i < iter; ++i)
768+ {
769+ result = __builtin_amdgcn_mfma_f32_32x32x16_f16(a, a, result, 0, 0, 0);
770+ }
771+ #else
772+ #error "Unsupported gfx arch"
773+ #endif
761774
762775 if (result[0] != 2*result[0])
763776 {
@@ -776,23 +789,34 @@ def flops_bench(device: int, type: str, unit: str, rate: int) -> PerfMetrics:
776789 vec16<float> result = {0};
777790
778791// MI100/MI200
779- #if defined(__gfx908__) or defined(__gfx90a__)
792+ #if defined(__gfx908__) || defined(__gfx90a__)
780793 vec2<short> a;
781794 a[1] = a[0]= threadIdx.x;
782795
783796 for(int i = 0; i < iter; ++i)
784797 {
785798 result = __builtin_amdgcn_mfma_f32_32x32x4bf16(a, a, result, 0, 0, 0);
786799 }
787- //MI300 series
788- #else
800+ // MI300 series
801+ #elif defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
789802 vec4<short> a;
790803 a[3] = a[2] = a[1] = a[0] = threadIdx.x;
791804
792805 for(int i = 0; i < iter; ++i)
793806 {
794807 result = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a, a, result, 0, 0, 0);
795808 }
809+ // MI350
810+ #elif defined(__gfx950__)
811+ vec8<short> a;
812+ a[7] = a[6] = a[5] = a[4] = a[3] = a[2] = a[1] = a[0] = threadIdx.x;
813+
814+ for(int i = 0; i < iter; ++i)
815+ {
816+ result = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, a, result, 0, 0, 0);
817+ }
818+ #else
819+ #error "Unsupported gfx arch"
796820#endif
797821
798822 if (result[0] != 2*result[0])
@@ -835,21 +859,32 @@ def flops_bench(device: int, type: str, unit: str, rate: int) -> PerfMetrics:
835859 vec16<int> result = {0};
836860
837861// MI100/MI200
838- #if defined(__gfx908__) or defined(__gfx90a__)
862+ #if defined(__gfx908__) || defined(__gfx90a__)
839863 int a = threadIdx.x;
840864
841865 for(int i = 0; i < iter; ++i)
842866 {
843867 result = __builtin_amdgcn_mfma_i32_32x32x8i8(a, a, result, 0, 0, 0);
844868 }
845869// MI300 series
846- #else
870+ #elif defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
847871 long a = threadIdx.x;
848872
849873 for(int i = 0; i < iter; ++i)
850874 {
851875 result = __builtin_amdgcn_mfma_i32_32x32x16_i8(a, a, result, 0, 0, 0);
852876 }
877+ // MI350 series
878+ #elif defined(__gfx950__)
879+ vec2<long> a;
880+ a[1] = a[0] = threadIdx.x;
881+
882+ for(int i = 0; i < iter; ++i)
883+ {
884+ result = __builtin_amdgcn_mfma_i32_32x32x32_i8(a, a, result, 0, 0, 0);
885+ }
886+ #else
887+ #error "Unsupported gfx arch"
853888#endif
854889
855890 if (result[0] != 2*result[0])
0 commit comments