Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
737d53d
Add conversion tests
geyyer Feb 12, 2025
7c5c2c8
Fix ctor
geyyer Feb 12, 2025
ecd638f
Fix nan logic
geyyer Feb 12, 2025
ee8937a
Fix conversion logic
geyyer Feb 12, 2025
f918177
Permute packed f4_t values
geyyer Feb 12, 2025
db2c611
Fix conversion to float, repack vector elements
geyyer Feb 13, 2025
99e771a
Fix device tests
geyyer Feb 13, 2025
bec4968
Permute elements in a vector
geyyer Feb 13, 2025
d1499dd
Add a repro test
geyyer Feb 14, 2025
e38b4a3
Add a conversion for a repro test
geyyer Feb 14, 2025
83fcce2
Update test vectors
geyyer Feb 17, 2025
af493e6
Update conversion
geyyer Feb 17, 2025
d90b505
Fix the test
geyyer Feb 18, 2025
e323d61
Update test vector generator
geyyer Feb 18, 2025
7daf210
Fix vector sr conversion
geyyer Feb 18, 2025
d19cebb
Permute conversion args
geyyer Feb 18, 2025
99f47e8
Update conversion
geyyer Feb 18, 2025
5022c8b
Test
geyyer Feb 18, 2025
50c1291
Fix packing
geyyer Feb 18, 2025
bb953da
Simplify conversion function
geyyer Feb 19, 2025
e971702
Pack conversion in a loop
geyyer Feb 19, 2025
bd6c212
Pack conversion in a loop
geyyer Feb 20, 2025
c426051
Pack another conversion in a loop
geyyer Feb 20, 2025
19367bd
Pack one more conversion in a loop
geyyer Feb 21, 2025
409e29e
Pack the last conversion in a loop
geyyer Feb 21, 2025
3fe8d31
Clean up
geyyer Feb 21, 2025
ddaa893
Merge branch 'develop' into lwpck-2836
geyyer Feb 21, 2025
ea66d0c
Merge branch 'develop' into lwpck-2838
geyyer Mar 24, 2025
895c0f7
Add printf to fix intrinsic
geyyer Mar 24, 2025
bc3d2f4
Add a sw-based workaround
geyyer Mar 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/ck/utility/data_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,22 @@ struct f4x2_pk_t
{
using type = uint8_t;
type data;
f4x2_pk_t() : data{type{}} {}
f4x2_pk_t(type init) : data{init} {}
__host__ __device__ f4x2_pk_t() : data{type{}} {}
__host__ __device__ f4x2_pk_t(type init) : data{init} {}

template <index_t I>
__host__ __device__ inline type unpack(Number<I>) const
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 0)
return data & 0b00001111;
else
return (data >> 4);
else
return data & 0b00001111;
}

__host__ __device__ inline type pack(const type x0, const type x1)
{
return (x1 << 4) | (x0 & 0b00001111);
return (x0 << 4) | (x1 & 0b00001111);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the original order incorrect?

}
};

Expand Down
12 changes: 6 additions & 6 deletions include/ck/utility/mxf4_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand All @@ -13,7 +13,7 @@ __host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
f4_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale == NumericLimits<e8m0_bexp_t>::QuietNaN();
return scale.is_nan();
}

// no infinity representation in ocp_e2m1_mxfp4 will always return false
Expand All @@ -26,11 +26,9 @@ __host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unu
}

template <>
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale, f4_t const data)
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return false;

// no need to check for scale as it does not have a 0 representation
f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;

Expand All @@ -41,7 +39,9 @@ template <>
__host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
{
return std::numeric_limits<float>::quiet_NaN();
}

if(is_zero<f4_t>(scale, data))
return 0.0f;
Expand Down
230 changes: 26 additions & 204 deletions include/ck/utility/scaled_type_convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,15 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
f4x2_t f4x2_array[4];
} value{};
value.f4x2_array[0] = x;
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
float2_t tmp =
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
// permute high bits and low bits to match the order of the original vector
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need these changes because we modified packing order in the fp4 storage?

    __host__ __device__ inline type pack(const type x0, const type x1)
    {
        return (x0 << 4) | (x1 & 0b00001111);
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to provide more context for both comments here. In the CK we have several ways of packing elements into vectors: llvm clang vectors, our custom non_native_vector_base and custom types which we pack manually. In llvm clang vectors 0th element is stored in the highest bits and Nth element is in the lowest bits. Same layout is used in the non_native_vector_base, which makes sense as we use llvm clang vector under the hood. So I decided to update the f4x2_pk_t type to have a consistent layout with other vectors. I believe the issue with native conversion instructions is that they swap high and low bits, so we have to swap either input or output vector elements. I believe keeping old f4x2_pk_t layout would help with this issue, but have to be well documented and considered in the tests. @andriy-ca what is your perspective on it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having elements in bytes aligned consistently with the other data types makes sense.

return float2_t{tmp[1], tmp[0]};
#else
float2_t ret{utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})),
utils::to_float<f4_t>(
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}))};
return ret;
#endif
}
Expand All @@ -398,109 +401,16 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
f4x32_t f4x32_array;
f4x2_t fp4x2[16];
} value{x};
union
{
uint32_t bitwise;
f4x2_t f4x2_array[4];
} bitwise_value{};
float2_t op;
float32_t ret;
// TODO: pack in a loop
bitwise_value.f4x2_array[0] = value.fp4x2[0];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[0] = op[0];
ret[1] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[1];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[2] = op[0];
ret[3] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[2];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[4] = op[0];
ret[5] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[3];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[6] = op[0];
ret[7] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[4];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[8] = op[0];
ret[9] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[5];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[10] = op[0];
ret[11] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[6];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[12] = op[0];
ret[13] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[7];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[14] = op[0];
ret[15] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[8];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[16] = op[0];
ret[17] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[9];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[18] = op[0];
ret[19] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[10];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[20] = op[0];
ret[21] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[11];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[22] = op[0];
ret[23] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[12];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[24] = op[0];
ret[25] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[13];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[26] = op[0];
ret[27] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[14];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[28] = op[0];
ret[29] = op[1];

bitwise_value.f4x2_array[0] = value.fp4x2[15];
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
bitwise_value.bitwise, type_convert<float>(scale), 0);
ret[30] = op[0];
ret[31] = op[1];
float f_scale = type_convert<float>(scale);

ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0);
// permute high bits and low bits to match the order of the original vector
ret[2 * idx] = op[1];
ret[2 * idx + 1] = op[0];
});

return ret;
#else
Expand All @@ -515,106 +425,18 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
f4x2_t f4x2_array[16];
f4x32_t f4x32_array;
} f4_values{bit_cast<__uint128_t>(x)};
// TODO: pack in a loop
float_values.float_array[0] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));

float_values.float_array[0] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));

float_values.float_array[0] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));

float_values.float_array[0] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[2] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[3] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[4] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[5] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
float_values.float_array[6] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
float_values.float_array[7] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));

ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
float_values.float_array[2 * idx] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>(
Number<0>{}));

float_values.float_array[2 * idx + 1] = utils::to_float<f4_t>(
scale,
f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>(
Number<1>{}));
});

return float_values.float32_array;
#endif
Expand Down
Loading