Skip to content

Commit 4a9fea7

Browse files
authored
add dpacks support (ROCm#1916)
1 parent 1c564c2 commit 4a9fea7

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

csrc/include/opus/opus.hpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,53 @@ OPUS_D constexpr auto fp32_to_bf16(const fp32_t& x, number<rm> = {}) {
866866
OPUS_CAST_DEFINE(fp16, fp32)
867867
OPUS_CAST_DEFINE(fp32, fp16)
868868

869+
namespace impl {
870+
// implement a "pack" of data, storage should pad to multiple of byte(8bit)
871+
template<typename storage_, uint32_t bits_, bool is_signed_ = true>
872+
struct dpacks {
873+
using storage = remove_cvref_t<storage_>;
874+
static constexpr uint32_t bits = bits_;
875+
static constexpr uint32_t mask = (1 << bits) - 1;
876+
static constexpr bool is_signed = is_signed_;
877+
static constexpr uint32_t num_packs = sizeof(storage) * 8 / bits; // we will not check if evenly divided or not here
878+
OPUS_H_D constexpr storage operator[](index_t i) const { return (value >> (i * bits)) & mask; } // NOTE: not efficient, better use v_bfi/v_bfe/v_perm on device
879+
template<index_t I> OPUS_H_D constexpr storage operator[](number<I>) const { return (value >> (I * bits)) & mask; } // NOTE: not efficient, better use v_bfi/v_bfe/v_perm on device
880+
storage value;
881+
};
882+
883+
template<typename storage_, uint32_t bits_, uint32_t exp_bits_, uint32_t mantissa_bits_, bool is_signed_ = true>
884+
struct fpacks : dpacks<storage_, bits_, is_signed_> {
885+
static constexpr uint32_t exp_bits = exp_bits_;
886+
static constexpr uint32_t mantissa_bits = mantissa_bits_;
887+
};
888+
} // namespace impl
889+
890+
template <typename> struct is_packs : false_type {};
891+
template <typename S, uint32_t B, bool X> struct is_packs<impl::dpacks<S, B, X>> : true_type {};
892+
template <typename S, uint32_t B, uint32_t E, uint32_t M, bool X> struct is_packs<impl::fpacks<S, B, E, M, X>> : true_type {};
893+
template <typename T> static constexpr bool is_packs_v = is_packs<remove_cvref_t<T>>::value;
894+
895+
template <typename T> struct sizeof_bits { static constexpr int value = int(sizeof(T) * 8); };
896+
template <> struct sizeof_bits<void> { static constexpr int value = 0; };
897+
template <typename S, uint32_t B, bool X> struct sizeof_bits<impl::dpacks<S, B, X>> { static constexpr int value = impl::dpacks<S, B, X>::bits; };
898+
template <typename S, uint32_t B, uint32_t E, uint32_t M, bool X> struct sizeof_bits<impl::fpacks<S, B, E, M, X>> { static constexpr int value = impl::fpacks<S, B, E, M, X>::bits; };
899+
template <class T> static constexpr auto sizeof_bits_v = sizeof_bits<T>::value;
900+
901+
#define OPUS_DEFINE_DPACKS(name_, storage_, bits_, is_signed_) \
902+
struct name_ : opus::impl::dpacks<storage_, bits_, is_signed_> { using base = opus::impl::dpacks<storage_, bits_, is_signed_>; }; \
903+
template<> struct sizeof_bits<name_> { static constexpr int value = name_::bits; }; template<> struct is_packs<name_> : true_type {}; template<> struct is_dtype<name_> : true_type {};
904+
905+
#define OPUS_DEFINE_FPACKS(name_, storage_, bits_, exp_bits_, mantissa_bits_, is_signed_) \
906+
struct name_ : opus::impl::fpacks<storage_, bits_, exp_bits_, mantissa_bits_, is_signed_> {using base = opus::impl::fpacks<storage_, bits_, exp_bits_, mantissa_bits_, is_signed_>; }; \
907+
template<> struct sizeof_bits<name_> { static constexpr int value = name_::bits; }; template<> struct is_packs<name_> : true_type {}; template<> struct is_dtype<name_> : true_type {};
908+
909+
// NOTE: convention here. The subbyte type below is indeed "packed" data. e.g. fp4_t, underneath it is fp4x2 in one byte, but we don't name it this way
910+
// This is different from cutlass convention (e.g float4_e2m1_t, but storage is uint8_t, hence an array of float4_e2m1_t will be expanded), and different from ck convention(explicitly name it fp4x2_t)
911+
OPUS_DEFINE_DPACKS(int4_t , uint8_t, 4, true) // int4x2
912+
OPUS_DEFINE_DPACKS(uint4_t, uint8_t, 4, false) // uint4x2
913+
OPUS_DEFINE_FPACKS(fp4_t, uint8_t, 4, 2, 1, true) // fp4x2
914+
OPUS_DEFINE_FPACKS(e8m0_t, uint8_t, 8, 8, 0, false) // fp4x2
915+
869916
template<typename D, typename S, typename... Aux, std::enable_if_t<is_vector_v<S>, bool> = true>
870917
OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) {
871918
vector_t<D, size<S>()> r; static_for([&](auto i){ r[i.value] = cast<D>(s[i.value], std::forward<Aux>(aux)...); }, number<size<S>()>{}); return r;
@@ -885,6 +932,8 @@ OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return impl::ca
885932
template<typename D, typename S, typename... Aux, std::enable_if_t<is_array_v<S>, bool> = true>
886933
OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { return impl::cast_impl<D>(s, make_index_seq<size<S>()>{}, std::forward<Aux>(aux)...); }
887934

935+
#undef OPUS_DEFINE_DPACKS
936+
#undef OPUS_DEFINE_FPACKS
888937
#undef OPUS_CAST_DEFINE
889938
/////////////////////////////////////////////////////////////////////////////////////////////////////////
890939
// arch

0 commit comments

Comments
 (0)