@@ -14,33 +14,8 @@ namespace {
1414/* ----- FP8E5M2 ------ */
1515// This data-type is the standard FP8E5M2 format
1616static SmallVector<Value>
17- Fp16_to_Fp8E5M2_func (Location loc, ConversionPatternRewriter &rewriter,
17+ Fp16_to_Fp8E5M2_RTNE (Location loc, ConversionPatternRewriter &rewriter,
1818 const SmallVector<Value> &v) {
19- auto fp16x2VecTy = vec_ty (f16_ty, 2 );
20- Value fp16x2Vec0 = undef (fp16x2VecTy);
21- Value fp16x2Vec1 = undef (fp16x2VecTy);
22- fp16x2Vec0 = insert_element (fp16x2VecTy, fp16x2Vec0, v[0 ], i32_val (0 ));
23- fp16x2Vec0 = insert_element (fp16x2VecTy, fp16x2Vec0, v[1 ], i32_val (1 ));
24- fp16x2Vec1 = insert_element (fp16x2VecTy, fp16x2Vec1, v[2 ], i32_val (0 ));
25- fp16x2Vec1 = insert_element (fp16x2VecTy, fp16x2Vec1, v[3 ], i32_val (1 ));
26-
27- Value a0 = bitcast (fp16x2Vec0, i32_ty);
28- Value a1 = bitcast (fp16x2Vec1, i32_ty);
29-
30- auto fp8x4VecTy = vec_ty (i8_ty, 4 );
31- a0 = bitcast (a0, fp8x4VecTy);
32- a1 = bitcast (a1, fp8x4VecTy);
33-
34- return {extract_element (i8_ty, a0, i32_val (1 )),
35- extract_element (i8_ty, a0, i32_val (3 )),
36- extract_element (i8_ty, a1, i32_val (1 )),
37- extract_element (i8_ty, a1, i32_val (3 ))};
38- }
39-
40- static SmallVector<Value>
41- Fp16_to_Fp8E5M2_RTNE_func (Location loc, ConversionPatternRewriter &rewriter,
42- const SmallVector<Value> &v) {
43-
4419 Value val = zext (i32_ty, bitcast (v[0 ], i16_ty));
4520 Value sign = and_ (i32_ty, val, i32_val (0x8000 ));
4621 Value nosign = and_ (i32_ty, val, i32_val (0x7fff ));
@@ -63,8 +38,32 @@ Fp16_to_Fp8E5M2_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
6338}
6439
6540static SmallVector<Value>
66- Fp8E5M2_to_Fp16_func (Location loc, ConversionPatternRewriter &rewriter,
67- const SmallVector<Value> &v) {
41+ Fp16_to_Fp8E5M2_RTZ (Location loc, ConversionPatternRewriter &rewriter,
42+ const SmallVector<Value> &v) {
43+ auto fp16x2VecTy = vec_ty (f16_ty, 2 );
44+ Value fp16x2Vec0 = undef (fp16x2VecTy);
45+ Value fp16x2Vec1 = undef (fp16x2VecTy);
46+ fp16x2Vec0 = insert_element (fp16x2VecTy, fp16x2Vec0, v[0 ], i32_val (0 ));
47+ fp16x2Vec0 = insert_element (fp16x2VecTy, fp16x2Vec0, v[1 ], i32_val (1 ));
48+ fp16x2Vec1 = insert_element (fp16x2VecTy, fp16x2Vec1, v[2 ], i32_val (0 ));
49+ fp16x2Vec1 = insert_element (fp16x2VecTy, fp16x2Vec1, v[3 ], i32_val (1 ));
50+
51+ Value a0 = bitcast (fp16x2Vec0, i32_ty);
52+ Value a1 = bitcast (fp16x2Vec1, i32_ty);
53+
54+ auto fp8x4VecTy = vec_ty (i8_ty, 4 );
55+ a0 = bitcast (a0, fp8x4VecTy);
56+ a1 = bitcast (a1, fp8x4VecTy);
57+
58+ return {extract_element (i8_ty, a0, i32_val (1 )),
59+ extract_element (i8_ty, a0, i32_val (3 )),
60+ extract_element (i8_ty, a1, i32_val (1 )),
61+ extract_element (i8_ty, a1, i32_val (3 ))};
62+ }
63+
64+ static SmallVector<Value> Fp8E5M2_to_Fp16 (Location loc,
65+ ConversionPatternRewriter &rewriter,
66+ const SmallVector<Value> &v) {
6867 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
6968 Value a0 = undef (fp8x4VecTy);
7069 a0 = insert_element (fp8x4VecTy, a0, int_val (8 , 0 ), i32_val (0 ));
@@ -89,9 +88,9 @@ Fp8E5M2_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
8988 extract_element (f16_ty, fp16x2Vec1, i32_val (1 ))};
9089}
9190
92- static SmallVector<Value>
93- Fp8E5M2_to_Bf16_func (Location loc, ConversionPatternRewriter &rewriter,
94- const SmallVector<Value> &v) {
91+ static SmallVector<Value> Fp8E5M2_to_Bf16 (Location loc,
92+ ConversionPatternRewriter &rewriter,
93+ const SmallVector<Value> &v) {
9594 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
9695 Value a0 = undef (fp8x4VecTy);
9796 a0 = insert_element (fp8x4VecTy, a0, int_val (8 , 0 ), i32_val (0 ));
@@ -178,9 +177,9 @@ Fp8E5M2_to_Bf16_func(Location loc, ConversionPatternRewriter &rewriter,
178177 extract_element (bf16_ty, bf16x2Vec1, i32_val (1 ))};
179178}
180179
181- static SmallVector<Value>
182- Bf16_to_Fp8E5M2_func (Location loc, ConversionPatternRewriter &rewriter,
183- const SmallVector<Value> &v) {
180+ static SmallVector<Value> Bf16_to_Fp8E5M2 (Location loc,
181+ ConversionPatternRewriter &rewriter,
182+ const SmallVector<Value> &v) {
184183 auto bf16x2VecTy = vec_ty (bf16_ty, 2 );
185184 Value bf16x2Vec0 = undef (bf16x2VecTy);
186185 Value bf16x2Vec1 = undef (bf16x2VecTy);
@@ -259,8 +258,8 @@ Bf16_to_Fp8E5M2_func(Location loc, ConversionPatternRewriter &rewriter,
259258}
260259
261260static SmallVector<Value>
262- Bf16_to_Fp8E5M2_RTNE_func (Location loc, ConversionPatternRewriter &rewriter,
263- const SmallVector<Value> &v) {
261+ Bf16_to_Fp8E5M2_RTNE (Location loc, ConversionPatternRewriter &rewriter,
262+ const SmallVector<Value> &v) {
264263 Value val = zext (i32_ty, bitcast (v[0 ], i16_ty));
265264 Value sign = and_ (i32_ty, val, i32_val (0x8000 ));
266265 Value nosign = and_ (i32_ty, val, i32_val (0x7fff ));
@@ -320,8 +319,8 @@ Bf16_to_Fp8E5M2_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
320319// - has multiple nans (when all exponent bits are 1)
321320// - has an exponent bias of 15 (vs. 7 for fp8e4m3)
322321static SmallVector<Value>
323- Fp8E4M3B15_to_Fp16_func (Location loc, ConversionPatternRewriter &rewriter,
324- const SmallVector<Value> &v) {
322+ Fp8E4M3B15_to_Fp16 (Location loc, ConversionPatternRewriter &rewriter,
323+ const SmallVector<Value> &v) {
325324 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
326325 Value a0 = undef (fp8x4VecTy);
327326 a0 = insert_element (fp8x4VecTy, a0, int_val (8 , 0 ), i32_val (0 ));
@@ -357,8 +356,8 @@ Fp8E4M3B15_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
357356}
358357
359358static SmallVector<Value>
360- Fp16_to_Fp8E4M3B15_func (Location loc, ConversionPatternRewriter &rewriter,
361- const SmallVector<Value> &v) {
359+ Fp16_to_Fp8E4M3B15 (Location loc, ConversionPatternRewriter &rewriter,
360+ const SmallVector<Value> &v) {
362361 auto fp16x2VecTy = vec_ty (f16_ty, 2 );
363362 Value fp16x2Vec0 = undef (fp16x2VecTy);
364363 Value fp16x2Vec1 = undef (fp16x2VecTy);
@@ -404,9 +403,9 @@ Fp16_to_Fp8E4M3B15_func(Location loc, ConversionPatternRewriter &rewriter,
404403// has more than a single NaN values.
405404
406405// Fp8E4M3 -> Fp16 (packed)
407- static SmallVector<Value>
408- Fp8E4M3Nv_to_Fp16_func (Location loc, ConversionPatternRewriter &rewriter,
409- const SmallVector<Value> &v) {
406+ static SmallVector<Value> Fp8E4M3Nv_to_Fp16 (Location loc,
407+ ConversionPatternRewriter &rewriter,
408+ const SmallVector<Value> &v) {
410409 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
411410 Value a0 = undef (fp8x4VecTy);
412411 a0 = insert_element (fp8x4VecTy, a0, int_val (8 , 0 ), i32_val (0 ));
@@ -478,9 +477,9 @@ Fp8E4M3Nv_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
478477}
479478
480479// Fp16 -> Fp8E4M3 (packed)
481- static SmallVector<Value>
482- Fp16_to_Fp8E4M3Nv_func (Location loc, ConversionPatternRewriter &rewriter,
483- const SmallVector<Value> &v) {
480+ static SmallVector<Value> Fp16_to_Fp8E4M3Nv (Location loc,
481+ ConversionPatternRewriter &rewriter,
482+ const SmallVector<Value> &v) {
484483 auto fp16x2VecTy = vec_ty (f16_ty, 2 );
485484 Value fp16x2Vec0 = undef (fp16x2VecTy);
486485
@@ -503,8 +502,8 @@ Fp16_to_Fp8E4M3Nv_func(Location loc, ConversionPatternRewriter &rewriter,
503502}
504503
505504static SmallVector<Value>
506- Fp16_to_Fp8E4M3Nv_RTNE_func (Location loc, ConversionPatternRewriter &rewriter,
507- const SmallVector<Value> &v) {
505+ Fp16_to_Fp8E4M3Nv_RTNE (Location loc, ConversionPatternRewriter &rewriter,
506+ const SmallVector<Value> &v) {
508507 Value val = zext (i32_ty, bitcast (v[0 ], i16_ty));
509508 Value sign = and_ (i32_ty, val, i32_val (0x8000 ));
510509 Value nosign = and_ (i32_ty, val, i32_val (0x7fff ));
@@ -556,9 +555,9 @@ Fp16_to_Fp8E4M3Nv_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
556555 return {extract_element (i8_ty, res, i32_val (1 ))};
557556}
558557
559- static SmallVector<Value>
560- Fp8E4M3Nv_to_Bf16_func (Location loc, ConversionPatternRewriter &rewriter,
561- const SmallVector<Value> &v) {
558+ static SmallVector<Value> Fp8E4M3Nv_to_Bf16 (Location loc,
559+ ConversionPatternRewriter &rewriter,
560+ const SmallVector<Value> &v) {
562561 auto fp8x4VecTy = vec_ty (i8_ty, 4 );
563562 Value a0 = undef (fp8x4VecTy);
564563 a0 = insert_element (fp8x4VecTy, a0, int_val (8 , 0 ), i32_val (0 ));
@@ -656,9 +655,9 @@ Fp8E4M3Nv_to_Bf16_func(Location loc, ConversionPatternRewriter &rewriter,
656655 extract_element (bf16_ty, bf16x2Vec1, i32_val (1 ))};
657656}
658657
659- static SmallVector<Value>
660- Bf16_to_Fp8E4M3Nv_func (Location loc, ConversionPatternRewriter &rewriter,
661- const SmallVector<Value> &v) {
658+ static SmallVector<Value> Bf16_to_Fp8E4M3Nv (Location loc,
659+ ConversionPatternRewriter &rewriter,
660+ const SmallVector<Value> &v) {
662661 auto bf16x2VecTy = vec_ty (bf16_ty, 2 );
663662 Value bf16x2Vec0 = undef (bf16x2VecTy);
664663 Value bf16x2Vec1 = undef (bf16x2VecTy);
@@ -737,8 +736,8 @@ Bf16_to_Fp8E4M3Nv_func(Location loc, ConversionPatternRewriter &rewriter,
737736}
738737
739738static SmallVector<Value>
740- Bf16_to_Fp8E4M3Nv_RTNE_func (Location loc, ConversionPatternRewriter &rewriter,
741- const SmallVector<Value> &v) {
739+ Bf16_to_Fp8E4M3Nv_RTNE (Location loc, ConversionPatternRewriter &rewriter,
740+ const SmallVector<Value> &v) {
742741 Value val = zext (i32_ty, bitcast (v[0 ], i16_ty));
743742 Value sign = and_ (i32_ty, val, i32_val (0x8000 ));
744743 Value nosign = and_ (i32_ty, val, i32_val (0x7fff ));
@@ -790,9 +789,9 @@ Bf16_to_Fp8E4M3Nv_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
790789 return {extract_element (i8_ty, res, i32_val (1 ))};
791790}
792791
793- static SmallVector<Value> Bf16_to_Fp16_func (Location loc,
794- ConversionPatternRewriter &rewriter,
795- const SmallVector<Value> &v) {
792+ static SmallVector<Value> Bf16_to_Fp16 (Location loc,
793+ ConversionPatternRewriter &rewriter,
794+ const SmallVector<Value> &v) {
796795 auto bf16x2VecTy = vec_ty (bf16_ty, 2 );
797796
798797 Value bf16x2Vec = undef (bf16x2VecTy);
@@ -997,39 +996,34 @@ struct FpToFpOpConversion
997996 std::pair<ConverterT, size_t >>
998997 srcMap = {
999998 // F8 -> F16
1000- {{F8E4M3B15TyID, F16TyID, undefRounding},
1001- {Fp8E4M3B15_to_Fp16_func, 4 }},
1002- {{F8E4M3TyID, F16TyID, undefRounding}, {Fp8E4M3Nv_to_Fp16_func, 2 }},
1003- {{F8E5M2TyID, F16TyID, undefRounding}, {Fp8E5M2_to_Fp16_func, 4 }},
999+ {{F8E4M3B15TyID, F16TyID, undefRounding}, {Fp8E4M3B15_to_Fp16, 4 }},
1000+ {{F8E4M3TyID, F16TyID, undefRounding}, {Fp8E4M3Nv_to_Fp16, 2 }},
1001+ {{F8E5M2TyID, F16TyID, undefRounding}, {Fp8E5M2_to_Fp16, 4 }},
10041002 // F16 -> F8
10051003 {{F16TyID, F8E4M3B15TyID, RoundingMode::RTZ},
1006- {Fp16_to_Fp8E4M3B15_func , 4 }},
1004+ {Fp16_to_Fp8E4M3B15 , 4 }},
10071005 {{F16TyID, F8E4M3B15TyID, RoundingMode::RTNE},
10081006 // TODO: provide proper implementation for RTNE rounding.
1009- {Fp16_to_Fp8E4M3B15_func, 4 }},
1010- {{F16TyID, F8E4M3TyID, RoundingMode::RTZ},
1011- {Fp16_to_Fp8E4M3Nv_func, 2 }},
1007+ {Fp16_to_Fp8E4M3B15, 4 }},
1008+ {{F16TyID, F8E4M3TyID, RoundingMode::RTZ}, {Fp16_to_Fp8E4M3Nv, 2 }},
10121009 {{F16TyID, F8E4M3TyID, RoundingMode::RTNE},
1013- {Fp16_to_Fp8E4M3Nv_RTNE_func , 1 }},
1010+ {Fp16_to_Fp8E4M3Nv_RTNE , 1 }},
10141011 {{F16TyID, F8E5M2TyID, RoundingMode::RTZ},
1015- {Fp16_to_Fp8E5M2_func , 4 }},
1012+ {Fp16_to_Fp8E5M2_RTZ , 4 }},
10161013 {{F16TyID, F8E5M2TyID, RoundingMode::RTNE},
1017- {Fp16_to_Fp8E5M2_RTNE_func , 1 }},
1014+ {Fp16_to_Fp8E5M2_RTNE , 1 }},
10181015 // F8 -> BF16
1019- {{F8E5M2TyID, BF16TyID, undefRounding}, {Fp8E5M2_to_Bf16_func, 4 }},
1020- {{F8E4M3TyID, BF16TyID, undefRounding},
1021- {Fp8E4M3Nv_to_Bf16_func, 4 }},
1016+ {{F8E5M2TyID, BF16TyID, undefRounding}, {Fp8E5M2_to_Bf16, 4 }},
1017+ {{F8E4M3TyID, BF16TyID, undefRounding}, {Fp8E4M3Nv_to_Bf16, 4 }},
10221018 // BF16 -> F8
1023- {{BF16TyID, F8E5M2TyID, RoundingMode::RTZ},
1024- {Bf16_to_Fp8E5M2_func, 4 }},
1019+ {{BF16TyID, F8E5M2TyID, RoundingMode::RTZ}, {Bf16_to_Fp8E5M2, 4 }},
10251020 {{BF16TyID, F8E5M2TyID, RoundingMode::RTNE},
1026- {Bf16_to_Fp8E5M2_RTNE_func, 1 }},
1027- {{BF16TyID, F8E4M3TyID, RoundingMode::RTZ},
1028- {Bf16_to_Fp8E4M3Nv_func, 4 }},
1021+ {Bf16_to_Fp8E5M2_RTNE, 1 }},
1022+ {{BF16TyID, F8E4M3TyID, RoundingMode::RTZ}, {Bf16_to_Fp8E4M3Nv, 4 }},
10291023 {{BF16TyID, F8E4M3TyID, RoundingMode::RTNE},
1030- {Bf16_to_Fp8E4M3Nv_RTNE_func , 1 }},
1024+ {Bf16_to_Fp8E4M3Nv_RTNE , 1 }},
10311025 // BF16 -> F16
1032- {{BF16TyID, F16TyID, undefRounding}, {Bf16_to_Fp16_func , 2 }},
1026+ {{BF16TyID, F16TyID, undefRounding}, {Bf16_to_Fp16 , 2 }},
10331027 };
10341028
10351029 std::tuple<TypeID, TypeID, RoundingMode> key = {
@@ -1097,6 +1091,7 @@ struct FpToFpOpConversion
10971091 auto [cvtFunc, numElements] =
10981092 getConversionFunc (srcType, dstType, roundingMode);
10991093 SmallVector<Value> inVals;
1094+ inVals.reserve (std::min (numElements, operands.size ()));
11001095 for (unsigned i = 0 ; i < std::min (numElements, operands.size ()); i++) {
11011096 inVals.push_back (operands[i][0 ]);
11021097 }
@@ -1323,9 +1318,8 @@ struct TruncFOpConversion
13231318 return {// Trunc uses the default rounding mode: RTNE
13241319 intel::convertFp32ToBf16 (loc, rewriter, operands[0 ][0 ],
13251320 RoundingMode::RTNE)};
1326- } else {
1327- return {rewriter.create <LLVM::FPTruncOp>(loc, elemTy, operands[0 ][0 ])};
13281321 }
1322+ return {rewriter.create <LLVM::FPTruncOp>(loc, elemTy, operands[0 ][0 ])};
13291323 }
13301324};
13311325
@@ -1488,7 +1482,6 @@ void populateElementwiseOpToLLVMPatterns(
14881482 LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
14891483 ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
14901484 PatternBenefit benefit) {
1491- using namespace mlir ::triton::gpu;
14921485
14931486 patterns.add <PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
14941487 benefit);
@@ -1511,7 +1504,6 @@ void populateElementwiseOpToLLVMPatterns(
15111504 patterns.add <ExtFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
15121505 patterns.add <TruncFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
15131506 patterns.add <FPToSIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
1514-
15151507 patterns.add <SIToFPOpConversion>(typeConverter, axisInfoAnalysis, benefit);
15161508 patterns.add <FpToFpOpConversion>(typeConverter, axisInfoAnalysis, benefit);
15171509
0 commit comments