Skip to content

Commit f98381b

Browse files
committed
[XPU][GEN] Drop barrier operation
Drop TritonGEN `barrier` operation replacing its uses with equivalent `spirv.OpControlBarrier` operations. Note the `GLOBAL` memory semantics specified by the original operation correspond to `SequentiallyConsistent | CrossWorkgroupMemory` in SPIR-V, as `SequentiallyConsistent` is implied. Signed-off-by: victor-eds <[email protected]>
1 parent 98dca47 commit f98381b

File tree

7 files changed

+30
-73
lines changed

7 files changed

+30
-73
lines changed

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
10451045
// -----
10461046

10471047
module attributes {"triton_gpu.target" = "xpu", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
1048-
// CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {convergent, no_unwind, will_return}
10491048
// CHECK-LABEL: atomic_cas_f32_scalar_no_store
10501049
tt.func @atomic_cas_f32_scalar_no_store(%ptr : !tt.ptr<f32>, %cmp : f32, %val : f32) {
10511050
// CHECK: [[TRUE:%.*]] = llvm.mlir.constant(true) : i1
@@ -1054,7 +1053,10 @@ module attributes {"triton_gpu.target" = "xpu", "triton_gpu.num-ctas" = 1 : i32,
10541053
// CHECK: [[CMP:%.*]] = llvm.icmp "eq"
10551054
// CHECK: [[MASK:%.*]] = llvm.and [[MASK0]], [[CMP]]
10561055
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
1057-
// CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> ()
1056+
// CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
1057+
// CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
1058+
// CHECK: [[GLOBAL:%.*]] = llvm.mlir.constant(528 : i32) : i32
1059+
// CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[WGSCOPE]], [[WGMEMSCOPE]], [[GLOBAL]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
10581060
// CHECK-NEXT: llvm.cond_br [[MASK]], ^bb1, ^bb2([[ZERO]] : i32)
10591061
// CHECK-NEXT: ^bb1:
10601062
// CHECK-NEXT: [[BCAST1:%.*]] = llvm.bitcast %arg1 : f32 to i32
@@ -1109,7 +1111,6 @@ module attributes {"triton_gpu.target" = "xpu", "triton_gpu.num-ctas" = 1 : i32,
11091111

11101112
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
11111113
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
1112-
// CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {convergent, no_unwind, will_return}
11131114
// CHECK-LABEL: atomic_add_f32
11141115
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
11151116
// CHECK: [[EV0_ARG2:%.*]] = llvm.extractvalue %arg2[0] : !llvm.struct<(f32, f32)>
@@ -1132,7 +1133,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
11321133
// CHECK: [[IE2:%.*]] = llvm.insertelement [[EV1_ARG2]], [[UNDEF2]][{{.*}} : i64] : vector<1xf32>
11331134
// CHECK-NEXT: [[PRED2:%.*]] = llvm.and [[CST_TRUE]], {{.*}} : i1
11341135
// CHECK-NEXT: [[ZERO2:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
1135-
// CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> ()
1136+
// CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
1137+
// CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
1138+
// CHECK: [[GLOBAL:%.*]] = llvm.mlir.constant(528 : i32) : i32
1139+
// CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[WGSCOPE]], [[WGMEMSCOPE]], [[GLOBAL]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
11361140
// CHECK-NEXT: llvm.cond_br [[PRED2]], ^bb3, ^bb4([[ZERO2]] : f32)
11371141
// CHECK-NEXT: ^bb3:
11381142
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE2]] : vector<1xf32> to f32
@@ -1147,7 +1151,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
11471151
// -----
11481152

11491153
module attributes {"triton_gpu.target" = "xpu", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
1150-
// CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {convergent, no_unwind, will_return}
11511154
// CHECK-LABEL: atomic_add_f32_scalar_no_store
11521155
tt.func @atomic_add_f32_scalar_no_store(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
11531156
// CHECK: [[CST_TRUE:%.*]] = llvm.mlir.constant(true) : i1
@@ -1159,7 +1162,10 @@ module attributes {"triton_gpu.target" = "xpu", "triton_gpu.num-ctas" = 1 : i32,
11591162
// CHECK: [[IE1:%.*]] = llvm.insertelement %arg2, [[UNDEF1]][{{.*}} : i64] : vector<1xf32>
11601163
// CHECK: [[PRED:%.*]] = llvm.and [[AND1]], %arg1 : i1
11611164
// CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32
1162-
// CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> ()
1165+
// CHECK: [[WGSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
1166+
// CHECK: [[WGMEMSCOPE:%.*]] = llvm.mlir.constant(2 : i32) : i32
1167+
// CHECK: [[GLOBAL:%.*]] = llvm.mlir.constant(528 : i32) : i32
1168+
// CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[WGSCOPE]], [[WGMEMSCOPE]], [[GLOBAL]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
11631169
// CHECK-NEXT: llvm.cond_br [[PRED]], ^bb1, ^bb2([[ZERO]] : f32)
11641170
// CHECK-NEXT: ^bb1:
11651171
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[IE1]] : vector<1xf32> to f32

test/TritonGEN/tritongen-to-llvm.mlir

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,5 @@
11
// RUN: triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s
22

3-
// CHECK: llvm.func spir_funccc @_Z7barrierj(i32) attributes {convergent, no_unwind, will_return}
4-
5-
llvm.func @triton_gen.barrier() {
6-
// CHECK-LABEL: triton_gen.barrier
7-
// CHECK: [[LOCAL:%.*]] = llvm.mlir.constant(1 : i32) : i32
8-
// CHECK: llvm.call spir_funccc @_Z7barrierj([[LOCAL]]) {{.*}} : (i32) -> ()
9-
// CHECK: [[GLOBAL:%.*]] = llvm.mlir.constant(2 : i32) : i32
10-
// CHECK: llvm.call spir_funccc @_Z7barrierj([[GLOBAL]]) {{.*}} : (i32) -> ()
11-
triton_gen.barrier {mem_fence=Local}
12-
triton_gen.barrier {mem_fence=Global}
13-
llvm.return
14-
}
15-
16-
// -----
17-
183
// CHECK-DAG: llvm.func spir_funccc @_Z31intel_work_group_barrier_arriveii(i32, i32) attributes {convergent, no_unwind, will_return}
194
// CHECK-DAG: llvm.func spir_funccc @_Z29intel_work_group_barrier_waitii(i32, i32) attributes {convergent, no_unwind, will_return}
205

test/TritonGEN/tritongen.mlir

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
11
// RUN: triton-opt %s -split-input-file -verify-diagnostics | FileCheck %s
22

3-
llvm.func @triton_gen.barrier() {
4-
// CHECK-LABEL: triton_gen.barrier
5-
// CHECK: triton_gen.barrier {mem_fence = Local}
6-
triton_gen.barrier {mem_fence=Local}
7-
llvm.return
8-
}
9-
103
llvm.func @triton_gen.split_barrier_signal() {
114
// CHECK-LABEL: triton_gen.split_barrier_signal
125
// CHECK: triton_gen.split_barrier_signal {mem_fence = None, mem_scope = WorkGroup}

third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,6 @@ class TritonGEN_Op<string mnemonic, list<Trait> traits = []> :
3232
// Synchronization
3333
//===----------------------------------------------------------------------===//
3434

35-
def TritonGEN_BarrierOp : TritonGEN_Op<"barrier"> {
36-
let summary = "Workgroup barrier";
37-
let description = [{
38-
The `triton_gen.barrier` operation performs a workgroup barrier and ensures
39-
all outstanding memory transaction using local or global memory are complete.
40-
}];
41-
let arguments = (ins TritonGEN_MemFence:$mem_fence);
42-
let results = (outs);
43-
let assemblyFormat = "attr-dict";
44-
let assemblyFormat = [{
45-
` ` `{` `mem_fence` `=` $mem_fence `}` attr-dict
46-
}];
47-
}
48-
4935
def TritonGEN_SplitBarrierSignalOp : TritonGEN_Op<"split_barrier_signal"> {
5036
let summary = "Split barrier signal";
5137
let description = [{

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -479,27 +479,6 @@ namespace {
479479
// Synchronization Ops Lowerings
480480
//===----------------------------------------------------------------------===//
481481

482-
struct TritonGENBarrierLowering
483-
: public ConvertOpToLLVMPattern<TritonGEN::BarrierOp> {
484-
using ConvertOpToLLVMPattern<TritonGEN::BarrierOp>::ConvertOpToLLVMPattern;
485-
486-
LogicalResult
487-
matchAndRewrite(TritonGEN::BarrierOp op, OpAdaptor adaptor,
488-
ConversionPatternRewriter &rewriter) const override {
489-
MLIRContext *ctx = rewriter.getContext();
490-
Location loc = op->getLoc();
491-
Type retType = void_ty(ctx);
492-
IntegerType argType = int_ty(32);
493-
Value arg = i32_val(static_cast<int>(op.getMemFence()));
494-
495-
LLVM::CallOp callOp =
496-
createDeviceFunctionCall(rewriter, "_Z7barrierj", {retType}, {argType},
497-
{arg}, {}, convergentNoUnwindWillReturnAttrs);
498-
rewriter.replaceOp(op, callOp);
499-
return success();
500-
}
501-
};
502-
503482
struct TritonGENSplitBarrier {
504483
protected:
505484
template <typename OpType>
@@ -1092,13 +1071,12 @@ struct TritonGENToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
10921071

10931072
void mlir::triton::populateTritonGENToLLVMConversionPatterns(
10941073
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1095-
patterns
1096-
.add<TritonGENBarrierLowering, TritonGENSplitBarrierSignalLowering,
1097-
TritonGENSplitBarrierWaitLowering, TritonSubGroupReduceLowering,
1098-
TritonSubGroupScanLowering, TritonMatrixDPASLowering,
1099-
TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering,
1100-
TritonMatrix2DBlockPrefetchLowering, TritonSIMDBlockReadLowering,
1101-
TritonSIMDBlockWriteLowering>(converter);
1074+
patterns.add<
1075+
TritonGENSplitBarrierSignalLowering, TritonGENSplitBarrierWaitLowering,
1076+
TritonSubGroupReduceLowering, TritonSubGroupScanLowering,
1077+
TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
1078+
TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering,
1079+
TritonSIMDBlockReadLowering, TritonSIMDBlockWriteLowering>(converter);
11021080
}
11031081

11041082
void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry &registry) {

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "Dialect/TritonIntelGPU/IR/Dialect.h"
22
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
3+
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
34
#include "mlir/IR/Matchers.h"
45
#include "mlir/IR/TypeUtilities.h"
56
#include "llvm/ADT/SmallVector.h"
@@ -1307,7 +1308,10 @@ struct AtomicCASOpConversion
13071308

13081309
Value zero = (valueElemNBits == 32) ? i32_val(0) : i64_val(0);
13091310
if (!atomicNeedsSharedMemory(op.getResult()))
1310-
rewriter.create<TritonGEN::BarrierOp>(loc, TritonGEN::MemFence::GLOBAL);
1311+
rewriter.create<spirv::ControlBarrierOp>(
1312+
loc, spirv::Scope::Workgroup, spirv::Scope::Workgroup,
1313+
spirv::MemorySemantics::SequentiallyConsistent |
1314+
spirv::MemorySemantics::CrossWorkgroupMemory);
13111315
Block &endBlock =
13121316
LLVM::intel::createPredicatedBlock(rewriter, loc, mask, {zero}, [&] {
13131317
// casPtr = bitcast(casPtr, ptr_ty(ctx, 1));
@@ -1456,8 +1460,10 @@ struct AtomicRMWOpConversion
14561460
rmwPtr, rmwVal, rmwMask, {zero});
14571461
} else {
14581462
if (!atomicNeedsSharedMemory(op.getResult()))
1459-
rewriter.create<TritonGEN::BarrierOp>(loc,
1460-
TritonGEN::MemFence::GLOBAL);
1463+
rewriter.create<spirv::ControlBarrierOp>(
1464+
loc, spirv::Scope::Workgroup, spirv::Scope::Workgroup,
1465+
spirv::MemorySemantics::SequentiallyConsistent |
1466+
spirv::MemorySemantics::CrossWorkgroupMemory);
14611467
endBlock = &LLVM::intel::createPredicatedBlock(
14621468
rewriter, loc, rmwMask, {zero}, [&] {
14631469
mlir::LLVM::AtomicBinOp rmwKind;

third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h"
1919
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
2020
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
21+
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
2122
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
2223
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
2324
#include "mlir/IR/PatternMatch.h"
@@ -268,6 +269,8 @@ class TritonGPUToLLVMPipelineManager {
268269
triton::populateGPUToTritonGENConversionPatterns(typeConverter, patterns);
269270
cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
270271
populateGpuToLLVMSPVConversionPatterns(typeConverter, patterns);
272+
populateSPIRVToLLVMConversionPatterns(typeConverter, patterns,
273+
spirv::ClientAPI::OpenCL);
271274
}
272275

273276
private:

0 commit comments

Comments
 (0)