@@ -1057,17 +1057,21 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder,
10571057 return call;
10581058}
10591059
1060- class ControlBarrierPattern
1061- : public SPIRVToLLVMConversion<spirv::ControlBarrierOp > {
1060+ template < typename BarrierOpTy>
1061+ class ControlBarrierPattern : public SPIRVToLLVMConversion <BarrierOpTy > {
10621062public:
1063- using SPIRVToLLVMConversion<spirv::ControlBarrierOp>::SPIRVToLLVMConversion;
1063+ using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor;
1064+
1065+ using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion;
1066+
1067+ static constexpr StringRef getFuncName ();
10641068
10651069 LogicalResult
1066- matchAndRewrite (spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
1070+ matchAndRewrite (BarrierOpTy controlBarrierOp, OpAdaptor adaptor,
10671071 ConversionPatternRewriter &rewriter) const override {
1068- constexpr StringLiteral funcName = " _Z22__spirv_ControlBarrieriii " ;
1072+ constexpr StringRef funcName = getFuncName () ;
10691073 Operation *symbolTable =
1070- controlBarrierOp->getParentWithTrait <OpTrait::SymbolTable>();
1074+ controlBarrierOp->template getParentWithTrait <OpTrait::SymbolTable>();
10711075
10721076 Type i32 = rewriter.getI32Type ();
10731077
@@ -1266,6 +1270,24 @@ class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
12661270 }
12671271};
12681272
1273+ template <>
1274+ constexpr StringRef
1275+ ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() {
1276+ return " _Z22__spirv_ControlBarrieriii" ;
1277+ }
1278+
1279+ template <>
1280+ constexpr StringRef
1281+ ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() {
1282+ return " _Z33__spirv_ControlBarrierArriveINTELiii" ;
1283+ }
1284+
1285+ template <>
1286+ constexpr StringRef
1287+ ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() {
1288+ return " _Z31__spirv_ControlBarrierWaitINTELiii" ;
1289+ }
1290+
12691291// / Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
12701292// / should be reachable for conversion to succeed. The structure of the loop in
12711293// / LLVM dialect will be the following:
@@ -1899,7 +1921,9 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
18991921 ReturnPattern, ReturnValuePattern,
19001922
19011923 // Barrier ops
1902- ControlBarrierPattern,
1924+ ControlBarrierPattern<spirv::ControlBarrierOp>,
1925+ ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>,
1926+ ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>,
19031927
19041928 // Group reduction operations
19051929 GroupReducePattern<spirv::GroupIAddOp>,
0 commit comments