diff --git a/backends/xnnpack/third-party/XNNPACK b/backends/xnnpack/third-party/XNNPACK index 84096dd536e..52208356940 160000 --- a/backends/xnnpack/third-party/XNNPACK +++ b/backends/xnnpack/third-party/XNNPACK @@ -1 +1 @@ -Subproject commit 84096dd536edffd19337d9297634c4f5c5449bfd +Subproject commit 52208356940a7c7d3597cf386d500a0f776f7bd0 diff --git a/backends/xnnpack/third-party/xnnpack.buck.bzl b/backends/xnnpack/third-party/xnnpack.buck.bzl index c01512eed26..14520b07664 100644 --- a/backends/xnnpack/third-party/xnnpack.buck.bzl +++ b/backends/xnnpack/third-party/xnnpack.buck.bzl @@ -274,6 +274,38 @@ def define_xnnpack(): ], ) + SSE2_FMA_COMPILER_FLAGS = [ + "-msse2", + "-mno-sse3", + ] + + native.cxx_library( + name = "ukernels_sse2fma", + srcs = select({ + "DEFAULT": prod_srcs_for_arch_wrapper("sse2fma"), + "ovr_config//cpu:arm32": DEFAULT_DUMMY_SRC, + "ovr_config//cpu:arm64": DEFAULT_DUMMY_SRC, + }), + headers = get_xnnpack_headers(), + header_namespace = "", + compiler_flags = [ + "-O2", + "-Wno-error=missing-braces", # required since the SGX toolchain does not have this by default + ] + select({ + "DEFAULT": SSE2_FMA_COMPILER_FLAGS, + "ovr_config//cpu:arm32": [], + "ovr_config//cpu:arm64": [], + }), + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + exported_deps = [ + ":FP16", + ":interface", + ], + ) + SSE3_COMPILER_FLAGS = ["-mssse3"] # @lint-ignore BUCKLINT: native and fb_native are explicitly forbidden in fbcode. @@ -961,6 +993,44 @@ def define_xnnpack(): ], ) + AMD64_COMPILER_FLAGS = [ + "-mf16c", + "-mfma", + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + "-mavx512vnni", + "-mgfni", + ] + native.cxx_library( + name = "ukernels_amd64", + srcs = select({ + "DEFAULT": prod_srcs_for_arch_wrapper("amd64"), + "ovr_config//cpu:arm32": DEFAULT_DUMMY_SRC, + "ovr_config//cpu:arm64": DEFAULT_DUMMY_SRC, + }), + headers = get_xnnpack_headers(), + header_namespace = "", + compiler_flags = [ + "-O2", + "-Wno-error=missing-braces", # required since the SGX toolchain does not have this by default + ] + select({ + "DEFAULT": AMD64_COMPILER_FLAGS, + "ovr_config//cpu:arm32": [], + "ovr_config//cpu:arm64": [], + }), + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + exported_deps = [ + ":FP16", + ":interface", + ], + ) + AVX512VNNIGFNI_COMPILER_FLAGS = AVX512VNNI_COMPILER_FLAGS + [ "-mgfni", ] @@ -1044,12 +1114,14 @@ def define_xnnpack(): ":ukernels_fma3", ":ukernels_sse", ":ukernels_sse2", + ":ukernels_sse2fma", ":ukernels_sse41", ":ukernels_ssse3", ":ukernels_avx512vbmi", ":ukernels_avx512vnnigfni", ":ukernels_avx512vnni", ":ukernels_avxvnni", + ":ukernels_amd64", ] ARM_XNNPACK_DEPS = [ @@ -1097,10 +1169,22 @@ def define_xnnpack(): "-DXNN_ENABLE_GEMM_M_SPECIALIZATION", "-DXNN_ENABLE_ARM_DOTPROD", "-DXNN_ENABLE_CPUINFO", - # "-DXNN_ENABLE_DWCONV_MULTIPLASS=1", + # "-DXNN_ENABLE_DWCONV_MULTIPLASS=0", "-DXNN_ENABLE_ARM_I8MM=1", "-DXNN_ENABLE_ARM_FP16_VECTOR=1", - "-DXNN_ENABLE_AVX512BF16=0" + "-DXNN_ENABLE_AVX512F=1", + "-DXNN_ENABLE_AVX512SKX=1", + "-DXNN_ENABLE_AVX512VNNI=1", + "-DXNN_ENABLE_AVX512VBMI=1", + "-DXNN_ENABLE_AVXVNNI=0", + "-DXNN_ENABLE_AVXVNNIINT8=0", + "-DXNN_ENABLE_AVX512FP16=0", + "-DXNN_ENABLE_AVX512VNNIGFNI=0", + "-DXNN_ENABLE_AVX512BF16=0", + "-DXNN_ENABLE_AVX256VNNIGFNI=0", + "-DXNN_ENABLE_AVX512AMX=0", + "-DXNN_ENABLE_AVX256SKX=0", + "-DXNN_ENABLE_AVX256VNNI=0", ], visibility = ["PUBLIC"], exported_deps = COMMON_XNNPACK_DEPS + [