Skip to content

Commit 9b542ad

Browse files
authored
enhance async_load with issue_space and use __fp16 type (ROCm#1915)
1 parent 4a9fea7 commit 9b542ad

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

csrc/include/opus/opus.hpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,6 @@ template <index_t vec, typename Layout>
660660
OPUS_H_D constexpr auto layout_to_vectorized_issue_space() {
661661
constexpr auto issue_space = layout_to_issue_space<Layout>();
662662
constexpr auto issue_space_vec = vectorize_issue_space(issue_space, number<vec>{});
663-
static_assert(size<decltype(issue_space_vec)>() == Layout::coord_rank);
664663
return issue_space_vec;
665664
}
666665

@@ -829,7 +828,7 @@ template<typename T> constexpr bool is_dtype_v = is_dtype<remove_cvref_t<T>>::va
829828

830829
REGISTER_DTYPE(fp32, float)
831830
REGISTER_DTYPE(bf16, unsigned short)
832-
REGISTER_DTYPE(fp16, _Float16)
831+
REGISTER_DTYPE(fp16, __fp16)
833832
REGISTER_DTYPE(fp8 , _BitInt(8))
834833
REGISTER_DTYPE(bf8 , unsigned _BitInt(8))
835834
REGISTER_DTYPE(i32 , int32_t)
@@ -964,7 +963,7 @@ template<> OPUS_D float min<float>(const float&a, const float&b) { return
964963

965964
template<typename T> OPUS_D T med3(const T&a, const T&b, const T&c) { auto max_0 = max(a, b); auto min_0 = max(a, b); return max(max_0, max(min_0, c)); }
966965
template<> OPUS_D float med3<float>(const float&a, const float&b, const float&c) { return __builtin_amdgcn_fmed3f(a, b, c); }
967-
template<> OPUS_D _Float16 med3<_Float16>(const _Float16&a, const _Float16&b, const _Float16&c) { return __builtin_amdgcn_fmed3h(a, b, c); }
966+
template<> OPUS_D __fp16 med3<__fp16>(const __fp16&a, const __fp16&b, const __fp16&c) { return __builtin_amdgcn_fmed3h(a, b, c); }
968967
/////////////////////////////////////////////////////////////////////////////////////////////////////////
969968
// buffer load/store related
970969
OPUS_D constexpr auto buffer_default_config() {
@@ -1031,6 +1030,16 @@ struct gmem {
10311030
template<index_t vec = 1, index_t aux = 0> // os in unit of T and cast to vector with vec
10321031
OPUS_D void async_load(__shared__ void* dst, int v_os, int s_os = 0, number<aux> = {}) { _async_load<vec>(dst, v_os * sizeof(T), s_os * sizeof(T), number<aux>{}); }
10331032

1033+
template<index_t vec = 1, typename LayoutG, typename LayoutS, index_t aux = 0, std::enable_if_t<is_layout_v<LayoutG> && is_layout_v<LayoutS>, bool> = true>
1034+
OPUS_D void async_load(__shared__ void* smem_base, const LayoutG& u_gmem, const LayoutS& u_smem, int s_os = 0, number<aux> = {}) {
1035+
constexpr auto issue_space = layout_to_issue_space<LayoutG>();
1036+
constexpr auto issue_space_vec = vectorize_issue_space(issue_space, number<vec>{});
1037+
scalar_type* smem_ptr = reinterpret_cast<scalar_type*>(smem_base);
1038+
static_ford(issue_space_vec, [&](auto... ids) {
1039+
async_load<vec>(smem_ptr + u_smem(ids...), u_gmem(ids...), s_os, number<aux>{});
1040+
});
1041+
}
1042+
10341043
template<index_t vec = 1, typename V, index_t aux = 0, std::enable_if_t<(is_vector_v<V> || is_dtype_v<V> || is_array_v<V>), bool> = true> // os in unit of T and cast to vector with vec
10351044
OPUS_D void store(const V& x, int v_os, int s_os = 0, number<aux> = {}) {
10361045
static_assert(std::is_same_v<typename vector_traits<V>::dtype, scalar_type>, "scalar type must be same for the data to be stored" );
@@ -1562,7 +1571,7 @@ OPUS_D decltype(auto) make_tiled_mma(ES, TS, WS, WA&& = {}, TA&& = {}) {
15621571
/////////////////////////////////////////////////////////////////////////////////////////////////////////
15631572
template<index_t cached_vec = 0, typename L, typename D, typename S, typename C, std::enable_if_t<is_layout_v<L> && is_tuple_v<D> && is_tuple_v<S> && is_tuple_v<C>, bool> = true>
15641573
OPUS_D constexpr auto partition_layout(L&& layout, D&& dims, S&& shapes, C&& p_coord) {
1565-
static_assert(L::rank == D::size()); OPUS_KP_(dims);
1574+
OPUS_KP_(dims);
15661575
return make_layout<cached_vec>(std::forward<S>(shapes), unfold_x_stride(std::forward<D>(dims), std::forward<S>(shapes), layout.stride()), unfold_p_coord(std::forward<D>(dims), p_coord));
15671576
}
15681577
// partition, use cached_vec to dispatch which layout implementation. cached_vec < 0 : "layout", cached_vec == 0 : "layout_linear", cached_vec > 0 : "layout_cached"

0 commit comments

Comments
 (0)