Skip to content

Commit 441343a

Browse files
authored
Add MX FP4 device conversion tests (#1889)
* Add conversion tests * Fix ctor * Fix nan logic * Fix conversion logic * Permute packed f4_t values * Fix conversion to float, repack vector elements * Fix device tests * Permute elements in a vector * Add a repro test * Add a conversion for a repro test * Update test vectors * Update conversion * Fix the test * Update test vector generator * Fix vector sr conversion * Permute conversion args * Update conversion * Test * Fix packing * Simplify conversion function * Pack conversion in a loop * Pack conversion in a loop * Pack another conversion in a loop * Pack one more conversion in a loop * Pack the last conversion in a loop * Clean up * Add printf to fix intrinsic * Add a sw-based workaround
1 parent 23a9497 commit 441343a

File tree

8 files changed

+658
-728
lines changed

8 files changed

+658
-728
lines changed

include/ck/ck.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
245245
// workaround: compiler issue on gfx908
246246
#define CK_WORKAROUND_SWDEV_388832 1
247247

248+
// workaround: compiler issue on gfx950
249+
#define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1
250+
248251
// denorm test fix, necessary for gfx90a
249252
#ifndef CK_GFX90A_DENORM_WORKAROUND
250253
#define CK_GFX90A_DENORM_WORKAROUND 0

include/ck/utility/data_type.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,22 @@ struct f4x2_pk_t
3636
{
3737
using type = uint8_t;
3838
type data;
39-
f4x2_pk_t() : data{type{}} {}
40-
f4x2_pk_t(type init) : data{init} {}
39+
__host__ __device__ f4x2_pk_t() : data{type{}} {}
40+
__host__ __device__ f4x2_pk_t(type init) : data{init} {}
4141

4242
template <index_t I>
4343
__host__ __device__ inline type unpack(Number<I>) const
4444
{
4545
static_assert(I < 2, "Index is out of range.");
4646
if constexpr(I == 0)
47-
return data & 0b00001111;
48-
else
4947
return (data >> 4);
48+
else
49+
return data & 0b00001111;
5050
}
5151

5252
__host__ __device__ inline type pack(const type x0, const type x1)
5353
{
54-
return (x1 << 4) | (x0 & 0b00001111);
54+
return (x0 << 4) | (x1 & 0b00001111);
5555
}
5656
};
5757

include/ck/utility/mxf4_utils.hpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
#ifndef CK_CODE_GEN_RTC
55
#pragma once
@@ -14,7 +14,7 @@ __host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
1414
f4_t const dataBytes [[maybe_unused]])
1515
{
1616
// no need to check for data as it does not have NaN representation
17-
return scale == NumericLimits<e8m0_bexp_t>::QuietNaN();
17+
return scale.is_nan();
1818
}
1919

2020
// no infinity representation in ocp_e2m1_mxfp4 will always return false
@@ -27,11 +27,9 @@ __host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unu
2727
}
2828

2929
template <>
30-
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale, f4_t const data)
30+
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
31+
f4_t const data)
3132
{
32-
if(is_nan<f4_t>(scale, data))
33-
return false;
34-
3533
// no need to check for scale as it does not have a 0 representation
3634
f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;
3735

include/ck/utility/mxfp_utils.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ template <typename T>
9999
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);
100100

101101
template <typename T>
102-
inline T convert_to_type(float value)
102+
__host__ __device__ inline T convert_to_type(float value)
103103
{
104104
using bitwise_type = typename NumericUtils<T>::bitwise_type;
105105

@@ -258,7 +258,7 @@ inline T convert_to_type(float value)
258258
}
259259

260260
template <typename T>
261-
inline T convert_to_type_sr(float value, uint32_t seed)
261+
__host__ __device__ inline T convert_to_type_sr(float value, uint32_t seed)
262262
{
263263
if(std::abs(value) > NumericLimits<T>::Max())
264264
{

include/ck/utility/scaled_type_convert.hpp

Lines changed: 26 additions & 204 deletions
Original file line numberDiff line numberDiff line change
@@ -377,12 +377,15 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
377377
f4x2_t f4x2_array[4];
378378
} value{};
379379
value.f4x2_array[0] = x;
380-
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
380+
float2_t tmp =
381+
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
382+
// permute high bits and low bits to match the order of the original vector
383+
return float2_t{tmp[1], tmp[0]};
381384
#else
382385
float2_t ret{utils::to_float<f4_t>(
383-
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
386+
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})),
384387
utils::to_float<f4_t>(
385-
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
388+
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}))};
386389
return ret;
387390
#endif
388391
}
@@ -398,109 +401,16 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
398401
f4x32_t f4x32_array;
399402
f4x2_t fp4x2[16];
400403
} value{x};
401-
union
402-
{
403-
uint32_t bitwise;
404-
f4x2_t f4x2_array[4];
405-
} bitwise_value{};
406404
float2_t op;
407405
float32_t ret;
408-
// TODO: pack in a loop
409-
bitwise_value.f4x2_array[0] = value.fp4x2[0];
410-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
411-
bitwise_value.bitwise, type_convert<float>(scale), 0);
412-
ret[0] = op[0];
413-
ret[1] = op[1];
414-
415-
bitwise_value.f4x2_array[0] = value.fp4x2[1];
416-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
417-
bitwise_value.bitwise, type_convert<float>(scale), 0);
418-
ret[2] = op[0];
419-
ret[3] = op[1];
420-
421-
bitwise_value.f4x2_array[0] = value.fp4x2[2];
422-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
423-
bitwise_value.bitwise, type_convert<float>(scale), 0);
424-
ret[4] = op[0];
425-
ret[5] = op[1];
426-
427-
bitwise_value.f4x2_array[0] = value.fp4x2[3];
428-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
429-
bitwise_value.bitwise, type_convert<float>(scale), 0);
430-
ret[6] = op[0];
431-
ret[7] = op[1];
432-
433-
bitwise_value.f4x2_array[0] = value.fp4x2[4];
434-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
435-
bitwise_value.bitwise, type_convert<float>(scale), 0);
436-
ret[8] = op[0];
437-
ret[9] = op[1];
438-
439-
bitwise_value.f4x2_array[0] = value.fp4x2[5];
440-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
441-
bitwise_value.bitwise, type_convert<float>(scale), 0);
442-
ret[10] = op[0];
443-
ret[11] = op[1];
444-
445-
bitwise_value.f4x2_array[0] = value.fp4x2[6];
446-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
447-
bitwise_value.bitwise, type_convert<float>(scale), 0);
448-
ret[12] = op[0];
449-
ret[13] = op[1];
450-
451-
bitwise_value.f4x2_array[0] = value.fp4x2[7];
452-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
453-
bitwise_value.bitwise, type_convert<float>(scale), 0);
454-
ret[14] = op[0];
455-
ret[15] = op[1];
456-
457-
bitwise_value.f4x2_array[0] = value.fp4x2[8];
458-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
459-
bitwise_value.bitwise, type_convert<float>(scale), 0);
460-
ret[16] = op[0];
461-
ret[17] = op[1];
462-
463-
bitwise_value.f4x2_array[0] = value.fp4x2[9];
464-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
465-
bitwise_value.bitwise, type_convert<float>(scale), 0);
466-
ret[18] = op[0];
467-
ret[19] = op[1];
468-
469-
bitwise_value.f4x2_array[0] = value.fp4x2[10];
470-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
471-
bitwise_value.bitwise, type_convert<float>(scale), 0);
472-
ret[20] = op[0];
473-
ret[21] = op[1];
474-
475-
bitwise_value.f4x2_array[0] = value.fp4x2[11];
476-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
477-
bitwise_value.bitwise, type_convert<float>(scale), 0);
478-
ret[22] = op[0];
479-
ret[23] = op[1];
480-
481-
bitwise_value.f4x2_array[0] = value.fp4x2[12];
482-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
483-
bitwise_value.bitwise, type_convert<float>(scale), 0);
484-
ret[24] = op[0];
485-
ret[25] = op[1];
486-
487-
bitwise_value.f4x2_array[0] = value.fp4x2[13];
488-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
489-
bitwise_value.bitwise, type_convert<float>(scale), 0);
490-
ret[26] = op[0];
491-
ret[27] = op[1];
492-
493-
bitwise_value.f4x2_array[0] = value.fp4x2[14];
494-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
495-
bitwise_value.bitwise, type_convert<float>(scale), 0);
496-
ret[28] = op[0];
497-
ret[29] = op[1];
498-
499-
bitwise_value.f4x2_array[0] = value.fp4x2[15];
500-
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
501-
bitwise_value.bitwise, type_convert<float>(scale), 0);
502-
ret[30] = op[0];
503-
ret[31] = op[1];
406+
float f_scale = type_convert<float>(scale);
407+
408+
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
409+
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0);
410+
// permute high bits and low bits to match the order of the original vector
411+
ret[2 * idx] = op[1];
412+
ret[2 * idx + 1] = op[0];
413+
});
504414

505415
return ret;
506416
#else
@@ -515,106 +425,18 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
515425
f4x2_t f4x2_array[16];
516426
f4x32_t f4x32_array;
517427
} f4_values{bit_cast<__uint128_t>(x)};
518-
// TODO: pack in a loop
519-
float_values.float_array[0] = utils::to_float<f4_t>(
520-
scale,
521-
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
522-
float_values.float_array[1] = utils::to_float<f4_t>(
523-
scale,
524-
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
525-
float_values.float_array[2] = utils::to_float<f4_t>(
526-
scale,
527-
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
528-
float_values.float_array[3] = utils::to_float<f4_t>(
529-
scale,
530-
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
531-
float_values.float_array[4] = utils::to_float<f4_t>(
532-
scale,
533-
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
534-
float_values.float_array[5] = utils::to_float<f4_t>(
535-
scale,
536-
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
537-
float_values.float_array[6] = utils::to_float<f4_t>(
538-
scale,
539-
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
540-
float_values.float_array[7] = utils::to_float<f4_t>(
541-
scale,
542-
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
543-
544-
float_values.float_array[0] = utils::to_float<f4_t>(
545-
scale,
546-
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
547-
float_values.float_array[1] = utils::to_float<f4_t>(
548-
scale,
549-
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
550-
float_values.float_array[2] = utils::to_float<f4_t>(
551-
scale,
552-
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
553-
float_values.float_array[3] = utils::to_float<f4_t>(
554-
scale,
555-
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
556-
float_values.float_array[4] = utils::to_float<f4_t>(
557-
scale,
558-
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
559-
float_values.float_array[5] = utils::to_float<f4_t>(
560-
scale,
561-
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
562-
float_values.float_array[6] = utils::to_float<f4_t>(
563-
scale,
564-
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
565-
float_values.float_array[7] = utils::to_float<f4_t>(
566-
scale,
567-
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
568-
569-
float_values.float_array[0] = utils::to_float<f4_t>(
570-
scale,
571-
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
572-
float_values.float_array[1] = utils::to_float<f4_t>(
573-
scale,
574-
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
575-
float_values.float_array[2] = utils::to_float<f4_t>(
576-
scale,
577-
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
578-
float_values.float_array[3] = utils::to_float<f4_t>(
579-
scale,
580-
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
581-
float_values.float_array[4] = utils::to_float<f4_t>(
582-
scale,
583-
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
584-
float_values.float_array[5] = utils::to_float<f4_t>(
585-
scale,
586-
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
587-
float_values.float_array[6] = utils::to_float<f4_t>(
588-
scale,
589-
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
590-
float_values.float_array[7] = utils::to_float<f4_t>(
591-
scale,
592-
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
593-
594-
float_values.float_array[0] = utils::to_float<f4_t>(
595-
scale,
596-
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
597-
float_values.float_array[1] = utils::to_float<f4_t>(
598-
scale,
599-
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
600-
float_values.float_array[2] = utils::to_float<f4_t>(
601-
scale,
602-
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
603-
float_values.float_array[3] = utils::to_float<f4_t>(
604-
scale,
605-
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
606-
float_values.float_array[4] = utils::to_float<f4_t>(
607-
scale,
608-
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
609-
float_values.float_array[5] = utils::to_float<f4_t>(
610-
scale,
611-
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
612-
float_values.float_array[6] = utils::to_float<f4_t>(
613-
scale,
614-
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
615-
float_values.float_array[7] = utils::to_float<f4_t>(
616-
scale,
617-
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
428+
429+
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
430+
float_values.float_array[2 * idx] = utils::to_float<f4_t>(
431+
scale,
432+
f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>(
433+
Number<0>{}));
434+
435+
float_values.float_array[2 * idx + 1] = utils::to_float<f4_t>(
436+
scale,
437+
f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>(
438+
Number<1>{}));
439+
});
618440

619441
return float_values.float32_array;
620442
#endif

0 commit comments

Comments
 (0)