@@ -274,6 +274,38 @@ def define_xnnpack():
274274 ],
275275 )
276276
277+ SSE2_FMA_COMPILER_FLAGS = [
278+ "-msse2" ,
279+ "-mno-sse3" ,
280+ ]
281+
282+ native .cxx_library (
283+ name = "ukernels_sse2fma" ,
284+ srcs = select ({
285+ "DEFAULT" : prod_srcs_for_arch_wrapper ("sse2fma" ),
286+ "ovr_config//cpu:arm32" : DEFAULT_DUMMY_SRC ,
287+ "ovr_config//cpu:arm64" : DEFAULT_DUMMY_SRC ,
288+ }),
289+ headers = get_xnnpack_headers (),
290+ header_namespace = "" ,
291+ compiler_flags = [
292+ "-O2" ,
293+ "-Wno-error=missing-braces" , # required since the SGX toolchain does not have this by default
294+ ] + select ({
295+ "DEFAULT" : SSE2_FMA_COMPILER_FLAGS ,
296+ "ovr_config//cpu:arm32" : [],
297+ "ovr_config//cpu:arm64" : [],
298+ }),
299+ preferred_linkage = "static" ,
300+ preprocessor_flags = [
301+ "-DXNN_LOG_LEVEL=0" ,
302+ ],
303+ exported_deps = [
304+ ":FP16" ,
305+ ":interface" ,
306+ ],
307+ )
308+
277309 SSE3_COMPILER_FLAGS = ["-mssse3" ]
278310
279311 # @lint-ignore BUCKLINT: native and fb_native are explicitly forbidden in fbcode.
@@ -961,6 +993,44 @@ def define_xnnpack():
961993 ],
962994 )
963995
996+ AMD64_COMPILER_FLAGS = [
997+ "-mf16c" ,
998+ "-mfma" ,
999+ "-mavx512f" ,
1000+ "-mavx512cd" ,
1001+ "-mavx512bw" ,
1002+ "-mavx512dq" ,
1003+ "-mavx512vl" ,
1004+ "-mavx512vnni" ,
1005+ "-mgfni" ,
1006+ ]
1007+ native .cxx_library (
1008+ name = "ukernels_amd64" ,
1009+ srcs = select ({
1010+ "DEFAULT" : prod_srcs_for_arch_wrapper ("amd64" ),
1011+ "ovr_config//cpu:arm32" : DEFAULT_DUMMY_SRC ,
1012+ "ovr_config//cpu:arm64" : DEFAULT_DUMMY_SRC ,
1013+ }),
1014+ headers = get_xnnpack_headers (),
1015+ header_namespace = "" ,
1016+ compiler_flags = [
1017+ "-O2" ,
1018+ "-Wno-error=missing-braces" , # required since the SGX toolchain does not have this by default
1019+ ] + select ({
1020+ "DEFAULT" : AMD64_COMPILER_FLAGS ,
1021+ "ovr_config//cpu:arm32" : [],
1022+ "ovr_config//cpu:arm64" : [],
1023+ }),
1024+ preferred_linkage = "static" ,
1025+ preprocessor_flags = [
1026+ "-DXNN_LOG_LEVEL=0" ,
1027+ ],
1028+ exported_deps = [
1029+ ":FP16" ,
1030+ ":interface" ,
1031+ ],
1032+ )
1033+
9641034 AVX512VNNIGFNI_COMPILER_FLAGS = AVX512VNNI_COMPILER_FLAGS + [
9651035 "-mgfni" ,
9661036 ]
@@ -1044,12 +1114,14 @@ def define_xnnpack():
10441114 ":ukernels_fma3" ,
10451115 ":ukernels_sse" ,
10461116 ":ukernels_sse2" ,
1117+ ":ukernels_sse2fma" ,
10471118 ":ukernels_sse41" ,
10481119 ":ukernels_ssse3" ,
10491120 ":ukernels_avx512vbmi" ,
10501121 ":ukernels_avx512vnnigfni" ,
10511122 ":ukernels_avx512vnni" ,
10521123 ":ukernels_avxvnni" ,
1124+ ":ukernels_amd64" ,
10531125 ]
10541126
10551127 ARM_XNNPACK_DEPS = [
@@ -1097,10 +1169,22 @@ def define_xnnpack():
10971169 "-DXNN_ENABLE_GEMM_M_SPECIALIZATION" ,
10981170 "-DXNN_ENABLE_ARM_DOTPROD" ,
10991171 "-DXNN_ENABLE_CPUINFO" ,
1100- # "-DXNN_ENABLE_DWCONV_MULTIPLASS=1 ",
1172+ # "-DXNN_ENABLE_DWCONV_MULTIPLASS=0 ",
11011173 "-DXNN_ENABLE_ARM_I8MM=1" ,
11021174 "-DXNN_ENABLE_ARM_FP16_VECTOR=1" ,
1103- "-DXNN_ENABLE_AVX512BF16=0"
1175+ "-DXNN_ENABLE_AVX512F=1" ,
1176+ "-DXNN_ENABLE_AVX512SKX=1" ,
1177+ "-DXNN_ENABLE_AVX512VNNI=1" ,
1178+ "-DXNN_ENABLE_AVX512VBMI=1" ,
1179+ "-DXNN_ENABLE_AVXVNNI=0" ,
1180+ "-DXNN_ENABLE_AVXVNNIINT8=0" ,
1181+ "-DXNN_ENABLE_AVX512FP16=0" ,
1182+ "-DXNN_ENABLE_AVX512VNNIGFNI=0" ,
1183+ "-DXNN_ENABLE_AVX512BF16=0" ,
1184+ "-DXNN_ENABLE_AVX256VNNIGFNI=0" ,
1185+ "-DXNN_ENABLE_AVX512AMX=0" ,
1186+ "-DXNN_ENABLE_AVX256SKX=0" ,
1187+ "-DXNN_ENABLE_AVX256VNNI=0" ,
11041188 ],
11051189 visibility = ["PUBLIC" ],
11061190 exported_deps = COMMON_XNNPACK_DEPS + [
0 commit comments