@@ -660,7 +660,6 @@ template <index_t vec, typename Layout>
660660OPUS_H_D constexpr auto layout_to_vectorized_issue_space () {
661661 constexpr auto issue_space = layout_to_issue_space<Layout>();
662662 constexpr auto issue_space_vec = vectorize_issue_space (issue_space, number<vec>{});
663- static_assert (size<decltype (issue_space_vec)>() == Layout::coord_rank);
664663 return issue_space_vec;
665664}
666665
@@ -829,7 +828,7 @@ template<typename T> constexpr bool is_dtype_v = is_dtype<remove_cvref_t<T>>::va
829828
830829REGISTER_DTYPE (fp32, float )
831830REGISTER_DTYPE(bf16 , unsigned short )
832- REGISTER_DTYPE(fp16, _Float16 )
831+ REGISTER_DTYPE(fp16, __fp16 )
833832REGISTER_DTYPE(fp8 , _BitInt (8 ))
834833REGISTER_DTYPE(bf8 , unsigned _BitInt (8 ))
835834REGISTER_DTYPE(i32 , int32_t )
@@ -964,7 +963,7 @@ template<> OPUS_D float min<float>(const float&a, const float&b) { return
964963
965964template <typename T> OPUS_D T med3 (const T&a, const T&b, const T&c) { auto max_0 = max (a, b); auto min_0 = max (a, b); return max (max_0, max (min_0, c)); }
966965template <> OPUS_D float med3<float >(const float &a, const float &b, const float &c) { return __builtin_amdgcn_fmed3f (a, b, c); }
967- template <> OPUS_D _Float16 med3<_Float16 >(const _Float16 &a, const _Float16 &b, const _Float16 &c) { return __builtin_amdgcn_fmed3h (a, b, c); }
966+ template <> OPUS_D __fp16 med3<__fp16 >(const __fp16 &a, const __fp16 &b, const __fp16 &c) { return __builtin_amdgcn_fmed3h (a, b, c); }
968967// ///////////////////////////////////////////////////////////////////////////////////////////////////////
969968// buffer load/store related
970969OPUS_D constexpr auto buffer_default_config () {
@@ -1031,6 +1030,16 @@ struct gmem {
10311030 template <index_t vec = 1 , index_t aux = 0 > // os in unit of T and cast to vector with vec
10321031 OPUS_D void async_load (__shared__ void * dst, int v_os, int s_os = 0 , number<aux> = {}) { _async_load<vec>(dst, v_os * sizeof (T), s_os * sizeof (T), number<aux>{}); }
10331032
1033+ template <index_t vec = 1 , typename LayoutG, typename LayoutS, index_t aux = 0 , std::enable_if_t <is_layout_v<LayoutG> && is_layout_v<LayoutS>, bool > = true >
1034+ OPUS_D void async_load (__shared__ void * smem_base, const LayoutG& u_gmem, const LayoutS& u_smem, int s_os = 0 , number<aux> = {}) {
1035+ constexpr auto issue_space = layout_to_issue_space<LayoutG>();
1036+ constexpr auto issue_space_vec = vectorize_issue_space (issue_space, number<vec>{});
1037+ scalar_type* smem_ptr = reinterpret_cast <scalar_type*>(smem_base);
1038+ static_ford (issue_space_vec, [&](auto ... ids) {
1039+ async_load<vec>(smem_ptr + u_smem (ids...), u_gmem (ids...), s_os, number<aux>{});
1040+ });
1041+ }
1042+
10341043 template <index_t vec = 1 , typename V, index_t aux = 0 , std::enable_if_t <(is_vector_v<V> || is_dtype_v<V> || is_array_v<V>), bool > = true > // os in unit of T and cast to vector with vec
10351044 OPUS_D void store (const V& x, int v_os, int s_os = 0 , number<aux> = {}) {
10361045 static_assert (std::is_same_v<typename vector_traits<V>::dtype, scalar_type>, " scalar type must be same for the data to be stored" );
@@ -1562,7 +1571,7 @@ OPUS_D decltype(auto) make_tiled_mma(ES, TS, WS, WA&& = {}, TA&& = {}) {
15621571// ///////////////////////////////////////////////////////////////////////////////////////////////////////
15631572template <index_t cached_vec = 0 , typename L, typename D, typename S, typename C, std::enable_if_t <is_layout_v<L> && is_tuple_v<D> && is_tuple_v<S> && is_tuple_v<C>, bool > = true >
15641573OPUS_D constexpr auto partition_layout (L&& layout, D&& dims, S&& shapes, C&& p_coord) {
1565- static_assert (L::rank == D::size ()); OPUS_KP_ (dims);
1574+ OPUS_KP_ (dims);
15661575 return make_layout<cached_vec>(std::forward<S>(shapes), unfold_x_stride (std::forward<D>(dims), std::forward<S>(shapes), layout.stride ()), unfold_p_coord (std::forward<D>(dims), p_coord));
15671576}
15681577// partition, use cached_vec to dispatch which layout implementation. cached_vec < 0 : "layout", cached_vec == 0 : "layout_linear", cached_vec > 0 : "layout_cached"
0 commit comments