diff --git a/backends/xnnpack/third-party/XNNPACK b/backends/xnnpack/third-party/XNNPACK index 84096dd536e..52208356940 160000 --- a/backends/xnnpack/third-party/XNNPACK +++ b/backends/xnnpack/third-party/XNNPACK @@ -1 +1 @@ -Subproject commit 84096dd536edffd19337d9297634c4f5c5449bfd +Subproject commit 52208356940a7c7d3597cf386d500a0f776f7bd0 diff --git a/backends/xnnpack/third-party/xnnpack.buck.bzl b/backends/xnnpack/third-party/xnnpack.buck.bzl index 8556fde3d8a..14520b07664 100644 --- a/backends/xnnpack/third-party/xnnpack.buck.bzl +++ b/backends/xnnpack/third-party/xnnpack.buck.bzl @@ -4,8 +4,8 @@ load( "OPERATOR_SRCS", "SUBGRAPH_SRCS", "TABLE_SRCS", - "XNNPACK_SRCS", "get_xnnpack_headers", + "get_ukernel_config_srcs", "prod_srcs_for_arch_wrapper", ) @@ -274,6 +274,38 @@ def define_xnnpack(): ], ) + SSE2_FMA_COMPILER_FLAGS = [ + "-msse2", + "-mno-sse3", + ] + + native.cxx_library( + name = "ukernels_sse2fma", + srcs = select({ + "DEFAULT": prod_srcs_for_arch_wrapper("sse2fma"), + "ovr_config//cpu:arm32": DEFAULT_DUMMY_SRC, + "ovr_config//cpu:arm64": DEFAULT_DUMMY_SRC, + }), + headers = get_xnnpack_headers(), + header_namespace = "", + compiler_flags = [ + "-O2", + "-Wno-error=missing-braces", # required since the SGX toolchain does not have this by default + ] + select({ + "DEFAULT": SSE2_FMA_COMPILER_FLAGS, + "ovr_config//cpu:arm32": [], + "ovr_config//cpu:arm64": [], + }), + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + exported_deps = [ + ":FP16", + ":interface", + ], + ) + SSE3_COMPILER_FLAGS = ["-mssse3"] # @lint-ignore BUCKLINT: native and fb_native are explicitly forbidden in fbcode. @@ -961,6 +993,44 @@ def define_xnnpack(): ], ) + AMD64_COMPILER_FLAGS = [ + "-mf16c", + "-mfma", + "-mavx512f", + "-mavx512cd", + "-mavx512bw", + "-mavx512dq", + "-mavx512vl", + "-mavx512vnni", + "-mgfni", + ] + native.cxx_library( + name = "ukernels_amd64", + srcs = select({ + "DEFAULT": prod_srcs_for_arch_wrapper("amd64"), + "ovr_config//cpu:arm32": DEFAULT_DUMMY_SRC, + "ovr_config//cpu:arm64": DEFAULT_DUMMY_SRC, + }), + headers = get_xnnpack_headers(), + header_namespace = "", + compiler_flags = [ + "-O2", + "-Wno-error=missing-braces", # required since the SGX toolchain does not have this by default + ] + select({ + "DEFAULT": AMD64_COMPILER_FLAGS, + "ovr_config//cpu:arm32": [], + "ovr_config//cpu:arm64": [], + }), + preferred_linkage = "static", + preprocessor_flags = [ + "-DXNN_LOG_LEVEL=0", + ], + exported_deps = [ + ":FP16", + ":interface", + ], + ) + AVX512VNNIGFNI_COMPILER_FLAGS = AVX512VNNI_COMPILER_FLAGS + [ "-mgfni", ] @@ -1044,12 +1114,14 @@ def define_xnnpack(): ":ukernels_fma3", ":ukernels_sse", ":ukernels_sse2", + ":ukernels_sse2fma", ":ukernels_sse41", ":ukernels_ssse3", ":ukernels_avx512vbmi", ":ukernels_avx512vnnigfni", ":ukernels_avx512vnni", ":ukernels_avxvnni", + ":ukernels_amd64", ] ARM_XNNPACK_DEPS = [ @@ -1070,7 +1142,7 @@ def define_xnnpack(): # @lint-ignore BUCKLINT: native and fb_native are explicitly forbidden in fbcode. native.cxx_library( name = "XNNPACK", - srcs = XNNPACK_SRCS + LOGGING_SRCS + [ + srcs = get_ukernel_config_srcs() + LOGGING_SRCS + [ "XNNPACK/src/init.c", "XNNPACK/src/params.c", "XNNPACK/src/configs/hardware-config.c", @@ -1097,10 +1169,22 @@ def define_xnnpack(): "-DXNN_ENABLE_GEMM_M_SPECIALIZATION", "-DXNN_ENABLE_ARM_DOTPROD", "-DXNN_ENABLE_CPUINFO", - # "-DXNN_ENABLE_DWCONV_MULTIPLASS=1", + # "-DXNN_ENABLE_DWCONV_MULTIPLASS=0", "-DXNN_ENABLE_ARM_I8MM=1", "-DXNN_ENABLE_ARM_FP16_VECTOR=1", - "-DXNN_ENABLE_AVX512BF16=0" + "-DXNN_ENABLE_AVX512F=1", + "-DXNN_ENABLE_AVX512SKX=1", + "-DXNN_ENABLE_AVX512VNNI=1", + "-DXNN_ENABLE_AVX512VBMI=1", + "-DXNN_ENABLE_AVXVNNI=0", + "-DXNN_ENABLE_AVXVNNIINT8=0", + "-DXNN_ENABLE_AVX512FP16=0", + "-DXNN_ENABLE_AVX512VNNIGFNI=0", + "-DXNN_ENABLE_AVX512BF16=0", + "-DXNN_ENABLE_AVX256VNNIGFNI=0", + "-DXNN_ENABLE_AVX512AMX=0", + "-DXNN_ENABLE_AVX256SKX=0", + "-DXNN_ENABLE_AVX256VNNI=0", ], visibility = ["PUBLIC"], exported_deps = COMMON_XNNPACK_DEPS + [ diff --git a/backends/xnnpack/third-party/xnnpack_src_defs.bzl b/backends/xnnpack/third-party/xnnpack_src_defs.bzl index cb1f635e79e..25477e8c718 100644 --- a/backends/xnnpack/third-party/xnnpack_src_defs.bzl +++ b/backends/xnnpack/third-party/xnnpack_src_defs.bzl @@ -9,32 +9,6 @@ load("//backends/xnnpack/third-party/XNNPACK/gen:microkernels.bzl", "prod_srcs_f load("@fbsource//xplat/executorch/third-party:glob_defs.bzl", "subdir_glob") # To get from XNNPACK:build_srcs.bzl in the future -_XNNPACK_SRCS = [ - "src/configs/argmaxpool-config.c", - "src/configs/avgpool-config.c", - "src/configs/binary-elementwise-config.c", - "src/configs/cmul-config.c", - "src/configs/conv-hwc2chw-config.c", - "src/configs/dwconv-config.c", - "src/configs/dwconv2d-chw-config.c", - "src/configs/gemm-config.c", - "src/configs/ibilinear-chw-config.c", - "src/configs/ibilinear-config.c", - "src/configs/lut32norm-config.c", - "src/configs/maxpool-config.c", - "src/configs/pack-lh-config.c", - "src/configs/raddstoreexpminusmax-config.c", - "src/configs/reduce-config.c", - "src/configs/spmm-config.c", - "src/configs/transpose-config.c", - "src/configs/unary-elementwise-config.c", - "src/configs/unpool-config.c", - "src/configs/vmulcaddc-config.c", - "src/configs/x8-lut-config.c", - "src/configs/xx-fill-config.c", - "src/configs/xx-pad-config.c", -] - def define_xnnpack_build_src(xnnpack_build_src): return ["XNNPACK/{}".format(src) for src in xnnpack_build_src] @@ -56,8 +30,12 @@ def get_xnnpack_headers(): ]) return src_headers | include_headers | ukernel_headers +def get_ukernel_config_srcs(): + return subdir_glob([ + ("XNNPACK/src/configs", "*.c"), + ]).values() + OPERATOR_SRCS = define_xnnpack_build_src(_OPERATOR_SRCS) SUBGRAPH_SRCS = define_xnnpack_build_src(_SUBGRAPH_SRCS) TABLE_SRCS = define_xnnpack_build_src(_TABLE_SRCS) -XNNPACK_SRCS = define_xnnpack_build_src(_XNNPACK_SRCS) LOGGING_SRCS = define_xnnpack_build_src(_LOGGING_SRCS)