@@ -818,9 +818,6 @@ struct ForwardToKernel : WorkingKernel {
818818 * Kernel static dispatching *
819819 *******************************/
820820
821- template <typename Traits, typename = void >
822- struct KernelDispatch ;
823-
824821// Benchmarking show unpack to uint64_t is underperforming on SSE4.2 and Avx2
825822template <typename Traits, typename Arch = typename Traits::arch_type>
826823constexpr bool MediumShouldUseUint32 =
@@ -829,36 +826,36 @@ constexpr bool MediumShouldUseUint32 =
829826 (Traits::kShape .packed_bit_size() < 32 ) &&
830827 KernelTraitsWithUnpack<Traits, uint32_t >::kShape .is_medium();
831828
832- template <typename Traits>
833- struct KernelDispatch <Traits, std::enable_if_t <Traits::kShape .is_medium() &&
834- !MediumShouldUseUint32<Traits>>>
835- : MediumKernel<Traits> {};
836-
837- template <typename Traits>
838- struct KernelDispatch <
839- Traits, std::enable_if_t <Traits::kShape .is_medium() && MediumShouldUseUint32<Traits>>>
840- : ForwardToKernel<Traits, MediumKernel<KernelTraitsWithUnpack<Traits, uint32_t >>> {};
841-
842829// Benchmarking show large unpack to uint8_t is underperforming on SSE4.2
843830template <typename Traits, typename Arch = typename Traits::arch_type>
844831constexpr bool LargeShouldUseUint16 = HasSse2<Arch> &&
845832 (Traits::kShape .unpacked_byte_size() ==
846833 sizeof (uint8_t ));
847834
835+ // A ``std::enable_if`` that works on MSVC
848836template <typename Traits>
849- struct KernelDispatch <
850- Traits, std::enable_if_t <Traits::kShape .is_large() && !LargeShouldUseUint16<Traits>>>
851- : LargeKernel<Traits> {};
852-
853- template <typename Traits>
854- struct KernelDispatch <
855- Traits, std::enable_if_t <Traits::kShape .is_large() && LargeShouldUseUint16<Traits>>>
856- : ForwardToKernel<Traits, MediumKernel<KernelTraitsWithUnpack<Traits, uint16_t >>> {};
837+ constexpr auto KernelDispatchImpl () {
838+ if constexpr (Traits::kShape .is_medium ()) {
839+ if constexpr (MediumShouldUseUint32<Traits>) {
840+ using Kernel32 = MediumKernel<KernelTraitsWithUnpack<Traits, uint32_t >>;
841+ return ForwardToKernel<Traits, Kernel32>{};
842+ } else {
843+ return MediumKernel<Traits>{};
844+ }
845+ } else if constexpr (Traits::kShape .is_large ()) {
846+ if constexpr (LargeShouldUseUint16<Traits>) {
847+ using Kernel16 = MediumKernel<KernelTraitsWithUnpack<Traits, uint16_t >>;
848+ return ForwardToKernel<Traits, Kernel16>{};
849+ } else {
850+ return LargeKernel<Traits>{};
851+ }
852+ } else if constexpr (Traits::kShape .is_oversized ()) {
853+ return NoOpKernel<Traits>{};
854+ }
855+ }
857856
858- // Oversize kernel is only a few edge cases
859857template <typename Traits>
860- struct KernelDispatch <Traits, std::enable_if_t <Traits::kShape .is_oversized()>>
861- : NoOpKernel<Traits> {};
858+ using KernelDispatch = decltype (KernelDispatchImpl<Traits>());
862859
863860// / The public kernel exposed for any size.
864861template <typename UnpackedUint, int kPackedBitSize , int kSimdBitSize >
0 commit comments