|
8 | 8 | #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" |
9 | 9 |
|
10 | 10 | using mlir::triton::gpu::ElementwiseOpConversionBase; |
| 11 | +using mlir::triton::gpu::MultipleOperandsRange; |
11 | 12 |
|
12 | 13 | namespace { |
13 | 14 |
|
@@ -859,73 +860,10 @@ inline Type getElementType(Value value) { |
859 | 860 | return type; |
860 | 861 | } |
861 | 862 |
|
862 | | -inline SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, |
863 | | - Type srcTy, |
864 | | - ConversionPatternRewriter &rewriter, |
865 | | - Location loc, |
866 | | - TypeConverter *typeConverter) { |
867 | | - auto tensorTy = dyn_cast<RankedTensorType>(srcTy); |
868 | | - if (!tensorTy) |
869 | | - return inValues; |
870 | | - auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding()); |
871 | | - if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent()))) |
872 | | - return inValues; |
873 | | - SmallVector<Value> outValues; |
874 | | - for (auto v : inValues) { |
875 | | - // cast i32 to appropriate eltType vector and extract elements |
876 | | - auto eltType = typeConverter->convertType(tensorTy.getElementType()); |
877 | | - auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth()); |
878 | | - auto vec = bitcast(v, vecType); |
879 | | - for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) { |
880 | | - outValues.push_back(extract_element(vec, i32_val(i))); |
881 | | - } |
882 | | - } |
883 | | - return outValues; |
884 | | -} |
885 | | - |
886 | | -inline SmallVector<Value> packI32(const SmallVector<Value> &inValues, |
887 | | - Type srcTy, |
888 | | - ConversionPatternRewriter &rewriter, |
889 | | - Location loc, TypeConverter *typeConverter) { |
890 | | - auto tensorTy = dyn_cast<RankedTensorType>(srcTy); |
891 | | - if (!tensorTy) |
892 | | - return inValues; |
893 | | - auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding()); |
894 | | - if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent()))) |
895 | | - return inValues; |
896 | | - SmallVector<Value> outValues; |
897 | | - auto eltType = typeConverter->convertType(tensorTy.getElementType()); |
898 | | - int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); |
899 | | - auto vecType = vec_ty(eltType, vecWidth); |
900 | | - for (int i = 0; i < inValues.size(); i += vecWidth) { |
901 | | - Value vec = undef(vecType); |
902 | | - for (int j = 0; j < vecWidth; j++) { |
903 | | - vec = insert_element(vec, inValues[i + j], i32_val(j)); |
904 | | - } |
905 | | - outValues.push_back(bitcast(vec, i32_ty)); |
906 | | - } |
907 | | - return outValues; |
908 | | -} |
909 | | - |
910 | 863 | typedef std::function<SmallVector<Value>(Location, ConversionPatternRewriter &, |
911 | 864 | const SmallVector<Value> &)> |
912 | 865 | ConverterT; |
913 | 866 |
|
914 | | -class MultipleOperandsRange |
915 | | - : public iterator_range<SmallVector<SmallVector<Value>>::iterator> { |
916 | | - using ContainerT = SmallVector<SmallVector<Value>>; |
917 | | - |
918 | | -public: |
919 | | - using iterator_range<ContainerT::iterator>::iterator_range; |
920 | | - ContainerT::reference operator[](ContainerT::size_type idx) { |
921 | | - return begin()[idx]; |
922 | | - } |
923 | | - ContainerT::const_reference operator[](ContainerT::size_type idx) const { |
924 | | - return begin()[idx]; |
925 | | - } |
926 | | - ContainerT::size_type size() const { return end() - begin(); } |
927 | | -}; |
928 | | - |
929 | 867 | // Attempts to use vectorized conversions via inline PTX when possible. |
930 | 868 | struct FpToFpOpConversion |
931 | 869 | : public ElementwiseOpConversionBase<FpToFpOp, FpToFpOpConversion> { |
@@ -1051,7 +989,7 @@ struct FpToFpOpConversion |
1051 | 989 |
|
1052 | 990 | if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) { |
1053 | 991 | assert(roundingMode.has_value() && |
1054 | | - "Rounding mode must be specified for convertsions to fp8"); |
| 992 | + "Rounding mode must be specified for conversions to fp8"); |
1055 | 993 |
|
1056 | 994 | // For now only RTNE is supported for conversions from fp16 to fp8 |
1057 | 995 | if (!srcElementType.isF32() && |
@@ -1117,8 +1055,7 @@ Value EmitDualBF16ElementwiseOp(Location loc, |
1117 | 1055 | auto v0 = intel::convertBf16ToFp32(loc, rewriter, operands[0][0]); |
1118 | 1056 | auto v1 = intel::convertBf16ToFp32(loc, rewriter, operands[0][1]); |
1119 | 1057 | auto result = rewriter.create<OP>(loc, f32_ty, v0, v1); |
1120 | | - auto undefRounding = static_cast<RoundingMode>(-1); |
1121 | | - return intel::convertFp32ToBf16(loc, rewriter, result, undefRounding); |
| 1058 | + return intel::convertFp32ToBf16(loc, rewriter, result, RoundingMode::RTNE); |
1122 | 1059 | } |
1123 | 1060 |
|
1124 | 1061 | struct ExternElementwiseOpConversion |
@@ -1245,11 +1182,9 @@ struct SIToFPOpConversion |
1245 | 1182 | Type inElemTy = getElementType(op.getIn()); |
1246 | 1183 | Type outElemTy = getElementType(op.getOut()); |
1247 | 1184 | if (outElemTy.isBF16() && inElemTy.isInteger(8) && operands.size() >= 4) { |
1248 | | - SmallVector<Value> outVals; |
1249 | 1185 | auto value = rewriter.create<LLVM::SIToFPOp>(loc, f32_ty, operands[0][0]); |
1250 | 1186 | return { |
1251 | 1187 | intel::convertFp32ToBf16(loc, rewriter, value, RoundingMode::RTNE)}; |
1252 | | - llvm_unreachable(""); |
1253 | 1188 | } else if (outElemTy.isBF16()) { |
1254 | 1189 | auto value = rewriter.create<LLVM::SIToFPOp>(loc, f32_ty, operands[0][0]); |
1255 | 1190 | return { |
|
0 commit comments