@@ -866,6 +866,53 @@ OPUS_D constexpr auto fp32_to_bf16(const fp32_t& x, number<rm> = {}) {
866866OPUS_CAST_DEFINE (fp16, fp32)
867867OPUS_CAST_DEFINE(fp32, fp16)
868868
869+ namespace impl {
870+ // implement a "pack" of data, storage should pad to multiple of byte(8bit)
871+ template <typename storage_, uint32_t bits_, bool is_signed_ = true >
872+ struct dpacks {
873+ using storage = remove_cvref_t <storage_>;
874+ static constexpr uint32_t bits = bits_;
875+ static constexpr uint32_t mask = (1 << bits) - 1 ;
876+ static constexpr bool is_signed = is_signed_;
877+ static constexpr uint32_t num_packs = sizeof (storage) * 8 / bits; // we will not check if evenly divided or not here
878+ OPUS_H_D constexpr storage operator [](index_t i) const { return (value >> (i * bits)) & mask; } // NOTE: not efficient, better use v_bfi/v_bfe/v_perm on device
879+ template <index_t I> OPUS_H_D constexpr storage operator [](number<I>) const { return (value >> (I * bits)) & mask; } // NOTE: not efficient, better use v_bfi/v_bfe/v_perm on device
880+ storage value;
881+ };
882+
883+ template <typename storage_, uint32_t bits_, uint32_t exp_bits_, uint32_t mantissa_bits_, bool is_signed_ = true >
884+ struct fpacks : dpacks<storage_, bits_, is_signed_> {
885+ static constexpr uint32_t exp_bits = exp_bits_;
886+ static constexpr uint32_t mantissa_bits = mantissa_bits_;
887+ };
888+ } // namespace impl
889+
890+ template <typename > struct is_packs : false_type {};
891+ template <typename S, uint32_t B, bool X> struct is_packs <impl::dpacks<S, B, X>> : true_type {};
892+ template <typename S, uint32_t B, uint32_t E, uint32_t M, bool X> struct is_packs <impl::fpacks<S, B, E, M, X>> : true_type {};
893+ template <typename T> static constexpr bool is_packs_v = is_packs<remove_cvref_t <T>>::value;
894+
895+ template <typename T> struct sizeof_bits { static constexpr int value = int (sizeof (T) * 8 ); };
896+ template <> struct sizeof_bits <void > { static constexpr int value = 0 ; };
897+ template <typename S, uint32_t B, bool X> struct sizeof_bits <impl::dpacks<S, B, X>> { static constexpr int value = impl::dpacks<S, B, X>::bits; };
898+ template <typename S, uint32_t B, uint32_t E, uint32_t M, bool X> struct sizeof_bits <impl::fpacks<S, B, E, M, X>> { static constexpr int value = impl::fpacks<S, B, E, M, X>::bits; };
899+ template <class T > static constexpr auto sizeof_bits_v = sizeof_bits<T>::value;
900+
901+ #define OPUS_DEFINE_DPACKS (name_, storage_, bits_, is_signed_ ) \
902+ struct name_ : opus::impl::dpacks<storage_, bits_, is_signed_> { using base = opus::impl::dpacks<storage_, bits_, is_signed_>; }; \
903+ template <> struct sizeof_bits <name_> { static constexpr int value = name_::bits; }; template <> struct is_packs <name_> : true_type {}; template <> struct is_dtype <name_> : true_type {};
904+
905+ #define OPUS_DEFINE_FPACKS (name_, storage_, bits_, exp_bits_, mantissa_bits_, is_signed_ ) \
906+ struct name_ : opus::impl::fpacks<storage_, bits_, exp_bits_, mantissa_bits_, is_signed_> {using base = opus::impl::fpacks<storage_, bits_, exp_bits_, mantissa_bits_, is_signed_>; }; \
907+ template <> struct sizeof_bits <name_> { static constexpr int value = name_::bits; }; template <> struct is_packs <name_> : true_type {}; template <> struct is_dtype <name_> : true_type {};
908+
909+ // NOTE: convention here. The subbyte type below is indeed "packed" data. e.g. fp4_t, underneath it is fp4x2 in one byte, but we don't name it this way
910+ // This is different from cutlass convention (e.g float4_e2m1_t, but storage is uint8_t, hence an array of float4_e2m1_t will be expanded), and different from ck convention(explicitly name it fp4x2_t)
911+ OPUS_DEFINE_DPACKS (int4_t , uint8_t , 4 , true ) // int4x2
912+ OPUS_DEFINE_DPACKS(uint4_t , uint8_t , 4 , false ) // uint4x2
913+ OPUS_DEFINE_FPACKS(fp4_t , uint8_t , 4 , 2 , 1 , true ) // fp4x2
914+ OPUS_DEFINE_FPACKS(e8m0_t , uint8_t , 8 , 8 , 0 , false ) // fp4x2
915+
869916template<typename D, typename S, typename... Aux, std::enable_if_t<is_vector_v<S>, bool> = true>
870917OPUS_D constexpr decltype(auto ) cast(const S& s, Aux&&... aux) {
871918 vector_t <D, size<S>()> r; static_for ([&](auto i){ r[i.value ] = cast<D>(s[i.value ], std::forward<Aux>(aux)...); }, number<size<S>()>{}); return r;
@@ -885,6 +932,8 @@ OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return impl::ca
885932template <typename D, typename S, typename ... Aux, std::enable_if_t <is_array_v<S>, bool > = true >
886933OPUS_D constexpr decltype (auto ) cast(const S& s, Aux&&... aux) { return impl::cast_impl<D>(s, make_index_seq<size<S>()>{}, std::forward<Aux>(aux)...); }
887934
935+ #undef OPUS_DEFINE_DPACKS
936+ #undef OPUS_DEFINE_FPACKS
888937#undef OPUS_CAST_DEFINE
889938// ///////////////////////////////////////////////////////////////////////////////////////////////////////
890939// arch
0 commit comments