Skip to content

Commit 3b860a6

Browse files
authored
Update ElementwiseOpToLLVM.cpp (#2981)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 0e3a9c0 commit 3b860a6

File tree

1 file changed

+3
-68
lines changed

1 file changed

+3
-68
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 3 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
99

1010
using mlir::triton::gpu::ElementwiseOpConversionBase;
11+
using mlir::triton::gpu::MultipleOperandsRange;
1112

1213
namespace {
1314

@@ -859,73 +860,10 @@ inline Type getElementType(Value value) {
859860
return type;
860861
}
861862

862-
inline SmallVector<Value> unpackI32(const SmallVector<Value> &inValues,
863-
Type srcTy,
864-
ConversionPatternRewriter &rewriter,
865-
Location loc,
866-
TypeConverter *typeConverter) {
867-
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
868-
if (!tensorTy)
869-
return inValues;
870-
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
871-
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
872-
return inValues;
873-
SmallVector<Value> outValues;
874-
for (auto v : inValues) {
875-
// cast i32 to appropriate eltType vector and extract elements
876-
auto eltType = typeConverter->convertType(tensorTy.getElementType());
877-
auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth());
878-
auto vec = bitcast(v, vecType);
879-
for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) {
880-
outValues.push_back(extract_element(vec, i32_val(i)));
881-
}
882-
}
883-
return outValues;
884-
}
885-
886-
inline SmallVector<Value> packI32(const SmallVector<Value> &inValues,
887-
Type srcTy,
888-
ConversionPatternRewriter &rewriter,
889-
Location loc, TypeConverter *typeConverter) {
890-
auto tensorTy = dyn_cast<RankedTensorType>(srcTy);
891-
if (!tensorTy)
892-
return inValues;
893-
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
894-
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
895-
return inValues;
896-
SmallVector<Value> outValues;
897-
auto eltType = typeConverter->convertType(tensorTy.getElementType());
898-
int vecWidth = 32 / eltType.getIntOrFloatBitWidth();
899-
auto vecType = vec_ty(eltType, vecWidth);
900-
for (int i = 0; i < inValues.size(); i += vecWidth) {
901-
Value vec = undef(vecType);
902-
for (int j = 0; j < vecWidth; j++) {
903-
vec = insert_element(vec, inValues[i + j], i32_val(j));
904-
}
905-
outValues.push_back(bitcast(vec, i32_ty));
906-
}
907-
return outValues;
908-
}
909-
910863
typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
911864
const SmallVector<Value> &)>
912865
ConverterT;
913866

914-
class MultipleOperandsRange
915-
: public iterator_range<SmallVector<SmallVector<Value>>::iterator> {
916-
using ContainerT = SmallVector<SmallVector<Value>>;
917-
918-
public:
919-
using iterator_range<ContainerT::iterator>::iterator_range;
920-
ContainerT::reference operator[](ContainerT::size_type idx) {
921-
return begin()[idx];
922-
}
923-
ContainerT::const_reference operator[](ContainerT::size_type idx) const {
924-
return begin()[idx];
925-
}
926-
ContainerT::size_type size() const { return end() - begin(); }
927-
};
928-
929867
// Attempts to use vectorized conversions via inline PTX when possible.
930868
struct FpToFpOpConversion
931869
: public ElementwiseOpConversionBase<FpToFpOp, FpToFpOpConversion> {
@@ -1051,7 +989,7 @@ struct FpToFpOpConversion
1051989

1052990
if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) {
1053991
assert(roundingMode.has_value() &&
1054-
"Rounding mode must be specified for convertsions to fp8");
992+
"Rounding mode must be specified for conversions to fp8");
1055993

1056994
// For now only RTNE is supported for conversions from fp16 to fp8
1057995
if (!srcElementType.isF32() &&
@@ -1117,8 +1055,7 @@ Value EmitDualBF16ElementwiseOp(Location loc,
11171055
auto v0 = intel::convertBf16ToFp32(loc, rewriter, operands[0][0]);
11181056
auto v1 = intel::convertBf16ToFp32(loc, rewriter, operands[0][1]);
11191057
auto result = rewriter.create<OP>(loc, f32_ty, v0, v1);
1120-
auto undefRounding = static_cast<RoundingMode>(-1);
1121-
return intel::convertFp32ToBf16(loc, rewriter, result, undefRounding);
1058+
return intel::convertFp32ToBf16(loc, rewriter, result, RoundingMode::RTNE);
11221059
}
11231060

11241061
struct ExternElementwiseOpConversion
@@ -1245,11 +1182,9 @@ struct SIToFPOpConversion
12451182
Type inElemTy = getElementType(op.getIn());
12461183
Type outElemTy = getElementType(op.getOut());
12471184
if (outElemTy.isBF16() && inElemTy.isInteger(8) && operands.size() >= 4) {
1248-
SmallVector<Value> outVals;
12491185
auto value = rewriter.create<LLVM::SIToFPOp>(loc, f32_ty, operands[0][0]);
12501186
return {
12511187
intel::convertFp32ToBf16(loc, rewriter, value, RoundingMode::RTNE)};
1252-
llvm_unreachable("");
12531188
} else if (outElemTy.isBF16()) {
12541189
auto value = rewriter.create<LLVM::SIToFPOp>(loc, f32_ty, operands[0][0]);
12551190
return {

0 commit comments

Comments
 (0)