Skip to content

Commit 874dd36

Browse files
committed
Add address space modifier to barrier
1 parent 4ae0c50 commit 874dd36

File tree

8 files changed

+71
-11
lines changed

8 files changed

+71
-11
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def GPU_AddressSpaceEnum : GPU_I32Enum<
9999
def GPU_AddressSpaceAttr :
100100
GPU_I32EnumAttr<"address_space", GPU_AddressSpaceEnum>;
101101

102+
def GPU_AddressSpaceAttrArray : TypedArrayAttrBase<GPU_AddressSpaceAttr, "GPU Address Space array">;
103+
102104
//===----------------------------------------------------------------------===//
103105
// GPU Types.
104106
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,8 @@ def GPU_ShuffleOp : GPU_Op<
13551355
];
13561356
}
13571357

1358-
def GPU_BarrierOp : GPU_Op<"barrier"> {
1358+
def GPU_BarrierOp : GPU_Op<"barrier">,
1359+
Arguments<(ins OptionalAttr<GPU_AddressSpaceAttrArray> :$address_spaces)> {
13591360
let summary = "Synchronizes all work items of a workgroup.";
13601361
let description = [{
13611362
The "barrier" op synchronizes all work items of a workgroup. It is used
@@ -1371,11 +1372,25 @@ def GPU_BarrierOp : GPU_Op<"barrier"> {
13711372
accessing the same memory can be avoided by synchronizing work items
13721373
in-between these accesses.
13731374

1375+
The address space of visible memory accesses can be modified by adding a
1376+
list of address spaces required to be visible. By default all address spaces
1377+
are included.
1378+
1379+
```mlir
1380+
// only workgroup address spaces accesses required to be visible
1381+
gpu.barrier memfence [#gpu.address_space<workgroup>]
1382+
// no memory accesses required to be visible
1383+
gpu.barrier memfence []
1384+
// all memory accesses required to be visible
1385+
gpu.barrier
1386+
```
1387+
13741388
Either none or all work items of a workgroup need to execute this op
13751389
in convergence.
13761390
}];
1377-
let assemblyFormat = "attr-dict";
1391+
let assemblyFormat = "(`memfence` $address_spaces^)? attr-dict";
13781392
let hasCanonicalizer = 1;
1393+
let builders = [OpBuilder<(ins)>];
13791394
}
13801395

13811396
def GPU_GPUModuleOp : GPU_Op<"module", [

mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,31 @@ struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
116116
lookupOrCreateSPIRVFn(moduleOp, funcName, flagTy, voidTy,
117117
/*isMemNone=*/false, /*isConvergent=*/true);
118118

119-
// Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`.
120-
// See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`.
121-
constexpr int64_t localMemFenceFlag = 1;
119+
// Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE` and
120+
// `CLK_GLOBAL_MEM_FENCE`. See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`.
121+
constexpr int32_t localMemFenceFlag = 1;
122+
constexpr int32_t globalMemFenceFlag = 2;
123+
int32_t memFenceFlag = 0;
124+
std::optional<ArrayAttr> addressSpaces = adaptor.getAddressSpaces();
125+
if (addressSpaces) {
126+
for (Attribute attr : addressSpaces.value()) {
127+
auto addressSpace = cast<gpu::AddressSpaceAttr>(attr).getValue();
128+
switch (addressSpace) {
129+
case gpu::AddressSpace::Global:
130+
memFenceFlag = memFenceFlag | globalMemFenceFlag;
131+
break;
132+
case gpu::AddressSpace::Workgroup:
133+
memFenceFlag = memFenceFlag | localMemFenceFlag;
134+
break;
135+
case gpu::AddressSpace::Private:
136+
break;
137+
}
138+
}
139+
} else {
140+
memFenceFlag = localMemFenceFlag | globalMemFenceFlag;
141+
}
122142
Location loc = op->getLoc();
123-
Value flag =
124-
rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag);
143+
Value flag = rewriter.create<LLVM::ConstantOp>(loc, flagTy, memFenceFlag);
125144
rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag));
126145
return success();
127146
}

mlir/lib/Conversion/GPUToNVVM/GPUToNVVM.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ include "mlir/IR/PatternBase.td"
1717
include "mlir/Dialect/GPU/IR/GPUOps.td"
1818
include "mlir/Dialect/LLVMIR/NVVMOps.td"
1919

20-
def : Pat<(GPU_BarrierOp), (NVVM_Barrier0Op)>;
20+
def : Pat<(GPU_BarrierOp : $op $memory_fence), (NVVM_Barrier0Op)>;
2121

2222
#endif // MLIR_CONVERSION_GPUTONVVM_TD

mlir/lib/Conversion/GPUToROCDL/GPUToROCDL.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ include "mlir/IR/PatternBase.td"
1717
include "mlir/Dialect/GPU/IR/GPUOps.td"
1818
include "mlir/Dialect/LLVMIR/ROCDLOps.td"
1919

20-
def : Pat<(GPU_BarrierOp), (ROCDL_BarrierOp)>;
20+
def : Pat<(GPU_BarrierOp : $op $memory_fence), (ROCDL_BarrierOp)>;
2121

2222
#endif // MLIR_CONVERSION_GPUTOROCDL_TD

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,9 @@ void BarrierOp::getCanonicalizationPatterns(RewritePatternSet &results,
13511351
results.add(eraseRedundantGpuBarrierOps);
13521352
}
13531353

1354+
void BarrierOp::build(mlir::OpBuilder &odsBuilder,
1355+
mlir::OperationState &odsState) {}
1356+
13541357
//===----------------------------------------------------------------------===//
13551358
// GPUFuncOp
13561359
//===----------------------------------------------------------------------===//

mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,29 @@ gpu.module @barriers {
213213

214214
// CHECK-LABEL: gpu_barrier
215215
func.func @gpu_barrier() {
216-
// CHECK: [[FLAGS:%.*]] = llvm.mlir.constant(1 : i32) : i32
217-
// CHECK: llvm.call spir_funccc @_Z7barrierj([[FLAGS]]) {
216+
// CHECK: [[GLOBAL_AND_LOCAL_FLAG:%.*]] = llvm.mlir.constant(3 : i32) : i32
217+
// CHECK: llvm.call spir_funccc @_Z7barrierj([[GLOBAL_AND_LOCAL_FLAG]]) {
218218
// CHECK-SAME-DAG: no_unwind
219219
// CHECK-SAME-DAG: convergent
220220
// CHECK-SAME-DAG: will_return
221221
// CHECK-NOT: memory_effects = #llvm.memory_effects
222222
// CHECK-SAME: } : (i32) -> ()
223223
gpu.barrier
224+
// CHECK: [[GLOBAL_AND_LOCAL_FLAG2:%.*]] = llvm.mlir.constant(3 : i32) : i32
225+
// CHECK: llvm.call spir_funccc @_Z7barrierj([[GLOBAL_AND_LOCAL_FLAG2]])
226+
gpu.barrier memfence [#gpu.address_space<global>, #gpu.address_space<workgroup>]
227+
// CHECK: [[LOCAL_FLAG:%.*]] = llvm.mlir.constant(1 : i32) : i32
228+
// CHECK: llvm.call spir_funccc @_Z7barrierj([[LOCAL_FLAG]])
229+
gpu.barrier memfence [#gpu.address_space<workgroup>]
230+
// CHECK: [[GLOBAL_FLAG:%.*]] = llvm.mlir.constant(2 : i32) : i32
231+
// CHECK: llvm.call spir_funccc @_Z7barrierj([[GLOBAL_FLAG]])
232+
gpu.barrier memfence [#gpu.address_space<global>]
233+
// CHECK: [[NONE_FLAG:%.*]] = llvm.mlir.constant(0 : i32) : i32
234+
// CHECK: llvm.call spir_funccc @_Z7barrierj([[NONE_FLAG]])
235+
gpu.barrier memfence []
236+
// CHECK: [[NONE_FLAG2:%.*]] = llvm.mlir.constant(0 : i32) : i32
237+
// CHECK: llvm.call spir_funccc @_Z7barrierj([[NONE_FLAG2]])
238+
gpu.barrier memfence [#gpu.address_space<private>]
224239
return
225240
}
226241
}

mlir/test/Dialect/GPU/ops.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ module attributes {gpu.container_module} {
141141
%shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32
142142

143143
"gpu.barrier"() : () -> ()
144+
gpu.barrier
145+
gpu.barrier memfence [#gpu.address_space<workgroup>]
146+
gpu.barrier memfence [#gpu.address_space<global>]
147+
gpu.barrier memfence [#gpu.address_space<global>, #gpu.address_space<workgroup>]
148+
gpu.barrier memfence [#gpu.address_space<private>]
149+
gpu.barrier memfence []
144150

145151
"some_op"(%bIdX, %tIdX) : (index, index) -> ()
146152
%42 = memref.load %arg1[%bIdX] : memref<?xf32, 1>

0 commit comments

Comments
 (0)