@@ -1024,6 +1024,71 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
10241024 }
10251025};
10261026
1027+ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn (Operation *symbolTable,
1028+ StringRef name,
1029+ ArrayRef<Type> paramTypes,
1030+ Type resultType) {
1031+ auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1032+ SymbolTable::lookupSymbolIn (symbolTable, name));
1033+ if (func)
1034+ return func;
1035+
1036+ OpBuilder b (symbolTable->getRegion (0 ));
1037+ func = b.create <LLVM::LLVMFuncOp>(
1038+ symbolTable->getLoc (), name,
1039+ LLVM::LLVMFunctionType::get (resultType, paramTypes));
1040+ func.setCConv (LLVM::cconv::CConv::SPIR_FUNC);
1041+ func.setConvergent (true );
1042+ func.setNoUnwind (true );
1043+ func.setWillReturn (true );
1044+ return func;
1045+ }
1046+
1047+ static LLVM::CallOp createSPIRVBuiltinCall (Location loc, OpBuilder &builder,
1048+ LLVM::LLVMFuncOp func,
1049+ ValueRange args) {
1050+ auto call = builder.create <LLVM::CallOp>(loc, func, args);
1051+ call.setCConv (func.getCConv ());
1052+ call.setConvergentAttr (func.getConvergentAttr ());
1053+ call.setNoUnwindAttr (func.getNoUnwindAttr ());
1054+ call.setWillReturnAttr (func.getWillReturnAttr ());
1055+ return call;
1056+ }
1057+
1058+ class ControlBarrierPattern
1059+ : public SPIRVToLLVMConversion<spirv::ControlBarrierOp> {
1060+ public:
1061+ using SPIRVToLLVMConversion<spirv::ControlBarrierOp>::SPIRVToLLVMConversion;
1062+
1063+ LogicalResult
1064+ matchAndRewrite (spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
1065+ ConversionPatternRewriter &rewriter) const override {
1066+ constexpr StringLiteral funcName = " _Z22__spirv_ControlBarrieriii" ;
1067+ Operation *symbolTable =
1068+ controlBarrierOp->getParentWithTrait <OpTrait::SymbolTable>();
1069+
1070+ Type i32 = rewriter.getI32Type ();
1071+
1072+ Type voidTy = rewriter.getType <LLVM::LLVMVoidType>();
1073+ LLVM::LLVMFuncOp func =
1074+ lookupOrCreateSPIRVFn (symbolTable, funcName, {i32 , i32 , i32 }, voidTy);
1075+
1076+ Location loc = controlBarrierOp->getLoc ();
1077+ Value execution = rewriter.create <LLVM::ConstantOp>(
1078+ loc, i32 , static_cast <int32_t >(adaptor.getExecutionScope ()));
1079+ Value memory = rewriter.create <LLVM::ConstantOp>(
1080+ loc, i32 , static_cast <int32_t >(adaptor.getMemoryScope ()));
1081+ Value semantics = rewriter.create <LLVM::ConstantOp>(
1082+ loc, i32 , static_cast <int32_t >(adaptor.getMemorySemantics ()));
1083+
1084+ auto call = createSPIRVBuiltinCall (loc, rewriter, func,
1085+ {execution, memory, semantics});
1086+
1087+ rewriter.replaceOp (controlBarrierOp, call);
1088+ return success ();
1089+ }
1090+ };
1091+
10271092// / Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
10281093// / should be reachable for conversion to succeed. The structure of the loop in
10291094// / LLVM dialect will be the following:
@@ -1648,7 +1713,10 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
16481713 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
16491714
16501715 // Return ops
1651- ReturnPattern, ReturnValuePattern>(patterns.getContext (), typeConverter);
1716+ ReturnPattern, ReturnValuePattern,
1717+
1718+ // Barrier ops
1719+ ControlBarrierPattern>(patterns.getContext (), typeConverter);
16521720
16531721 patterns.add <GlobalVariablePattern>(clientAPI, patterns.getContext (),
16541722 typeConverter);
0 commit comments