Skip to content

Commit d16a1dd

Browse files
authored
[NFI]: Cleanup ElementwiseOpToLLVM.cpp (#2973)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 0773668 commit d16a1dd

File tree

1 file changed

+76
-84
lines changed

1 file changed

+76
-84
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 76 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,8 @@ namespace {
1414
/* ----- FP8E5M2 ------ */
1515
// This data-type is the standard FP8E5M2 format
1616
static SmallVector<Value>
17-
Fp16_to_Fp8E5M2_func(Location loc, ConversionPatternRewriter &rewriter,
17+
Fp16_to_Fp8E5M2_RTNE(Location loc, ConversionPatternRewriter &rewriter,
1818
const SmallVector<Value> &v) {
19-
auto fp16x2VecTy = vec_ty(f16_ty, 2);
20-
Value fp16x2Vec0 = undef(fp16x2VecTy);
21-
Value fp16x2Vec1 = undef(fp16x2VecTy);
22-
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
23-
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
24-
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0));
25-
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1));
26-
27-
Value a0 = bitcast(fp16x2Vec0, i32_ty);
28-
Value a1 = bitcast(fp16x2Vec1, i32_ty);
29-
30-
auto fp8x4VecTy = vec_ty(i8_ty, 4);
31-
a0 = bitcast(a0, fp8x4VecTy);
32-
a1 = bitcast(a1, fp8x4VecTy);
33-
34-
return {extract_element(i8_ty, a0, i32_val(1)),
35-
extract_element(i8_ty, a0, i32_val(3)),
36-
extract_element(i8_ty, a1, i32_val(1)),
37-
extract_element(i8_ty, a1, i32_val(3))};
38-
}
39-
40-
static SmallVector<Value>
41-
Fp16_to_Fp8E5M2_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
42-
const SmallVector<Value> &v) {
43-
4419
Value val = zext(i32_ty, bitcast(v[0], i16_ty));
4520
Value sign = and_(i32_ty, val, i32_val(0x8000));
4621
Value nosign = and_(i32_ty, val, i32_val(0x7fff));
@@ -63,8 +38,32 @@ Fp16_to_Fp8E5M2_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
6338
}
6439

6540
static SmallVector<Value>
66-
Fp8E5M2_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
67-
const SmallVector<Value> &v) {
41+
Fp16_to_Fp8E5M2_RTZ(Location loc, ConversionPatternRewriter &rewriter,
42+
const SmallVector<Value> &v) {
43+
auto fp16x2VecTy = vec_ty(f16_ty, 2);
44+
Value fp16x2Vec0 = undef(fp16x2VecTy);
45+
Value fp16x2Vec1 = undef(fp16x2VecTy);
46+
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[0], i32_val(0));
47+
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v[1], i32_val(1));
48+
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[2], i32_val(0));
49+
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v[3], i32_val(1));
50+
51+
Value a0 = bitcast(fp16x2Vec0, i32_ty);
52+
Value a1 = bitcast(fp16x2Vec1, i32_ty);
53+
54+
auto fp8x4VecTy = vec_ty(i8_ty, 4);
55+
a0 = bitcast(a0, fp8x4VecTy);
56+
a1 = bitcast(a1, fp8x4VecTy);
57+
58+
return {extract_element(i8_ty, a0, i32_val(1)),
59+
extract_element(i8_ty, a0, i32_val(3)),
60+
extract_element(i8_ty, a1, i32_val(1)),
61+
extract_element(i8_ty, a1, i32_val(3))};
62+
}
63+
64+
static SmallVector<Value> Fp8E5M2_to_Fp16(Location loc,
65+
ConversionPatternRewriter &rewriter,
66+
const SmallVector<Value> &v) {
6867
auto fp8x4VecTy = vec_ty(i8_ty, 4);
6968
Value a0 = undef(fp8x4VecTy);
7069
a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0));
@@ -89,9 +88,9 @@ Fp8E5M2_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
8988
extract_element(f16_ty, fp16x2Vec1, i32_val(1))};
9089
}
9190

92-
static SmallVector<Value>
93-
Fp8E5M2_to_Bf16_func(Location loc, ConversionPatternRewriter &rewriter,
94-
const SmallVector<Value> &v) {
91+
static SmallVector<Value> Fp8E5M2_to_Bf16(Location loc,
92+
ConversionPatternRewriter &rewriter,
93+
const SmallVector<Value> &v) {
9594
auto fp8x4VecTy = vec_ty(i8_ty, 4);
9695
Value a0 = undef(fp8x4VecTy);
9796
a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0));
@@ -178,9 +177,9 @@ Fp8E5M2_to_Bf16_func(Location loc, ConversionPatternRewriter &rewriter,
178177
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))};
179178
}
180179

181-
static SmallVector<Value>
182-
Bf16_to_Fp8E5M2_func(Location loc, ConversionPatternRewriter &rewriter,
183-
const SmallVector<Value> &v) {
180+
static SmallVector<Value> Bf16_to_Fp8E5M2(Location loc,
181+
ConversionPatternRewriter &rewriter,
182+
const SmallVector<Value> &v) {
184183
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
185184
Value bf16x2Vec0 = undef(bf16x2VecTy);
186185
Value bf16x2Vec1 = undef(bf16x2VecTy);
@@ -259,8 +258,8 @@ Bf16_to_Fp8E5M2_func(Location loc, ConversionPatternRewriter &rewriter,
259258
}
260259

261260
static SmallVector<Value>
262-
Bf16_to_Fp8E5M2_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
263-
const SmallVector<Value> &v) {
261+
Bf16_to_Fp8E5M2_RTNE(Location loc, ConversionPatternRewriter &rewriter,
262+
const SmallVector<Value> &v) {
264263
Value val = zext(i32_ty, bitcast(v[0], i16_ty));
265264
Value sign = and_(i32_ty, val, i32_val(0x8000));
266265
Value nosign = and_(i32_ty, val, i32_val(0x7fff));
@@ -320,8 +319,8 @@ Bf16_to_Fp8E5M2_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
320319
// - has multiple nans (when all exponent bits are 1)
321320
// - has an exponent bias of 15 (vs. 7 for fp8e4m3)
322321
static SmallVector<Value>
323-
Fp8E4M3B15_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
324-
const SmallVector<Value> &v) {
322+
Fp8E4M3B15_to_Fp16(Location loc, ConversionPatternRewriter &rewriter,
323+
const SmallVector<Value> &v) {
325324
auto fp8x4VecTy = vec_ty(i8_ty, 4);
326325
Value a0 = undef(fp8x4VecTy);
327326
a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0));
@@ -357,8 +356,8 @@ Fp8E4M3B15_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
357356
}
358357

359358
static SmallVector<Value>
360-
Fp16_to_Fp8E4M3B15_func(Location loc, ConversionPatternRewriter &rewriter,
361-
const SmallVector<Value> &v) {
359+
Fp16_to_Fp8E4M3B15(Location loc, ConversionPatternRewriter &rewriter,
360+
const SmallVector<Value> &v) {
362361
auto fp16x2VecTy = vec_ty(f16_ty, 2);
363362
Value fp16x2Vec0 = undef(fp16x2VecTy);
364363
Value fp16x2Vec1 = undef(fp16x2VecTy);
@@ -404,9 +403,9 @@ Fp16_to_Fp8E4M3B15_func(Location loc, ConversionPatternRewriter &rewriter,
404403
// has more than a single NaN values.
405404

406405
// Fp8E4M3 -> Fp16 (packed)
407-
static SmallVector<Value>
408-
Fp8E4M3Nv_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
409-
const SmallVector<Value> &v) {
406+
static SmallVector<Value> Fp8E4M3Nv_to_Fp16(Location loc,
407+
ConversionPatternRewriter &rewriter,
408+
const SmallVector<Value> &v) {
410409
auto fp8x4VecTy = vec_ty(i8_ty, 4);
411410
Value a0 = undef(fp8x4VecTy);
412411
a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0));
@@ -478,9 +477,9 @@ Fp8E4M3Nv_to_Fp16_func(Location loc, ConversionPatternRewriter &rewriter,
478477
}
479478

480479
// Fp16 -> Fp8E4M3 (packed)
481-
static SmallVector<Value>
482-
Fp16_to_Fp8E4M3Nv_func(Location loc, ConversionPatternRewriter &rewriter,
483-
const SmallVector<Value> &v) {
480+
static SmallVector<Value> Fp16_to_Fp8E4M3Nv(Location loc,
481+
ConversionPatternRewriter &rewriter,
482+
const SmallVector<Value> &v) {
484483
auto fp16x2VecTy = vec_ty(f16_ty, 2);
485484
Value fp16x2Vec0 = undef(fp16x2VecTy);
486485

@@ -503,8 +502,8 @@ Fp16_to_Fp8E4M3Nv_func(Location loc, ConversionPatternRewriter &rewriter,
503502
}
504503

505504
static SmallVector<Value>
506-
Fp16_to_Fp8E4M3Nv_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
507-
const SmallVector<Value> &v) {
505+
Fp16_to_Fp8E4M3Nv_RTNE(Location loc, ConversionPatternRewriter &rewriter,
506+
const SmallVector<Value> &v) {
508507
Value val = zext(i32_ty, bitcast(v[0], i16_ty));
509508
Value sign = and_(i32_ty, val, i32_val(0x8000));
510509
Value nosign = and_(i32_ty, val, i32_val(0x7fff));
@@ -556,9 +555,9 @@ Fp16_to_Fp8E4M3Nv_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
556555
return {extract_element(i8_ty, res, i32_val(1))};
557556
}
558557

559-
static SmallVector<Value>
560-
Fp8E4M3Nv_to_Bf16_func(Location loc, ConversionPatternRewriter &rewriter,
561-
const SmallVector<Value> &v) {
558+
static SmallVector<Value> Fp8E4M3Nv_to_Bf16(Location loc,
559+
ConversionPatternRewriter &rewriter,
560+
const SmallVector<Value> &v) {
562561
auto fp8x4VecTy = vec_ty(i8_ty, 4);
563562
Value a0 = undef(fp8x4VecTy);
564563
a0 = insert_element(fp8x4VecTy, a0, int_val(8, 0), i32_val(0));
@@ -656,9 +655,9 @@ Fp8E4M3Nv_to_Bf16_func(Location loc, ConversionPatternRewriter &rewriter,
656655
extract_element(bf16_ty, bf16x2Vec1, i32_val(1))};
657656
}
658657

659-
static SmallVector<Value>
660-
Bf16_to_Fp8E4M3Nv_func(Location loc, ConversionPatternRewriter &rewriter,
661-
const SmallVector<Value> &v) {
658+
static SmallVector<Value> Bf16_to_Fp8E4M3Nv(Location loc,
659+
ConversionPatternRewriter &rewriter,
660+
const SmallVector<Value> &v) {
662661
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
663662
Value bf16x2Vec0 = undef(bf16x2VecTy);
664663
Value bf16x2Vec1 = undef(bf16x2VecTy);
@@ -737,8 +736,8 @@ Bf16_to_Fp8E4M3Nv_func(Location loc, ConversionPatternRewriter &rewriter,
737736
}
738737

739738
static SmallVector<Value>
740-
Bf16_to_Fp8E4M3Nv_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
741-
const SmallVector<Value> &v) {
739+
Bf16_to_Fp8E4M3Nv_RTNE(Location loc, ConversionPatternRewriter &rewriter,
740+
const SmallVector<Value> &v) {
742741
Value val = zext(i32_ty, bitcast(v[0], i16_ty));
743742
Value sign = and_(i32_ty, val, i32_val(0x8000));
744743
Value nosign = and_(i32_ty, val, i32_val(0x7fff));
@@ -790,9 +789,9 @@ Bf16_to_Fp8E4M3Nv_RTNE_func(Location loc, ConversionPatternRewriter &rewriter,
790789
return {extract_element(i8_ty, res, i32_val(1))};
791790
}
792791

793-
static SmallVector<Value> Bf16_to_Fp16_func(Location loc,
794-
ConversionPatternRewriter &rewriter,
795-
const SmallVector<Value> &v) {
792+
static SmallVector<Value> Bf16_to_Fp16(Location loc,
793+
ConversionPatternRewriter &rewriter,
794+
const SmallVector<Value> &v) {
796795
auto bf16x2VecTy = vec_ty(bf16_ty, 2);
797796

798797
Value bf16x2Vec = undef(bf16x2VecTy);
@@ -997,39 +996,34 @@ struct FpToFpOpConversion
997996
std::pair<ConverterT, size_t>>
998997
srcMap = {
999998
// F8 -> F16
1000-
{{F8E4M3B15TyID, F16TyID, undefRounding},
1001-
{Fp8E4M3B15_to_Fp16_func, 4}},
1002-
{{F8E4M3TyID, F16TyID, undefRounding}, {Fp8E4M3Nv_to_Fp16_func, 2}},
1003-
{{F8E5M2TyID, F16TyID, undefRounding}, {Fp8E5M2_to_Fp16_func, 4}},
999+
{{F8E4M3B15TyID, F16TyID, undefRounding}, {Fp8E4M3B15_to_Fp16, 4}},
1000+
{{F8E4M3TyID, F16TyID, undefRounding}, {Fp8E4M3Nv_to_Fp16, 2}},
1001+
{{F8E5M2TyID, F16TyID, undefRounding}, {Fp8E5M2_to_Fp16, 4}},
10041002
// F16 -> F8
10051003
{{F16TyID, F8E4M3B15TyID, RoundingMode::RTZ},
1006-
{Fp16_to_Fp8E4M3B15_func, 4}},
1004+
{Fp16_to_Fp8E4M3B15, 4}},
10071005
{{F16TyID, F8E4M3B15TyID, RoundingMode::RTNE},
10081006
// TODO: provide proper implementation for RTNE rounding.
1009-
{Fp16_to_Fp8E4M3B15_func, 4}},
1010-
{{F16TyID, F8E4M3TyID, RoundingMode::RTZ},
1011-
{Fp16_to_Fp8E4M3Nv_func, 2}},
1007+
{Fp16_to_Fp8E4M3B15, 4}},
1008+
{{F16TyID, F8E4M3TyID, RoundingMode::RTZ}, {Fp16_to_Fp8E4M3Nv, 2}},
10121009
{{F16TyID, F8E4M3TyID, RoundingMode::RTNE},
1013-
{Fp16_to_Fp8E4M3Nv_RTNE_func, 1}},
1010+
{Fp16_to_Fp8E4M3Nv_RTNE, 1}},
10141011
{{F16TyID, F8E5M2TyID, RoundingMode::RTZ},
1015-
{Fp16_to_Fp8E5M2_func, 4}},
1012+
{Fp16_to_Fp8E5M2_RTZ, 4}},
10161013
{{F16TyID, F8E5M2TyID, RoundingMode::RTNE},
1017-
{Fp16_to_Fp8E5M2_RTNE_func, 1}},
1014+
{Fp16_to_Fp8E5M2_RTNE, 1}},
10181015
// F8 -> BF16
1019-
{{F8E5M2TyID, BF16TyID, undefRounding}, {Fp8E5M2_to_Bf16_func, 4}},
1020-
{{F8E4M3TyID, BF16TyID, undefRounding},
1021-
{Fp8E4M3Nv_to_Bf16_func, 4}},
1016+
{{F8E5M2TyID, BF16TyID, undefRounding}, {Fp8E5M2_to_Bf16, 4}},
1017+
{{F8E4M3TyID, BF16TyID, undefRounding}, {Fp8E4M3Nv_to_Bf16, 4}},
10221018
// BF16 -> F8
1023-
{{BF16TyID, F8E5M2TyID, RoundingMode::RTZ},
1024-
{Bf16_to_Fp8E5M2_func, 4}},
1019+
{{BF16TyID, F8E5M2TyID, RoundingMode::RTZ}, {Bf16_to_Fp8E5M2, 4}},
10251020
{{BF16TyID, F8E5M2TyID, RoundingMode::RTNE},
1026-
{Bf16_to_Fp8E5M2_RTNE_func, 1}},
1027-
{{BF16TyID, F8E4M3TyID, RoundingMode::RTZ},
1028-
{Bf16_to_Fp8E4M3Nv_func, 4}},
1021+
{Bf16_to_Fp8E5M2_RTNE, 1}},
1022+
{{BF16TyID, F8E4M3TyID, RoundingMode::RTZ}, {Bf16_to_Fp8E4M3Nv, 4}},
10291023
{{BF16TyID, F8E4M3TyID, RoundingMode::RTNE},
1030-
{Bf16_to_Fp8E4M3Nv_RTNE_func, 1}},
1024+
{Bf16_to_Fp8E4M3Nv_RTNE, 1}},
10311025
// BF16 -> F16
1032-
{{BF16TyID, F16TyID, undefRounding}, {Bf16_to_Fp16_func, 2}},
1026+
{{BF16TyID, F16TyID, undefRounding}, {Bf16_to_Fp16, 2}},
10331027
};
10341028

10351029
std::tuple<TypeID, TypeID, RoundingMode> key = {
@@ -1097,6 +1091,7 @@ struct FpToFpOpConversion
10971091
auto [cvtFunc, numElements] =
10981092
getConversionFunc(srcType, dstType, roundingMode);
10991093
SmallVector<Value> inVals;
1094+
inVals.reserve(std::min(numElements, operands.size()));
11001095
for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) {
11011096
inVals.push_back(operands[i][0]);
11021097
}
@@ -1323,9 +1318,8 @@ struct TruncFOpConversion
13231318
return {// Trunc uses the default rounding mode: RTNE
13241319
intel::convertFp32ToBf16(loc, rewriter, operands[0][0],
13251320
RoundingMode::RTNE)};
1326-
} else {
1327-
return {rewriter.create<LLVM::FPTruncOp>(loc, elemTy, operands[0][0])};
13281321
}
1322+
return {rewriter.create<LLVM::FPTruncOp>(loc, elemTy, operands[0][0])};
13291323
}
13301324
};
13311325

@@ -1488,7 +1482,6 @@ void populateElementwiseOpToLLVMPatterns(
14881482
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
14891483
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
14901484
PatternBenefit benefit) {
1491-
using namespace mlir::triton::gpu;
14921485

14931486
patterns.add<PreciseSqrtOpConversion>(typeConverter, axisInfoAnalysis,
14941487
benefit);
@@ -1511,7 +1504,6 @@ void populateElementwiseOpToLLVMPatterns(
15111504
patterns.add<ExtFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
15121505
patterns.add<TruncFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
15131506
patterns.add<FPToSIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
1514-
15151507
patterns.add<SIToFPOpConversion>(typeConverter, axisInfoAnalysis, benefit);
15161508
patterns.add<FpToFpOpConversion>(typeConverter, axisInfoAnalysis, benefit);
15171509

0 commit comments

Comments
 (0)