-
Notifications
You must be signed in to change notification settings - Fork 257
Add MX FP4 device conversion tests #1889
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 27 commits
737d53d
7c5c2c8
ecd638f
ee8937a
f918177
db2c611
99e771a
bec4968
d1499dd
e38b4a3
83fcce2
af493e6
d90b505
e323d61
7daf210
d19cebb
99f47e8
5022c8b
50c1291
bb953da
e971702
bd6c212
c426051
19367bd
409e29e
3fe8d31
ddaa893
ea66d0c
895c0f7
bc3d2f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -377,12 +377,15 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b | |
| f4x2_t f4x2_array[4]; | ||
| } value{}; | ||
| value.f4x2_array[0] = x; | ||
| return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0); | ||
| float2_t tmp = | ||
| __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0); | ||
| // permute high bits and low bits to match the order of the original vector | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need these changes because we modified packing order in the fp4 storage? __host__ __device__ inline type pack(const type x0, const type x1)
{
return (x0 << 4) | (x1 & 0b00001111);
}
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll try to provide more context for both comments here. In the CK we have several ways of packing elements into vectors: llvm clang vectors, our custom non_native_vector_base and custom types which we pack manually. In llvm clang vectors 0th element is stored in the highest bits and Nth element is in the lowest bits. Same layout is used in the non_native_vector_base, which makes sense as we use llvm clang vector under the hood. So I decided to update the f4x2_pk_t type to have a consistent layout with other vectors. I believe the issue with native conversion instructions is that they swap high and low bits, so we have to swap either input or output vector elements. I believe keeping old f4x2_pk_t layout would help with this issue, but have to be well documented and considered in the tests. @andriy-ca what is your perspective on it?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having elements in bytes aligned consistently with the other data types makes sense. |
||
| return float2_t{tmp[1], tmp[0]}; | ||
| #else | ||
| float2_t ret{utils::to_float<f4_t>( | ||
| scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})), | ||
| scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})), | ||
| utils::to_float<f4_t>( | ||
| scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))}; | ||
| scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}))}; | ||
| return ret; | ||
| #endif | ||
| } | ||
|
|
@@ -398,109 +401,16 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m | |
| f4x32_t f4x32_array; | ||
| f4x2_t fp4x2[16]; | ||
| } value{x}; | ||
| union | ||
| { | ||
| uint32_t bitwise; | ||
| f4x2_t f4x2_array[4]; | ||
| } bitwise_value{}; | ||
| float2_t op; | ||
| float32_t ret; | ||
| // TODO: pack in a loop | ||
| bitwise_value.f4x2_array[0] = value.fp4x2[0]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[0] = op[0]; | ||
| ret[1] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[1]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[2] = op[0]; | ||
| ret[3] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[2]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[4] = op[0]; | ||
| ret[5] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[3]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[6] = op[0]; | ||
| ret[7] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[4]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[8] = op[0]; | ||
| ret[9] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[5]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[10] = op[0]; | ||
| ret[11] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[6]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[12] = op[0]; | ||
| ret[13] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[7]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[14] = op[0]; | ||
| ret[15] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[8]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[16] = op[0]; | ||
| ret[17] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[9]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[18] = op[0]; | ||
| ret[19] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[10]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[20] = op[0]; | ||
| ret[21] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[11]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[22] = op[0]; | ||
| ret[23] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[12]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[24] = op[0]; | ||
| ret[25] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[13]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[26] = op[0]; | ||
| ret[27] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[14]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[28] = op[0]; | ||
| ret[29] = op[1]; | ||
|
|
||
| bitwise_value.f4x2_array[0] = value.fp4x2[15]; | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4( | ||
| bitwise_value.bitwise, type_convert<float>(scale), 0); | ||
| ret[30] = op[0]; | ||
| ret[31] = op[1]; | ||
| float f_scale = type_convert<float>(scale); | ||
|
|
||
| ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { | ||
| op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0); | ||
| // permute high bits and low bits to match the order of the original vector | ||
| ret[2 * idx] = op[1]; | ||
| ret[2 * idx + 1] = op[0]; | ||
| }); | ||
|
|
||
| return ret; | ||
| #else | ||
|
|
@@ -515,106 +425,18 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m | |
| f4x2_t f4x2_array[16]; | ||
| f4x32_t f4x32_array; | ||
| } f4_values{bit_cast<__uint128_t>(x)}; | ||
| // TODO: pack in a loop | ||
| float_values.float_array[0] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[1] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
| float_values.float_array[2] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[3] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
| float_values.float_array[4] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[5] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
| float_values.float_array[6] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[7] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
|
|
||
| float_values.float_array[0] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[1] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
| float_values.float_array[2] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[3] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
| float_values.float_array[4] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[5] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
| float_values.float_array[6] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[7] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
|
|
||
| float_values.float_array[0] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[1] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
| float_values.float_array[2] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[3] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
| float_values.float_array[4] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[5] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
| float_values.float_array[6] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[7] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
|
|
||
| float_values.float_array[0] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[1] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
| float_values.float_array[2] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[3] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
| float_values.float_array[4] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[5] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
| float_values.float_array[6] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})); | ||
| float_values.float_array[7] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})); | ||
|
|
||
| ck::static_for<0, 32 / 2, 1>{}([&](auto idx) { | ||
| float_values.float_array[2 * idx] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>( | ||
| Number<0>{})); | ||
|
|
||
| float_values.float_array[2 * idx + 1] = utils::to_float<f4_t>( | ||
| scale, | ||
| f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>( | ||
| Number<1>{})); | ||
| }); | ||
|
|
||
| return float_values.float32_array; | ||
| #endif | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was the original order incorrect?