Skip to content

Commit 8473cab

Browse files
authored
[XPU][TritonGEN] Replace split barrier ops usages with SPIR-V ops (#2814)
Use SPIR-V operations to encode split barriers. Semantics are exactly the same. As memory semantics are `None`, memory scope is omitted. Execution scope value is maintained. Signed-off-by: victor-eds <[email protected]>
1 parent c16eccf commit 8473cab

File tree

8 files changed

+23
-152
lines changed

8 files changed

+23
-152
lines changed

test/TritonGEN/tritongen-to-llvm.mlir

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,5 @@
11
// RUN: triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s
22

3-
// CHECK-DAG: llvm.func spir_funccc @_Z31intel_work_group_barrier_arriveii(i32, i32) attributes {convergent, no_unwind, will_return}
4-
// CHECK-DAG: llvm.func spir_funccc @_Z29intel_work_group_barrier_waitii(i32, i32) attributes {convergent, no_unwind, will_return}
5-
6-
llvm.func @triton_gen.split_barrier() {
7-
// CHECK-LABEL: triton_gen.split_barrier() {
8-
// CHECK-DAG: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
9-
// CHECK-DAG: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32
10-
// CHECK: llvm.call spir_funccc @_Z31intel_work_group_barrier_arriveii([[ZERO]], [[ONE]]) {{.*}} : (i32, i32) -> ()
11-
// CHECK-DAG: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
12-
// CHECK-DAG: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32
13-
// CHECK: llvm.call spir_funccc @_Z29intel_work_group_barrier_waitii([[ZERO]], [[ONE]]) {{.*}} : (i32, i32) -> ()
14-
triton_gen.split_barrier_signal {mem_fence=None, mem_scope=WorkGroup}
15-
triton_gen.split_barrier_wait {mem_fence=None, mem_scope=WorkGroup}
16-
llvm.return
17-
}
18-
19-
// -----
20-
213
// CHECK-DAG: llvm.func spir_funccc @_Z30sub_group_clustered_reduce_addij(i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
224
// CHECK-DAG: llvm.func spir_funccc @_Z30sub_group_clustered_reduce_mulij(i32, i32) -> i32 attributes {convergent, no_unwind, will_return}
235
// CHECK-DAG: llvm.func spir_funccc @_Z30sub_group_clustered_reduce_maxij(i32, i32) -> i32 attributes {convergent, no_unwind, will_return}

test/TritonGEN/tritongen.mlir

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,5 @@
11
// RUN: triton-opt %s -split-input-file -verify-diagnostics | FileCheck %s
22

3-
llvm.func @triton_gen.split_barrier_signal() {
4-
// CHECK-LABEL: triton_gen.split_barrier_signal
5-
// CHECK: triton_gen.split_barrier_signal {mem_fence = None, mem_scope = WorkGroup}
6-
triton_gen.split_barrier_signal {mem_fence=None, mem_scope=WorkGroup}
7-
llvm.return
8-
}
9-
10-
llvm.func @triton_gen.split_barrier_wait() {
11-
// CHECK-LABEL: triton_gen.split_barrier_wait
12-
// CHECK: triton_gen.split_barrier_wait {mem_fence = Local, mem_scope = SubGroup}
13-
triton_gen.split_barrier_wait {mem_fence=Local, mem_scope=SubGroup}
14-
llvm.return
15-
}
16-
17-
// -----
18-
193
module attributes {
204
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 32>>
215
} {

test/TritonIntelGPU/prefetch-block.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ module attributes {"triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-wa
3333
// CHECK-NEXT: [[B3:%.*]] = tt.advance [[B2]], {{.*}} : <tensor<32x256xf16, #blocked2>>
3434
// CHECK-NEXT: [[B4:%.*]] = tt.make_tensor_ptr %arg1, {{.*}} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>>
3535

36-
// CHECK: triton_gen.split_barrier_signal {mem_fence = None, mem_scope = WorkGroup}
36+
// CHECK: spirv.INTEL.ControlBarrierArrive <Workgroup>, <Workgroup>, <None>
3737
// CHECK-NEXT: scf.for [[IV:%.*]] = [[CST_ZERO]] to [[CST_4096]] step [[CST_32]]
3838
// CHECK-SAME: iter_args([[CST:%.*]] = {{.*}}, [[A6:%.*]] = [[A4]], [[B6:%.*]] = [[B4]], [[A5:%.*]] = [[A3]], [[B5:%.*]] = [[B3]])
3939
// CHECK-NEXT: [[LD_A:%.*]] = tt.load [[A6]]
@@ -45,11 +45,11 @@ module attributes {"triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-wa
4545
// CHECK-DAG: tt.advance [[A6]], {{.*}} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>>
4646
// CHECK-NEXT: tt.advance [[B5]], {{.*}} : <tensor<32x256xf16, #blocked2>>
4747
// CHECK-DAG: tt.advance [[B6]], {{.*}} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>>
48-
// CHECK: triton_gen.split_barrier_wait {mem_fence = None, mem_scope = WorkGroup}
49-
// CHECK-NEXT: triton_gen.split_barrier_signal {mem_fence = None, mem_scope = WorkGroup}
48+
// CHECK: spirv.INTEL.ControlBarrierWait <Workgroup>, <Workgroup>, <None>
49+
// CHECK-NEXT: spirv.INTEL.ControlBarrierArrive <Workgroup>, <Workgroup>, <None>
5050
// CHECK-NEXT: scf.yield {{.*}}
5151
// CHECK-NEXT: }
52-
// CHECK-NEXT: triton_gen.split_barrier_wait {mem_fence = None, mem_scope = WorkGroup}
52+
// CHECK-NEXT: spirv.INTEL.ControlBarrierWait <Workgroup>, <Workgroup>, <None>
5353

5454
%c64_i32 = arith.constant 64 : i32
5555
%c16_i32 = arith.constant 16 : i32

third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -32,49 +32,6 @@ class TritonGEN_Op<string mnemonic, list<Trait> traits = []> :
3232
// Synchronization
3333
//===----------------------------------------------------------------------===//
3434

35-
def TritonGEN_SplitBarrierSignalOp : TritonGEN_Op<"split_barrier_signal"> {
36-
let summary = "Split barrier signal";
37-
let description = [{
38-
The `triton_gen.split_barrier_signal` operation signals the arrival of a
39-
workgroup thread at the current program point; the issuing thread continue
40-
execution.
41-
Once all threads in a workgroup have signaled their arrival, any other thread
42-
waiting at the `triton_gen.split_barrier_wait` operation can exit it.
43-
The '$mem_fence' attribute is a bitfield that specifies the memory address
44-
spaces to apply the memory ordering constraints. The '$mem_scope' attribute
45-
describes the work-items to apply the memory ordering constraints.
46-
Behavior is undefined:
47-
- unless all threads in a workgroup participate in the barrier
48-
- a thread waits on a barrier before signaling
49-
- the '$mem_fence' and '$mem_scope' attributes aren't the same for all
50-
threads in a workgroup
51-
Furthermore, if the `$mem_fence` argument differs between the barrier signal
52-
and wait operation, then only memory operations for the address spaces specified
53-
by the intersection of the two flags arguments are visible.
54-
}];
55-
let arguments = (ins TritonGEN_MemFence:$mem_fence, TritonGEN_MemScope:$mem_scope);
56-
let results = (outs);
57-
let assemblyFormat = [{
58-
` ` `{` `mem_fence` `=` $mem_fence `,` `mem_scope` `=` $mem_scope `}` attr-dict
59-
}];
60-
}
61-
62-
def TritonGEN_SplitBarrierWaitOp : TritonGEN_Op<"split_barrier_wait"> {
63-
let summary = "Split barrier wait";
64-
let description = [{
65-
The `triton_gen.split_barrier_wait` operation blocks the issuing workgroup
66-
thread until all other threads arrive at the corresponding
67-
`triton_gen.split_barrier_arrive` operation. Please refer to the
68-
`triton_gen.split_barrier_signal` documentation for the description of the
69-
attributes and the constrains on the operation.
70-
}];
71-
let arguments = (ins TritonGEN_MemFence:$mem_fence, TritonGEN_MemScope:$mem_scope);
72-
let results = (outs);
73-
let assemblyFormat = [{
74-
` ` `{` `mem_fence` `=` $mem_fence `,` `mem_scope` `=` $mem_scope `}` attr-dict
75-
}];
76-
}
77-
7835
def TritonGEN_SubGroupReduceOp : TritonGEN_Op<"sub_group_reduce", [
7936
AllTypesMatch<["res", "value"]>]>,
8037
Results<(outs SignlessIntegerOrFloatLike:$res)>,

third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def TritonIntelGPUPrefetchBlock : Pass<"tritonintelgpu-prefetch-block", "mlir::M
203203
"mlir::triton::TritonGEN::TritonGENDialect",
204204
"mlir::triton::gpu::intel::TritonIntelGPUDialect",
205205
"mlir::scf::SCFDialect",
206+
"mlir::spirv::SPIRVDialect",
206207
"mlir::gpu::GPUDialect"];
207208
let options = [
208209
Option<"numAdvancePrefetches", "num-advance-prefetches",

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 3 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -481,61 +481,6 @@ namespace {
481481
// Synchronization Ops Lowerings
482482
//===----------------------------------------------------------------------===//
483483

484-
struct TritonGENSplitBarrier {
485-
protected:
486-
template <typename OpType>
487-
void replaceWithCall(OpType op, StringRef funcName,
488-
ConversionPatternRewriter &rewriter) const {
489-
static_assert(
490-
std::is_same<OpType, TritonGEN::SplitBarrierSignalOp>::value ||
491-
std::is_same<OpType, TritonGEN::SplitBarrierWaitOp>::value,
492-
"Unexpected OpType");
493-
494-
MLIRContext *ctx = rewriter.getContext();
495-
Location loc = op->getLoc();
496-
Type retType = void_ty(ctx);
497-
Value memFence = i32_val(static_cast<int>(op.getMemFence()));
498-
Value memScope = i32_val(static_cast<int>(op.getMemScope()));
499-
SmallVector<Value> args{memFence, memScope};
500-
SmallVector<Type> argTypes;
501-
for (auto arg : args)
502-
argTypes.push_back(arg.getType());
503-
504-
LLVM::CallOp callOp =
505-
createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
506-
{}, convergentNoUnwindWillReturnAttrs);
507-
rewriter.replaceOp(op, callOp);
508-
}
509-
};
510-
511-
struct TritonGENSplitBarrierSignalLowering
512-
: public ConvertOpToLLVMPattern<TritonGEN::SplitBarrierSignalOp>,
513-
public TritonGENSplitBarrier {
514-
using ConvertOpToLLVMPattern<
515-
TritonGEN::SplitBarrierSignalOp>::ConvertOpToLLVMPattern;
516-
LogicalResult
517-
matchAndRewrite(TritonGEN::SplitBarrierSignalOp op, OpAdaptor adaptor,
518-
ConversionPatternRewriter &rewriter) const override {
519-
TritonGENSplitBarrier::replaceWithCall(
520-
op, "_Z31intel_work_group_barrier_arriveii", rewriter);
521-
return success();
522-
}
523-
};
524-
525-
struct TritonGENSplitBarrierWaitLowering
526-
: public ConvertOpToLLVMPattern<TritonGEN::SplitBarrierWaitOp>,
527-
public TritonGENSplitBarrier {
528-
using ConvertOpToLLVMPattern<
529-
TritonGEN::SplitBarrierWaitOp>::ConvertOpToLLVMPattern;
530-
LogicalResult
531-
matchAndRewrite(TritonGEN::SplitBarrierWaitOp op, OpAdaptor adaptor,
532-
ConversionPatternRewriter &rewriter) const override {
533-
TritonGENSplitBarrier::replaceWithCall(
534-
op, "_Z29intel_work_group_barrier_waitii", rewriter);
535-
return success();
536-
}
537-
};
538-
539484
struct TritonSubGroupBase {
540485
protected:
541486
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
@@ -1082,10 +1027,9 @@ struct TritonGENToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
10821027
void mlir::triton::populateTritonGENToLLVMConversionPatterns(
10831028
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
10841029
patterns
1085-
.add<TritonGENSplitBarrierSignalLowering,
1086-
TritonGENSplitBarrierWaitLowering, TritonSubGroupReduceLowering,
1087-
TritonSubGroupScanLowering, TritonMatrixDPASLowering,
1088-
TritonMatrix2DBlockLoadLowering, TritonMatrix2DBlockStoreLowering,
1030+
.add<TritonSubGroupReduceLowering, TritonSubGroupScanLowering,
1031+
TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
1032+
TritonMatrix2DBlockStoreLowering,
10891033
TritonMatrix2DBlockPrefetchLowering, TritonSubGroupBlockReadLowering,
10901034
TritonSubGroupBlockWriteLowering>(converter);
10911035
}

third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ add_triton_library(TritonIntelGPUTransforms
2020

2121
LINK_LIBS PUBLIC
2222
MLIRSCFTransforms
23+
MLIRSPIRVDialect
2324
MLIRTransforms
2425
MLIRTransformUtils
2526
TritonIntelAnalysis

third_party/intel/lib/TritonIntelGPUTransforms/PrefetchBlock.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939

4040
#include "TritonToTritonGPUWarp/TritonToTritonGPUWarpPass.h"
4141
#include "mlir/Dialect/SCF/IR/SCF.h"
42+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
43+
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
4244
#include "mlir/IR/BuiltinAttributes.h"
4345
#include "mlir/IR/PatternMatch.h"
4446

@@ -367,13 +369,13 @@ void PrefetchBlockPass::injectPrefetchOpsInPreheader(
367369
if (injectSplitBarriers) {
368370
Location loc = loop.getLoc();
369371
b.setInsertionPoint(loop);
370-
b.create<tt::TritonGEN::SplitBarrierSignalOp>(
371-
loc, tt::TritonGEN::MemFence::NONE,
372-
tt::TritonGEN::MemScope::WORK_GROUP);
372+
b.create<spirv::INTELControlBarrierArriveOp>(loc, spirv::Scope::Workgroup,
373+
spirv::Scope::Workgroup,
374+
spirv::MemorySemantics::None);
373375
b.setInsertionPoint(loop->getNextNode());
374-
b.create<tt::TritonGEN::SplitBarrierWaitOp>(
375-
loc, tt::TritonGEN::MemFence::NONE,
376-
tt::TritonGEN::MemScope::WORK_GROUP);
376+
b.create<spirv::INTELControlBarrierWaitOp>(loc, spirv::Scope::Workgroup,
377+
spirv::Scope::Workgroup,
378+
spirv::MemorySemantics::None);
377379
}
378380
}
379381

@@ -454,12 +456,12 @@ void PrefetchBlockPass::injectPrefetchOpsInBody(
454456
if (injectSplitBarriers) {
455457
Location loc = loop.getLoc();
456458
b.setInsertionPoint(yield);
457-
b.create<tt::TritonGEN::SplitBarrierWaitOp>(
458-
loc, tt::TritonGEN::MemFence::NONE,
459-
tt::TritonGEN::MemScope::WORK_GROUP);
460-
b.create<tt::TritonGEN::SplitBarrierSignalOp>(
461-
loc, tt::TritonGEN::MemFence::NONE,
462-
tt::TritonGEN::MemScope::WORK_GROUP);
459+
b.create<spirv::INTELControlBarrierWaitOp>(loc, spirv::Scope::Workgroup,
460+
spirv::Scope::Workgroup,
461+
spirv::MemorySemantics::None);
462+
b.create<spirv::INTELControlBarrierArriveOp>(loc, spirv::Scope::Workgroup,
463+
spirv::Scope::Workgroup,
464+
spirv::MemorySemantics::None);
463465
}
464466

465467
yield.getResultsMutable().append(advances);

0 commit comments

Comments
 (0)