@@ -116,12 +116,31 @@ struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
116116 lookupOrCreateSPIRVFn (moduleOp, funcName, flagTy, voidTy,
117117 /* isMemNone=*/ false , /* isConvergent=*/ true );
118118
119- // Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`.
120- // See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`.
121- constexpr int64_t localMemFenceFlag = 1 ;
119+ // Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE` and
120+ // `CLK_GLOBAL_MEM_FENCE`. See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`.
121+ constexpr int32_t localMemFenceFlag = 1 ;
122+ constexpr int32_t globalMemFenceFlag = 2 ;
123+ int32_t memFenceFlag = 0 ;
124+ std::optional<ArrayAttr> addressSpaces = adaptor.getAddressSpaces ();
125+ if (addressSpaces) {
126+ for (Attribute attr : addressSpaces.value ()) {
127+ auto addressSpace = cast<gpu::AddressSpaceAttr>(attr).getValue ();
128+ switch (addressSpace) {
129+ case gpu::AddressSpace::Global:
130+ memFenceFlag = memFenceFlag | globalMemFenceFlag;
131+ break ;
132+ case gpu::AddressSpace::Workgroup:
133+ memFenceFlag = memFenceFlag | localMemFenceFlag;
134+ break ;
135+ case gpu::AddressSpace::Private:
136+ break ;
137+ }
138+ }
139+ } else {
140+ memFenceFlag = localMemFenceFlag | globalMemFenceFlag;
141+ }
122142 Location loc = op->getLoc ();
123- Value flag =
124- rewriter.create <LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag);
143+ Value flag = rewriter.create <LLVM::ConstantOp>(loc, flagTy, memFenceFlag);
125144 rewriter.replaceOp (op, createSPIRVBuiltinCall (loc, rewriter, func, flag));
126145 return success ();
127146 }
0 commit comments