diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 6f510e735e..5fa73d2fda 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -245,6 +245,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // workaround: compiler issue on gfx908 #define CK_WORKAROUND_SWDEV_388832 1 +// workaround: compiler issue on gfx950 +#define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1 + // denorm test fix, necessary for gfx90a #ifndef CK_GFX90A_DENORM_WORKAROUND #define CK_GFX90A_DENORM_WORKAROUND 0 diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index a4d96edc6d..1173f32303 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -36,22 +36,22 @@ struct f4x2_pk_t { using type = uint8_t; type data; - f4x2_pk_t() : data{type{}} {} - f4x2_pk_t(type init) : data{init} {} + __host__ __device__ f4x2_pk_t() : data{type{}} {} + __host__ __device__ f4x2_pk_t(type init) : data{init} {} template __host__ __device__ inline type unpack(Number) const { static_assert(I < 2, "Index is out of range."); if constexpr(I == 0) - return data & 0b00001111; - else return (data >> 4); + else + return data & 0b00001111; } __host__ __device__ inline type pack(const type x0, const type x1) { - return (x1 << 4) | (x0 & 0b00001111); + return (x0 << 4) | (x1 & 0b00001111); } }; diff --git a/include/ck/utility/mxf4_utils.hpp b/include/ck/utility/mxf4_utils.hpp index 757d3914e3..0930134790 100644 --- a/include/ck/utility/mxf4_utils.hpp +++ b/include/ck/utility/mxf4_utils.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CODE_GEN_RTC #pragma once @@ -14,7 +14,7 @@ __host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, f4_t const dataBytes [[maybe_unused]]) { // no need to check for data as it does not have NaN representation - return scale == NumericLimits::QuietNaN(); + return scale.is_nan(); } // no infinity representation in ocp_e2m1_mxfp4 will always return false @@ -27,11 +27,9 @@ __host__ __device__ inline bool is_inf(e8m0_bexp_t const scale [[maybe_unu } template <> -__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, f4_t const data) +__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale [[maybe_unused]], + f4_t const data) { - if(is_nan(scale, data)) - return false; - // no need to check for scale as it does not have a 0 representation f4_t result = (data & 0b00001111) & NumericUtils::set_sign_mask; diff --git a/include/ck/utility/mxfp_utils.hpp b/include/ck/utility/mxfp_utils.hpp index 947d64b705..f0a86f8750 100644 --- a/include/ck/utility/mxfp_utils.hpp +++ b/include/ck/utility/mxfp_utils.hpp @@ -99,7 +99,7 @@ template __host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed); template -inline T convert_to_type(float value) +__host__ __device__ inline T convert_to_type(float value) { using bitwise_type = typename NumericUtils::bitwise_type; @@ -258,7 +258,7 @@ inline T convert_to_type(float value) } template -inline T convert_to_type_sr(float value, uint32_t seed) +__host__ __device__ inline T convert_to_type_sr(float value, uint32_t seed) { if(std::abs(value) > NumericLimits::Max()) { diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp index 5b7a822e1f..9a9c53caec 100644 --- a/include/ck/utility/scaled_type_convert.hpp +++ b/include/ck/utility/scaled_type_convert.hpp @@ -377,12 +377,15 @@ inline __host__ __device__ float2_t scaled_type_convert(e8m0_b f4x2_t f4x2_array[4]; } value{}; value.f4x2_array[0] = x; - return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert(scale), 0); + float2_t tmp = + __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert(scale), 0); + // permute high bits and low bits to match the order of the original vector + return float2_t{tmp[1], tmp[0]}; #else float2_t ret{utils::to_float( - scale, x.template AsType()[Number<0>{}].unpack<>(Number<1>{})), + scale, x.template AsType()[Number<0>{}].unpack<>(Number<0>{})), utils::to_float( - scale, x.template AsType()[Number<0>{}].unpack<>(Number<0>{}))}; + scale, x.template AsType()[Number<0>{}].unpack<>(Number<1>{}))}; return ret; #endif } @@ -398,109 +401,16 @@ inline __host__ __device__ float32_t scaled_type_convert(e8m f4x32_t f4x32_array; f4x2_t fp4x2[16]; } value{x}; - union - { - uint32_t bitwise; - f4x2_t f4x2_array[4]; - } bitwise_value{}; float2_t op; float32_t ret; - // TODO: pack in a loop - bitwise_value.f4x2_array[0] = value.fp4x2[0]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[0] = op[0]; - ret[1] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[1]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[2] = op[0]; - ret[3] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[2]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[4] = op[0]; - ret[5] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[3]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[6] = op[0]; - ret[7] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[4]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[8] = op[0]; - ret[9] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[5]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[10] = op[0]; - ret[11] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[6]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[12] = op[0]; - ret[13] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[7]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[14] = op[0]; - ret[15] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[8]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[16] = op[0]; - ret[17] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[9]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[18] = op[0]; - ret[19] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[10]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[20] = op[0]; - ret[21] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[11]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[22] = op[0]; - ret[23] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[12]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[24] = op[0]; - ret[25] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[13]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[26] = op[0]; - ret[27] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[14]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[28] = op[0]; - ret[29] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[15]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[30] = op[0]; - ret[31] = op[1]; + float f_scale = type_convert(scale); + + ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0); + // permute high bits and low bits to match the order of the original vector + ret[2 * idx] = op[1]; + ret[2 * idx + 1] = op[0]; + }); return ret; #else @@ -515,106 +425,18 @@ inline __host__ __device__ float32_t scaled_type_convert(e8m f4x2_t f4x2_array[16]; f4x32_t f4x32_array; } f4_values{bit_cast<__uint128_t>(x)}; - // TODO: pack in a loop - float_values.float_array[0] = utils::to_float( - scale, - f4_values.f4x2_array[0].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[1] = utils::to_float( - scale, - f4_values.f4x2_array[0].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[2] = utils::to_float( - scale, - f4_values.f4x2_array[1].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[3] = utils::to_float( - scale, - f4_values.f4x2_array[1].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[4] = utils::to_float( - scale, - f4_values.f4x2_array[2].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[5] = utils::to_float( - scale, - f4_values.f4x2_array[2].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[6] = utils::to_float( - scale, - f4_values.f4x2_array[3].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[7] = utils::to_float( - scale, - f4_values.f4x2_array[3].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - - float_values.float_array[0] = utils::to_float( - scale, - f4_values.f4x2_array[4].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[1] = utils::to_float( - scale, - f4_values.f4x2_array[4].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[2] = utils::to_float( - scale, - f4_values.f4x2_array[5].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[3] = utils::to_float( - scale, - f4_values.f4x2_array[5].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[4] = utils::to_float( - scale, - f4_values.f4x2_array[6].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[5] = utils::to_float( - scale, - f4_values.f4x2_array[6].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[6] = utils::to_float( - scale, - f4_values.f4x2_array[7].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[7] = utils::to_float( - scale, - f4_values.f4x2_array[7].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - - float_values.float_array[0] = utils::to_float( - scale, - f4_values.f4x2_array[8].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[1] = utils::to_float( - scale, - f4_values.f4x2_array[8].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[2] = utils::to_float( - scale, - f4_values.f4x2_array[9].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[3] = utils::to_float( - scale, - f4_values.f4x2_array[9].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[4] = utils::to_float( - scale, - f4_values.f4x2_array[10].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[5] = utils::to_float( - scale, - f4_values.f4x2_array[10].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[6] = utils::to_float( - scale, - f4_values.f4x2_array[11].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[7] = utils::to_float( - scale, - f4_values.f4x2_array[11].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - - float_values.float_array[0] = utils::to_float( - scale, - f4_values.f4x2_array[12].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[1] = utils::to_float( - scale, - f4_values.f4x2_array[12].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[2] = utils::to_float( - scale, - f4_values.f4x2_array[13].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[3] = utils::to_float( - scale, - f4_values.f4x2_array[13].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[4] = utils::to_float( - scale, - f4_values.f4x2_array[14].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[5] = utils::to_float( - scale, - f4_values.f4x2_array[14].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[6] = utils::to_float( - scale, - f4_values.f4x2_array[15].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[7] = utils::to_float( - scale, - f4_values.f4x2_array[15].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + + ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { + float_values.float_array[2 * idx] = utils::to_float( + scale, + f4_values.f4x2_array[idx].template AsType()[Number<0>{}].template unpack<>( + Number<0>{})); + + float_values.float_array[2 * idx + 1] = utils::to_float( + scale, + f4_values.f4x2_array[idx].template AsType()[Number<0>{}].template unpack<>( + Number<1>{})); + }); return float_values.float32_array; #endif diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 3ac0098fd9..b9aeb44999 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -732,7 +732,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0); + // permute high bits and low bits to match the order of the original vector + value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[1], x[0], scale, 0); return value.f4x2_array[0]; #else union @@ -757,58 +758,13 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0 f4x2_t f4x2_array[16]; f4x32_t f4x32_array; } f4_values{}, tmp_values{}; - // TODO: pack in a loop - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[0], x[1], scale, 0); - f4_values.f4x2_array[0] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[2], x[3], scale, 0); - f4_values.f4x2_array[1] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[4], x[5], scale, 0); - f4_values.f4x2_array[2] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[6], x[7], scale, 0); - f4_values.f4x2_array[3] = tmp_values.f4x2_array[0]; - - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[8], x[9], scale, 0); - f4_values.f4x2_array[4] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[10], x[11], scale, 0); - f4_values.f4x2_array[5] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[12], x[13], scale, 0); - f4_values.f4x2_array[6] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[14], x[15], scale, 0); - f4_values.f4x2_array[7] = tmp_values.f4x2_array[0]; - - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[16], x[17], scale, 0); - f4_values.f4x2_array[8] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[18], x[19], scale, 0); - f4_values.f4x2_array[9] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[20], x[21], scale, 0); - f4_values.f4x2_array[10] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[22], x[23], scale, 0); - f4_values.f4x2_array[11] = tmp_values.f4x2_array[0]; - - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[24], x[25], scale, 0); - f4_values.f4x2_array[12] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[26], x[27], scale, 0); - f4_values.f4x2_array[13] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[28], x[29], scale, 0); - f4_values.f4x2_array[14] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = - __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[30], x[31], scale, 0); - f4_values.f4x2_array[15] = tmp_values.f4x2_array[0]; + + ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { + // permute high bits and low bits to match the order of the original vector + tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32( + tmp_values.bitwise, x[2 * idx + 1], x[2 * idx], scale, 0); + f4_values.f4x2_array[idx] = tmp_values.f4x2_array[0]; + }); return f4_values.f4x32_array; #else @@ -818,106 +774,14 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0 f4x2_t f4x2_array[16]; f4x32_t f4x32_array; } f4_values{}; - // TODO: pack in a loop - auto tmp = utils::sat_convert_to_type(x[0] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[1] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[2] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[3] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[4] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[5] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[6] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[7] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - - tmp = utils::sat_convert_to_type(x[8] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[9] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[10] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[11] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[12] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[13] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[14] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[15] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - - tmp = utils::sat_convert_to_type(x[16] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[17] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[18] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[19] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[20] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[21] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[22] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[23] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - - tmp = utils::sat_convert_to_type(x[24] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[25] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[26] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[27] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[28] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[29] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[30] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type(x[31] / scale); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; + + f4_t tmp; + + ck::static_for<0, 32, 1>{}([&](auto idx) { + tmp = utils::sat_convert_to_type(x[static_cast(idx)] / scale); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + }); return f4_values.f4x32_array; #endif @@ -967,7 +831,16 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) uint32_t bitwise; f4x2_t f4x2_array[4]; } value{0}; - value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0); +// apply a temporary workaround for gfx950 +#if CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION + uint8_t l = utils::sat_convert_to_type_sr(x[1] / scale, rng); + uint8_t h = utils::sat_convert_to_type_sr(x[0] / scale, rng); + value.bitwise = (h << 4) | l; +#else + // permute high bits and low bits to match the order of the original vector + value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0); +#endif // CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION return value.f4x2_array[0]; #else union @@ -997,64 +870,23 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f __uint128_t bitwise; f4x2_t f4x2_array[16]; f4x32_t f4x32_array; - } f4_values{0}, tmp_values{0}; + } f4_values{0}; union { float2_t floatx2_array[16]; float32_t floatx32_array; } float_values{{0}}; - // TODO: pack in a loop - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[0], rng, scale, 0); - f4_values.f4x2_array[0] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[1], rng, scale, 0); - f4_values.f4x2_array[1] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[2], rng, scale, 0); - f4_values.f4x2_array[2] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[3], rng, scale, 0); - f4_values.f4x2_array[3] = tmp_values.f4x2_array[0]; - - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[4], rng, scale, 0); - f4_values.f4x2_array[4] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[5], rng, scale, 0); - f4_values.f4x2_array[5] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[6], rng, scale, 0); - f4_values.f4x2_array[6] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[7], rng, scale, 0); - f4_values.f4x2_array[7] = tmp_values.f4x2_array[0]; - - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[8], rng, scale, 0); - f4_values.f4x2_array[8] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[9], rng, scale, 0); - f4_values.f4x2_array[9] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[10], rng, scale, 0); - f4_values.f4x2_array[10] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[11], rng, scale, 0); - f4_values.f4x2_array[11] = tmp_values.f4x2_array[0]; - - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[12], rng, scale, 0); - f4_values.f4x2_array[12] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[13], rng, scale, 0); - f4_values.f4x2_array[13] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[14], rng, scale, 0); - f4_values.f4x2_array[14] = tmp_values.f4x2_array[0]; - tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( - tmp_values.bitwise, float_values.floatx2_array[15], rng, scale, 0); - f4_values.f4x2_array[15] = tmp_values.f4x2_array[0]; + float_values.floatx32_array = x; + + ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { + // permute high bits and low bits to match the order of the original vector + f4_values.f4x2_array[idx] = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + f4_values.bitwise, + float2_t{float_values.floatx2_array[idx][1], float_values.floatx2_array[idx][0]}, + rng, + scale, + 0); + }); return f4_values.f4x32_array; #else @@ -1064,106 +896,14 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f f4x2_t f4x2_array[16]; f4x32_t f4x32_array; } f4_values{0}; - // TODO: pack in a loop - auto tmp = utils::sat_convert_to_type_sr(x[0] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[1] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[2] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[3] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[4] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[5] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[6] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[7] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - - tmp = utils::sat_convert_to_type_sr(x[8] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[9] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[10] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[11] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[12] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[13] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[14] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[15] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - - tmp = utils::sat_convert_to_type_sr(x[16] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[17] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[18] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[19] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[20] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[21] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[22] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[23] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - - tmp = utils::sat_convert_to_type_sr(x[24] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[25] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[26] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[27] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[28] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[29] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[30] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; - tmp = utils::sat_convert_to_type_sr(x[31] / scale, rng); - f4_values.bitwise <<= 4; - f4_values.bitwise |= tmp; + + f4_t tmp; + + ck::static_for<0, 32, 1>{}([&](auto idx) { + tmp = utils::sat_convert_to_type_sr(x[static_cast(idx)] / scale, rng); + f4_values.bitwise <<= 4; + f4_values.bitwise |= tmp; + }); return f4_values.f4x32_array; #endif @@ -1232,13 +972,15 @@ inline __host__ __device__ float2_t type_convert(f4x2_t x) } value{}; value.f4x2_array[0] = x; float scale = 1.0f; - return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0); + float2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0); + // permute high bits and low bits to match the order of the original vector + return float2_t{tmp[1], tmp[0]}; #else float2_t ret{ utils::to_float(NumericLimits::Binary_1(), - x.template AsType()[Number<0>{}].unpack<>(Number<1>{})), + x.template AsType()[Number<0>{}].unpack<>(Number<0>{})), utils::to_float(NumericLimits::Binary_1(), - x.template AsType()[Number<0>{}].unpack<>(Number<0>{}))}; + x.template AsType()[Number<0>{}].unpack<>(Number<1>{}))}; return ret; #endif } @@ -1253,110 +995,16 @@ inline __host__ __device__ float32_t type_convert(f4x32_t x) f4x32_t f4x32_array; f4x2_t fp4x2[16]; } value{x}; - union - { - uint32_t bitwise; - f4x2_t f4x2_array[4]; - } bitwise_value{}; float2_t op; float32_t ret; float scale = 1.0f; - // TODO: pack in a loop - bitwise_value.f4x2_array[0] = value.fp4x2[0]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[0] = op[0]; - ret[1] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[1]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[2] = op[0]; - ret[3] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[2]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[4] = op[0]; - ret[5] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[3]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[6] = op[0]; - ret[7] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[4]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[8] = op[0]; - ret[9] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[5]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[10] = op[0]; - ret[11] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[6]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[12] = op[0]; - ret[13] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[7]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[14] = op[0]; - ret[15] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[8]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[16] = op[0]; - ret[17] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[9]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[18] = op[0]; - ret[19] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[10]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[20] = op[0]; - ret[21] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[11]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[22] = op[0]; - ret[23] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[12]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[24] = op[0]; - ret[25] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[13]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[26] = op[0]; - ret[27] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[14]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[28] = op[0]; - ret[29] = op[1]; - - bitwise_value.f4x2_array[0] = value.fp4x2[15]; - op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( - bitwise_value.bitwise, type_convert(scale), 0); - ret[30] = op[0]; - ret[31] = op[1]; + + ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { + op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], scale, 0); + // permute high bits and low bits to match the order of the original vector + ret[2 * idx] = op[1]; + ret[2 * idx + 1] = op[0]; + }); return ret; #else @@ -1371,106 +1019,18 @@ inline __host__ __device__ float32_t type_convert(f4x32_t x) f4x2_t f4x2_array[16]; f4x32_t f4x32_array; } f4_values{bit_cast<__uint128_t>(x)}; - // TODO: pack in a loop - float_values.float_array[0] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[0].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[1] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[0].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[2] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[1].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[3] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[1].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[4] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[2].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[5] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[2].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[6] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[3].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[7] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[3].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - - float_values.float_array[0] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[4].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[1] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[4].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[2] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[5].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[3] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[5].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[4] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[6].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[5] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[6].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[6] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[7].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[7] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[7].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - - float_values.float_array[0] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[8].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[1] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[8].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[2] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[9].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[3] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[9].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[4] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[10].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[5] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[10].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[6] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[11].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[7] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[11].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - - float_values.float_array[0] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[12].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[1] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[12].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[2] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[13].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[3] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[13].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[4] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[14].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[5] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[14].template AsType()[Number<0>{}].unpack<>(Number<1>{})); - float_values.float_array[6] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[15].template AsType()[Number<0>{}].unpack<>(Number<0>{})); - float_values.float_array[7] = utils::to_float( - NumericLimits::Binary_1(), - f4_values.f4x2_array[15].template AsType()[Number<0>{}].unpack<>(Number<1>{})); + + ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { + float_values.float_array[2 * idx] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[idx].template AsType()[Number<0>{}].template unpack<>( + Number<0>{})); + + float_values.float_array[2 * idx + 1] = utils::to_float( + NumericLimits::Binary_1(), + f4_values.f4x2_array[idx].template AsType()[Number<0>{}].template unpack<>( + Number<1>{})); + }); return float_values.float32_array; #endif diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 58d8768736..8a0f631b39 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -75,6 +75,12 @@ if(GPU_TARGETS MATCHES "gfx950") endif() add_dependencies(test_mx_data_types test_mx_bf8) + add_gtest_executable(test_mx_fp4 test_mx_fp4.cpp) + if(result EQUAL 0) + target_link_libraries(test_mx_fp4 PRIVATE utility) + endif() + add_dependencies(test_mx_data_types test_mx_fp4) + add_gtest_executable(test_e8m0 test_e8m0.cpp) if(result EQUAL 0) target_link_libraries(test_e8m0 PRIVATE utility) diff --git a/test/data_type/test_mx_fp4.cpp b/test/data_type/test_mx_fp4.cpp new file mode 100644 index 0000000000..449f6fc777 --- /dev/null +++ b/test/data_type/test_mx_fp4.cpp @@ -0,0 +1,541 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/library/utility/device_memory.hpp" +#include "ck/utility/scaled_type_convert.hpp" + +using ck::e8m0_bexp_t; +using ck::float16_t; +using ck::float2_t; +using ck::float32_t; +using ck::scaled_type_convert; +using ck::type_convert; + +using ck::f4_convert_rne; +using ck::f4_convert_sr; +using ck::f4_t; +using ck::f4x16_t; +using ck::f4x2_pk_t; +using ck::f4x2_t; +using ck::f4x32_t; + +constexpr uint64_t test_size = 256 * 16 + 2 + 4 + 6; + +/** + * @brief Tests conversion of FP4 values to float using E8M0 exponent scaling. + * + * This function performs a series of conversions from FP4 values to float values using + * E8M0 exponent scaling. It handles all possible combinations of E8M0 and FP4 values, + * as well as specific vector and rounding conversions. + * + * @param N The maximum number of conversions to perform. + * @param p_test Pointer to the output array where the converted float values will be stored. + * @param p_completed Pointer to a variable that tracks the number of completed conversions. + * + * @note If either p_test or p_completed is nullptr, the function will return immediately. + * @note The function will stop converting if the number of conversions reaches N. + * @note First 256*16 conversions are for all possible combinations of E8M0 and FP4 values that are + * stored in memory sequentially with FP4 values varying faster. + * + * The function performs the following conversions: + * - All possible combinations of E8M0 and FP4 values. [256x16] + * - Vector conversions f4x2 -> f32x2. [2] + * - Vector conversions f32x2 -> f4x2 rne. [2] + * - Vector conversions f32x2 -> f4x2 sr. [2] + * - Round to nearest even conversions for specific float values. [6] + * + * The results are stored in the p_test array, and the number of completed conversions + * is updated in the p_completed variable. + */ +__host__ __device__ void +test_mx_fp4_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + // All possible combinations of E8M0 and FP4 + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++) + { + uint8_t fp4_uid = static_cast(fp4_id); + auto v = scaled_type_convert(e8m0_bexp_t(exp_id), f4_t(fp4_uid & 0b00001111)); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + } + + /// Test vector conversions + // f4x2 -> f32x2 + f4x2_t f4x2{f4x2_t::data_v{0b00011100}}; // 0b0001(=0.5) and 0b1100(=-2.0) + auto scale2 = e8m0_bexp_t(2.0f); + + float2_t f32x2 = scaled_type_convert(scale2, f4x2); + p_test[i++] = f32x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f32x2[1]; + if(i >= N) + { + return; + } + + // f32x2 -> f4x2 + f32x2 = {1.0f, -4.0f}; + f4x2 = f4_convert_rne(f32x2, type_convert(scale2)); // expect {0.5, -2} + + p_test[i++] = type_convert( + f4_t(f4x2.AsType()(ck::Number<0>{}).unpack<>(ck::Number<0>{}))); // 0.5f + if(i >= N) + { + return; + } + p_test[i++] = type_convert( + f4_t(f4x2.AsType()(ck::Number<0>{}).unpack<>(ck::Number<1>{}))); // -2.0f + if(i >= N) + { + return; + } + + f4x2 = f4_convert_sr(f32x2, type_convert(scale2)); // expect {0.5, -2} + + p_test[i++] = type_convert( + f4_t(f4x2.AsType()(ck::Number<0>{}).unpack<>(ck::Number<0>{}))); // 0.5f + if(i >= N) + { + return; + } + p_test[i++] = type_convert( + f4_t(f4x2.AsType()(ck::Number<0>{}).unpack<>(ck::Number<1>{}))); // -2.0f + if(i >= N) + { + return; + } + + /// Test round to nearest even + + p_test[i++] = type_convert(f4_convert_rne(24.0f, 4.0f)); // 24/4 + if(i >= N) + { + return; + } + + p_test[i++] = type_convert( + f4_convert_rne(std::numeric_limits::quiet_NaN(), 4.0f)); // => NaN + if(i >= N) + { + return; + } + + // Inf/2 > 6.0 => 6.0 on device + p_test[i++] = type_convert(f4_convert_rne(std::numeric_limits::infinity(), 2.0f)); + if(i >= N) + { + return; + } + + // 256/0.5 > 6.0 => 6.0 on device + p_test[i++] = type_convert(f4_convert_rne(256.0f, 0.5f)); + if(i >= N) + { + return; + } + + // -256/0.5 < -6.0 => -6.0 on device + p_test[i++] = type_convert(f4_convert_rne(-256.0f, 0.5f)); + if(i >= N) + { + return; + } + + // proper scale selection + p_test[i++] = type_convert(f4_convert_rne(20.0f, 4.0f)); // 20.0/4.0 = 5.0 + if(i >= N) + { + return; + } +} + +TEST(MXFP4, HostScaledConvert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_mx_fp4_scaled_convert(test_size, out.data(), &completed); + + // V = X * P; X - E8M0 scale, P - FP4 + + // If X = NaN, then V = NaN regardless of P + uint8_t e8m0_nan_id = ck::NumericLimits::QuietNaN().data; + for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++) + { + auto idx = e8m0_nan_id * 16 + fp4_id; + ASSERT_TRUE(std::isnan(out[idx])); + } + + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + if(exp_id == e8m0_nan_id) + continue; + for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++) + { + uint8_t fp4_uid = static_cast(fp4_id); + auto idx = exp_id * 16 + fp4_uid; + ASSERT_FLOAT_EQ(out[idx], + type_convert(e8m0_bexp_t(exp_id)) * + type_convert(f4_t(fp4_uid & 0b00001111))) + << "exp_id: " << exp_id << " fp4_id: " << fp4_id << std::endl + << type_convert(e8m0_bexp_t(exp_id)) << " * " + << type_convert(f4_t(fp4_uid & 0b00001111)); + } + } + + /// Test vector conversions + + auto i = 256 * 16; + + // f4x2 -> f32x2 + EXPECT_EQ(out[i++], 1.0f); + EXPECT_EQ(out[i++], -4.0f); + + // f32x2 -> f4x2 + // RNE + EXPECT_EQ(out[i++], 0.5f); + EXPECT_EQ(out[i++], -2.0f); + // SR + EXPECT_EQ(out[i++], 0.5f); + EXPECT_EQ(out[i++], -2.0f); + + /// Test round to nearest even + EXPECT_EQ(out[i++], 24.0f / 4.0f) << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Lowest())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(type_convert(5.0f))) + << "out[i-1]: " << out[i - 1]; + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void test_mx_fp4_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + test_mx_fp4_scaled_convert(N, p_test, p_completed); +} + +TEST(MXFP4, DeviceScaledConvert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_fp4_device_scaled_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + // V = X * P; X - E8M0 scale, P - FP4 + + // If X = NaN, then V = NaN regardless of P + uint8_t e8m0_nan_id = ck::NumericLimits::QuietNaN().data; + for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++) + { + auto idx = e8m0_nan_id * 16 + fp4_id; + ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; + } + + for(ck::index_t exp_id = 0; exp_id < 256; exp_id++) + { + if(exp_id == e8m0_nan_id) + continue; + for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++) + { + uint8_t fp4_uid = static_cast(fp4_id); + auto idx = exp_id * 16 + fp4_uid; + ASSERT_FLOAT_EQ(out[idx], + type_convert(e8m0_bexp_t(exp_id)) * + type_convert(f4_t(fp4_uid & 0b00001111))) + << "exp_id: " << exp_id << " fp4_id: " << fp4_id << std::endl + << type_convert(e8m0_bexp_t(exp_id)) << " * " + << type_convert(f4_t(fp4_uid & 0b00001111)); + } + } + + /// Test vector conversions + + auto i = 256 * 16; + + // f4x2 -> f32x2 + EXPECT_EQ(out[i++], 1.0f); + EXPECT_EQ(out[i++], -4.0f); + + // f32x2 -> f4x2 + // RNE + EXPECT_EQ(out[i++], 0.5f); + EXPECT_EQ(out[i++], -2.0f); + // SR + EXPECT_EQ(out[i++], 0.5f); + EXPECT_EQ(out[i++], -2.0f); + + /// Test round to nearest even + EXPECT_EQ(out[i++], 24.0f / 4.0f) << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Max())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(ck::NumericLimits::Lowest())) + << "out[i-1]: " << out[i - 1]; + EXPECT_EQ(out[i++], type_convert(type_convert(5.0f))) + << "out[i-1]: " << out[i - 1]; + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ float vec16_generator(ck::index_t i, float scale) +{ + return scale * type_convert(f4_t(i & 0b00001111)); +} + +__host__ __device__ float vec32_generator(ck::index_t i, float scale) +{ + if(i < 16) + { + return vec16_generator( + i, scale); // all positive values, then all negative values in ascending order + } + else + { + return vec16_generator( + 15 - (i % 16), + scale); // all negative values, then all positive values in descending order + } +} + +__global__ void test_mx_fp4x32_device_scaled_convert(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 32; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + auto scale2 = e8m0_bexp_t(2.0f); + + f4x32_t f4x32{}; + float32_t float32{}; + ck::static_for<0, N, 1>{}([&](auto ii) { + float32[static_cast(ii)] = vec32_generator(ii, type_convert(scale2)); + }); + + f4x32 = f4_convert_rne(float32, type_convert(scale2)); + + ck::static_for<0, N / 2, 1>{}([&](auto ii) { + p_test[i++] = type_convert( + f4_t(f4x32.AsType()(ck::Number{}).template unpack<>(ck::Number<0>{}))); + p_test[i++] = type_convert( + f4_t(f4x32.AsType()(ck::Number{}).template unpack<>(ck::Number<1>{}))); + }); +} + +TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvert) +{ + constexpr int N = 32; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_fp4x32_device_scaled_convert<<<1, 1>>>( + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + auto i = 0; + auto scale2 = e8m0_bexp_t(2.0f); + + ck::static_for<0, N, 1>{}([&](auto ii) { + EXPECT_EQ(out[i++], + vec32_generator(ii, type_convert(scale2)) / type_convert(scale2)) + << "ii: " << ii << std::endl; + }); + + EXPECT_EQ(N, completed); + EXPECT_EQ(N, i); +} + +__global__ void test_mx_fp4x32_device_scaled_convert_sr(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 32; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + auto scale2 = e8m0_bexp_t(2.0f); + + f4x32_t f4x32{}; + float32_t float32{}; + ck::static_for<0, N, 1>{}([&](auto ii) { + float32[static_cast(ii)] = vec32_generator(ii, type_convert(scale2)); + }); + + f4x32 = f4_convert_sr(float32, type_convert(scale2)); + + ck::static_for<0, N / 2, 1>{}([&](auto ii) { + p_test[i++] = type_convert( + f4_t(f4x32.AsType()(ck::Number{}).template unpack<>(ck::Number<0>{}))); + p_test[i++] = type_convert( + f4_t(f4x32.AsType()(ck::Number{}).template unpack<>(ck::Number<1>{}))); + }); +} + +TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvertSR) +{ + constexpr int N = 32; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_fp4x32_device_scaled_convert_sr<<<1, 1>>>( + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + auto i = 0; + auto scale2 = e8m0_bexp_t(2.0f); + + ck::static_for<0, N, 1>{}([&](auto ii) { + EXPECT_EQ(out[i++], + vec32_generator(ii, type_convert(scale2)) / type_convert(scale2)) + << "ii: " << ii << std::endl; + }); + + EXPECT_EQ(N, completed); + EXPECT_EQ(N, i); +} + +__global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_completed) +{ + constexpr int N = 32; + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + auto scale2 = e8m0_bexp_t(2.0f); + + f4x32_t f4x32{}; + float32_t float32{}; + ck::static_for<0, N / 2, 1>{}([&](auto ii) { + f4x32.AsType()(ck::Number{}) = f4x2_pk_t{}.pack( + type_convert(vec32_generator(2 * ii, type_convert(scale2)) / + type_convert(scale2)), + type_convert(vec32_generator(2 * ii + 1, type_convert(scale2)) / + type_convert(scale2))); + }); + + float32 = scaled_type_convert(scale2, f4x32); + + ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32[static_cast(ii)]; }); +} + +TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert) +{ + constexpr int N = 32; + std::vector out(N, -1.0f); + + DeviceMem device_out(N * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + test_mx_f32x32_device_scaled_convert<<<1, 1>>>( + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + auto i = 0; + auto scale2 = e8m0_bexp_t(2.0f); + + ck::static_for<0, N, 1>{}([&](auto ii) { + EXPECT_EQ(out[i++], vec32_generator(ii, type_convert(scale2))) + << "ii: " << ii << std::endl; + }); + + EXPECT_EQ(N, completed); + EXPECT_EQ(N, i); +}