2222#include " mlir/IR/BuiltinOps.h"
2323#include " mlir/IR/PatternMatch.h"
2424#include " mlir/Transforms/DialectConversion.h"
25+ #include " llvm/ADT/TypeSwitch.h"
2526#include " llvm/Support/Debug.h"
2627#include " llvm/Support/FormatVariadic.h"
2728
@@ -1027,7 +1028,8 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
10271028static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn (Operation *symbolTable,
10281029 StringRef name,
10291030 ArrayRef<Type> paramTypes,
1030- Type resultType) {
1031+ Type resultType,
1032+ bool convergent = true ) {
10311033 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
10321034 SymbolTable::lookupSymbolIn (symbolTable, name));
10331035 if (func)
@@ -1038,7 +1040,7 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
10381040 symbolTable->getLoc (), name,
10391041 LLVM::LLVMFunctionType::get (resultType, paramTypes));
10401042 func.setCConv (LLVM::cconv::CConv::SPIR_FUNC);
1041- func.setConvergent (true );
1043+ func.setConvergent (convergent );
10421044 func.setNoUnwind (true );
10431045 func.setWillReturn (true );
10441046 return func;
@@ -1089,6 +1091,181 @@ class ControlBarrierPattern
10891091 }
10901092};
10911093
1094+ namespace {
1095+
1096+ StringRef getTypeMangling (Type type, bool isSigned) {
1097+ return llvm::TypeSwitch<Type, StringRef>(type)
1098+ .Case <Float16Type>([](auto ) { return " Dh" ; })
1099+ .Case <Float32Type>([](auto ) { return " f" ; })
1100+ .Case <Float64Type>([](auto ) { return " d" ; })
1101+ .Case <IntegerType>([isSigned](IntegerType intTy) {
1102+ switch (intTy.getWidth ()) {
1103+ case 1 :
1104+ return " b" ;
1105+ case 8 :
1106+ return (isSigned) ? " a" : " c" ;
1107+ case 16 :
1108+ return (isSigned) ? " s" : " t" ;
1109+ case 32 :
1110+ return (isSigned) ? " i" : " j" ;
1111+ case 64 :
1112+ return (isSigned) ? " l" : " m" ;
1113+ default :
1114+ llvm_unreachable (" Unsupported integer width" );
1115+ }
1116+ })
1117+ .Default ([](auto ) {
1118+ llvm_unreachable (" No mangling defined" );
1119+ return " " ;
1120+ });
1121+ }
1122+
1123+ template <typename ReduceOp>
1124+ constexpr StringLiteral getGroupFuncName ();
1125+
1126+ template <>
1127+ constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1128+ return " _Z17__spirv_GroupIAddii" ;
1129+ }
1130+ template <>
1131+ constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1132+ return " _Z17__spirv_GroupFAddii" ;
1133+ }
1134+ template <>
1135+ constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1136+ return " _Z17__spirv_GroupSMinii" ;
1137+ }
1138+ template <>
1139+ constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1140+ return " _Z17__spirv_GroupUMinii" ;
1141+ }
1142+ template <>
1143+ constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1144+ return " _Z17__spirv_GroupFMinii" ;
1145+ }
1146+ template <>
1147+ constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1148+ return " _Z17__spirv_GroupSMaxii" ;
1149+ }
1150+ template <>
1151+ constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1152+ return " _Z17__spirv_GroupUMaxii" ;
1153+ }
1154+ template <>
1155+ constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1156+ return " _Z17__spirv_GroupFMaxii" ;
1157+ }
1158+ template <>
1159+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1160+ return " _Z27__spirv_GroupNonUniformIAddii" ;
1161+ }
1162+ template <>
1163+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1164+ return " _Z27__spirv_GroupNonUniformFAddii" ;
1165+ }
1166+ template <>
1167+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1168+ return " _Z27__spirv_GroupNonUniformIMulii" ;
1169+ }
1170+ template <>
1171+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1172+ return " _Z27__spirv_GroupNonUniformFMulii" ;
1173+ }
1174+ template <>
1175+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1176+ return " _Z27__spirv_GroupNonUniformSMinii" ;
1177+ }
1178+ template <>
1179+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1180+ return " _Z27__spirv_GroupNonUniformUMinii" ;
1181+ }
1182+ template <>
1183+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1184+ return " _Z27__spirv_GroupNonUniformFMinii" ;
1185+ }
1186+ template <>
1187+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1188+ return " _Z27__spirv_GroupNonUniformSMaxii" ;
1189+ }
1190+ template <>
1191+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1192+ return " _Z27__spirv_GroupNonUniformUMaxii" ;
1193+ }
1194+ template <>
1195+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1196+ return " _Z27__spirv_GroupNonUniformFMaxii" ;
1197+ }
1198+ template <>
1199+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1200+ return " _Z33__spirv_GroupNonUniformBitwiseAndii" ;
1201+ }
1202+ template <>
1203+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1204+ return " _Z32__spirv_GroupNonUniformBitwiseOrii" ;
1205+ }
1206+ template <>
1207+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1208+ return " _Z33__spirv_GroupNonUniformBitwiseXorii" ;
1209+ }
1210+ template <>
1211+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1212+ return " _Z33__spirv_GroupNonUniformLogicalAndii" ;
1213+ }
1214+ template <>
1215+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1216+ return " _Z32__spirv_GroupNonUniformLogicalOrii" ;
1217+ }
1218+ template <>
1219+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1220+ return " _Z33__spirv_GroupNonUniformLogicalXorii" ;
1221+ }
1222+ } // namespace
1223+
1224+ template <typename ReduceOp, bool Signed = false , bool NonUniform = false >
1225+ class GroupReducePattern : public SPIRVToLLVMConversion <ReduceOp> {
1226+ public:
1227+ using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
1228+
1229+ LogicalResult
1230+ matchAndRewrite (ReduceOp op, typename ReduceOp::Adaptor adaptor,
1231+ ConversionPatternRewriter &rewriter) const override {
1232+
1233+ Type retTy = op.getResult ().getType ();
1234+ if (!retTy.isIntOrFloat ()) {
1235+ return failure ();
1236+ }
1237+ SmallString<36 > funcName = getGroupFuncName<ReduceOp>();
1238+ funcName += getTypeMangling (retTy, false );
1239+
1240+ Type i32Ty = rewriter.getI32Type ();
1241+ SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1242+ if constexpr (NonUniform) {
1243+ if (adaptor.getClusterSize ()) {
1244+ funcName += " j" ;
1245+ paramTypes.push_back (i32Ty);
1246+ }
1247+ }
1248+
1249+ Operation *symbolTable =
1250+ op->template getParentWithTrait <OpTrait::SymbolTable>();
1251+
1252+ LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn (
1253+ symbolTable, funcName, paramTypes, retTy, !NonUniform);
1254+
1255+ Location loc = op.getLoc ();
1256+ Value scope = rewriter.create <LLVM::ConstantOp>(
1257+ loc, i32Ty, static_cast <int32_t >(adaptor.getExecutionScope ()));
1258+ Value groupOp = rewriter.create <LLVM::ConstantOp>(
1259+ loc, i32Ty, static_cast <int32_t >(adaptor.getGroupOperation ()));
1260+ SmallVector<Value> operands{scope, groupOp};
1261+ operands.append (adaptor.getOperands ().begin (), adaptor.getOperands ().end ());
1262+
1263+ auto call = createSPIRVBuiltinCall (loc, rewriter, func, operands);
1264+ rewriter.replaceOp (op, call);
1265+ return success ();
1266+ }
1267+ };
1268+
10921269// / Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
10931270// / should be reachable for conversion to succeed. The structure of the loop in
10941271// / LLVM dialect will be the following:
@@ -1722,7 +1899,50 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
17221899 ReturnPattern, ReturnValuePattern,
17231900
17241901 // Barrier ops
1725- ControlBarrierPattern>(patterns.getContext (), typeConverter);
1902+ ControlBarrierPattern,
1903+
1904+ // Group reduction operations
1905+ GroupReducePattern<spirv::GroupIAddOp>,
1906+ GroupReducePattern<spirv::GroupFAddOp>,
1907+ GroupReducePattern<spirv::GroupFMinOp>,
1908+ GroupReducePattern<spirv::GroupUMinOp>,
1909+ GroupReducePattern<spirv::GroupSMinOp, /* Signed=*/ true >,
1910+ GroupReducePattern<spirv::GroupFMaxOp>,
1911+ GroupReducePattern<spirv::GroupUMaxOp>,
1912+ GroupReducePattern<spirv::GroupSMaxOp, /* Signed=*/ true >,
1913+ GroupReducePattern<spirv::GroupNonUniformIAddOp, /* Signed=*/ false ,
1914+ /* NonUniform=*/ true >,
1915+ GroupReducePattern<spirv::GroupNonUniformFAddOp, /* Signed=*/ false ,
1916+ /* NonUniform=*/ true >,
1917+ GroupReducePattern<spirv::GroupNonUniformIMulOp, /* Signed=*/ false ,
1918+ /* NonUniform=*/ true >,
1919+ GroupReducePattern<spirv::GroupNonUniformFMulOp, /* Signed=*/ false ,
1920+ /* NonUniform=*/ true >,
1921+ GroupReducePattern<spirv::GroupNonUniformSMinOp, /* Signed=*/ true ,
1922+ /* NonUniform=*/ true >,
1923+ GroupReducePattern<spirv::GroupNonUniformUMinOp, /* Signed=*/ false ,
1924+ /* NonUniform=*/ true >,
1925+ GroupReducePattern<spirv::GroupNonUniformFMinOp, /* Signed=*/ false ,
1926+ /* NonUniform=*/ true >,
1927+ GroupReducePattern<spirv::GroupNonUniformSMaxOp, /* Signed=*/ true ,
1928+ /* NonUniform=*/ true >,
1929+ GroupReducePattern<spirv::GroupNonUniformUMaxOp, /* Signed=*/ false ,
1930+ /* NonUniform=*/ true >,
1931+ GroupReducePattern<spirv::GroupNonUniformFMaxOp, /* Signed=*/ false ,
1932+ /* NonUniform=*/ true >,
1933+ GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /* Signed=*/ false ,
1934+ /* NonUniform=*/ true >,
1935+ GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /* Signed=*/ false ,
1936+ /* NonUniform=*/ true >,
1937+ GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /* Signed=*/ false ,
1938+ /* NonUniform=*/ true >,
1939+ GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /* Signed=*/ false ,
1940+ /* NonUniform=*/ true >,
1941+ GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /* Signed=*/ false ,
1942+ /* NonUniform=*/ true >,
1943+ GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /* Signed=*/ false ,
1944+ /* NonUniform=*/ true >>(patterns.getContext (),
1945+ typeConverter);
17261946
17271947 patterns.add <GlobalVariablePattern>(clientAPI, patterns.getContext (),
17281948 typeConverter);
0 commit comments