@@ -377,12 +377,15 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
377377 f4x2_t f4x2_array[4 ];
378378 } value{};
379379 value.f4x2_array [0 ] = x;
380- return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (value.bitwise , type_convert<float >(scale), 0 );
380+ float2_t tmp =
381+ __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (value.bitwise , type_convert<float >(scale), 0 );
382+ // permute high bits and low bits to match the order of the original vector
383+ return float2_t {tmp[1 ], tmp[0 ]};
381384#else
382385 float2_t ret{utils::to_float<f4_t >(
383- scale, x.template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{})),
386+ scale, x.template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{})),
384387 utils::to_float<f4_t >(
385- scale, x.template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}))};
388+ scale, x.template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}))};
386389 return ret;
387390#endif
388391}
@@ -398,109 +401,16 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
398401 f4x32_t f4x32_array;
399402 f4x2_t fp4x2[16 ];
400403 } value{x};
401- union
402- {
403- uint32_t bitwise;
404- f4x2_t f4x2_array[4 ];
405- } bitwise_value{};
406404 float2_t op;
407405 float32_t ret;
408- // TODO: pack in a loop
409- bitwise_value.f4x2_array [0 ] = value.fp4x2 [0 ];
410- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
411- bitwise_value.bitwise , type_convert<float >(scale), 0 );
412- ret[0 ] = op[0 ];
413- ret[1 ] = op[1 ];
414-
415- bitwise_value.f4x2_array [0 ] = value.fp4x2 [1 ];
416- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
417- bitwise_value.bitwise , type_convert<float >(scale), 0 );
418- ret[2 ] = op[0 ];
419- ret[3 ] = op[1 ];
420-
421- bitwise_value.f4x2_array [0 ] = value.fp4x2 [2 ];
422- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
423- bitwise_value.bitwise , type_convert<float >(scale), 0 );
424- ret[4 ] = op[0 ];
425- ret[5 ] = op[1 ];
426-
427- bitwise_value.f4x2_array [0 ] = value.fp4x2 [3 ];
428- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
429- bitwise_value.bitwise , type_convert<float >(scale), 0 );
430- ret[6 ] = op[0 ];
431- ret[7 ] = op[1 ];
432-
433- bitwise_value.f4x2_array [0 ] = value.fp4x2 [4 ];
434- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
435- bitwise_value.bitwise , type_convert<float >(scale), 0 );
436- ret[8 ] = op[0 ];
437- ret[9 ] = op[1 ];
438-
439- bitwise_value.f4x2_array [0 ] = value.fp4x2 [5 ];
440- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
441- bitwise_value.bitwise , type_convert<float >(scale), 0 );
442- ret[10 ] = op[0 ];
443- ret[11 ] = op[1 ];
444-
445- bitwise_value.f4x2_array [0 ] = value.fp4x2 [6 ];
446- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
447- bitwise_value.bitwise , type_convert<float >(scale), 0 );
448- ret[12 ] = op[0 ];
449- ret[13 ] = op[1 ];
450-
451- bitwise_value.f4x2_array [0 ] = value.fp4x2 [7 ];
452- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
453- bitwise_value.bitwise , type_convert<float >(scale), 0 );
454- ret[14 ] = op[0 ];
455- ret[15 ] = op[1 ];
456-
457- bitwise_value.f4x2_array [0 ] = value.fp4x2 [8 ];
458- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
459- bitwise_value.bitwise , type_convert<float >(scale), 0 );
460- ret[16 ] = op[0 ];
461- ret[17 ] = op[1 ];
462-
463- bitwise_value.f4x2_array [0 ] = value.fp4x2 [9 ];
464- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
465- bitwise_value.bitwise , type_convert<float >(scale), 0 );
466- ret[18 ] = op[0 ];
467- ret[19 ] = op[1 ];
468-
469- bitwise_value.f4x2_array [0 ] = value.fp4x2 [10 ];
470- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
471- bitwise_value.bitwise , type_convert<float >(scale), 0 );
472- ret[20 ] = op[0 ];
473- ret[21 ] = op[1 ];
474-
475- bitwise_value.f4x2_array [0 ] = value.fp4x2 [11 ];
476- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
477- bitwise_value.bitwise , type_convert<float >(scale), 0 );
478- ret[22 ] = op[0 ];
479- ret[23 ] = op[1 ];
480-
481- bitwise_value.f4x2_array [0 ] = value.fp4x2 [12 ];
482- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
483- bitwise_value.bitwise , type_convert<float >(scale), 0 );
484- ret[24 ] = op[0 ];
485- ret[25 ] = op[1 ];
486-
487- bitwise_value.f4x2_array [0 ] = value.fp4x2 [13 ];
488- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
489- bitwise_value.bitwise , type_convert<float >(scale), 0 );
490- ret[26 ] = op[0 ];
491- ret[27 ] = op[1 ];
492-
493- bitwise_value.f4x2_array [0 ] = value.fp4x2 [14 ];
494- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
495- bitwise_value.bitwise , type_convert<float >(scale), 0 );
496- ret[28 ] = op[0 ];
497- ret[29 ] = op[1 ];
498-
499- bitwise_value.f4x2_array [0 ] = value.fp4x2 [15 ];
500- op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (
501- bitwise_value.bitwise , type_convert<float >(scale), 0 );
502- ret[30 ] = op[0 ];
503- ret[31 ] = op[1 ];
406+ float f_scale = type_convert<float >(scale);
407+
408+ ck::static_for<0 , 32 / 2 , 1 >{}([&](auto idx) {
409+ op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4 (value.fp4x2 [idx], f_scale, 0 );
410+ // permute high bits and low bits to match the order of the original vector
411+ ret[2 * idx] = op[1 ];
412+ ret[2 * idx + 1 ] = op[0 ];
413+ });
504414
505415 return ret;
506416#else
@@ -515,106 +425,18 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
515425 f4x2_t f4x2_array[16 ];
516426 f4x32_t f4x32_array;
517427 } f4_values{bit_cast<__uint128_t >(x)};
518- // TODO: pack in a loop
519- float_values.float_array [0 ] = utils::to_float<f4_t >(
520- scale,
521- f4_values.f4x2_array [0 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
522- float_values.float_array [1 ] = utils::to_float<f4_t >(
523- scale,
524- f4_values.f4x2_array [0 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
525- float_values.float_array [2 ] = utils::to_float<f4_t >(
526- scale,
527- f4_values.f4x2_array [1 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
528- float_values.float_array [3 ] = utils::to_float<f4_t >(
529- scale,
530- f4_values.f4x2_array [1 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
531- float_values.float_array [4 ] = utils::to_float<f4_t >(
532- scale,
533- f4_values.f4x2_array [2 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
534- float_values.float_array [5 ] = utils::to_float<f4_t >(
535- scale,
536- f4_values.f4x2_array [2 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
537- float_values.float_array [6 ] = utils::to_float<f4_t >(
538- scale,
539- f4_values.f4x2_array [3 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
540- float_values.float_array [7 ] = utils::to_float<f4_t >(
541- scale,
542- f4_values.f4x2_array [3 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
543-
544- float_values.float_array [0 ] = utils::to_float<f4_t >(
545- scale,
546- f4_values.f4x2_array [4 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
547- float_values.float_array [1 ] = utils::to_float<f4_t >(
548- scale,
549- f4_values.f4x2_array [4 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
550- float_values.float_array [2 ] = utils::to_float<f4_t >(
551- scale,
552- f4_values.f4x2_array [5 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
553- float_values.float_array [3 ] = utils::to_float<f4_t >(
554- scale,
555- f4_values.f4x2_array [5 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
556- float_values.float_array [4 ] = utils::to_float<f4_t >(
557- scale,
558- f4_values.f4x2_array [6 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
559- float_values.float_array [5 ] = utils::to_float<f4_t >(
560- scale,
561- f4_values.f4x2_array [6 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
562- float_values.float_array [6 ] = utils::to_float<f4_t >(
563- scale,
564- f4_values.f4x2_array [7 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
565- float_values.float_array [7 ] = utils::to_float<f4_t >(
566- scale,
567- f4_values.f4x2_array [7 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
568-
569- float_values.float_array [0 ] = utils::to_float<f4_t >(
570- scale,
571- f4_values.f4x2_array [8 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
572- float_values.float_array [1 ] = utils::to_float<f4_t >(
573- scale,
574- f4_values.f4x2_array [8 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
575- float_values.float_array [2 ] = utils::to_float<f4_t >(
576- scale,
577- f4_values.f4x2_array [9 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
578- float_values.float_array [3 ] = utils::to_float<f4_t >(
579- scale,
580- f4_values.f4x2_array [9 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
581- float_values.float_array [4 ] = utils::to_float<f4_t >(
582- scale,
583- f4_values.f4x2_array [10 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
584- float_values.float_array [5 ] = utils::to_float<f4_t >(
585- scale,
586- f4_values.f4x2_array [10 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
587- float_values.float_array [6 ] = utils::to_float<f4_t >(
588- scale,
589- f4_values.f4x2_array [11 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
590- float_values.float_array [7 ] = utils::to_float<f4_t >(
591- scale,
592- f4_values.f4x2_array [11 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
593-
594- float_values.float_array [0 ] = utils::to_float<f4_t >(
595- scale,
596- f4_values.f4x2_array [12 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
597- float_values.float_array [1 ] = utils::to_float<f4_t >(
598- scale,
599- f4_values.f4x2_array [12 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
600- float_values.float_array [2 ] = utils::to_float<f4_t >(
601- scale,
602- f4_values.f4x2_array [13 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
603- float_values.float_array [3 ] = utils::to_float<f4_t >(
604- scale,
605- f4_values.f4x2_array [13 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
606- float_values.float_array [4 ] = utils::to_float<f4_t >(
607- scale,
608- f4_values.f4x2_array [14 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
609- float_values.float_array [5 ] = utils::to_float<f4_t >(
610- scale,
611- f4_values.f4x2_array [14 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
612- float_values.float_array [6 ] = utils::to_float<f4_t >(
613- scale,
614- f4_values.f4x2_array [15 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<0 >{}));
615- float_values.float_array [7 ] = utils::to_float<f4_t >(
616- scale,
617- f4_values.f4x2_array [15 ].template AsType <f4x2_pk_t >()[Number<0 >{}].unpack <>(Number<1 >{}));
428+
429+ ck::static_for<0 , 32 / 2 , 1 >{}([&](auto idx) {
430+ float_values.float_array [2 * idx] = utils::to_float<f4_t >(
431+ scale,
432+ f4_values.f4x2_array [idx].template AsType <f4x2_pk_t >()[Number<0 >{}].template unpack <>(
433+ Number<0 >{}));
434+
435+ float_values.float_array [2 * idx + 1 ] = utils::to_float<f4_t >(
436+ scale,
437+ f4_values.f4x2_array [idx].template AsType <f4x2_pk_t >()[Number<0 >{}].template unpack <>(
438+ Number<1 >{}));
439+ });
618440
619441 return float_values.float32_array ;
620442#endif
0 commit comments