Skip to content

Commit 4b882ad

Browse files
Reintroduce TritonGEN::SplitBarrier[Arrive|Wait]Op and add its lowering to SPIRV dialect (#4523)
This patch reverts the previous removal of GEN split barrier operations to avoid SPIRV-specific operations in software pipeliner transformation. Signed-off-by: Whitney Tsang <[email protected]>
1 parent 6d6c0b0 commit 4b882ad

File tree

6 files changed

+147
-21
lines changed

6 files changed

+147
-21
lines changed

test/TritonGEN/tritongen-to-spirv.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,21 @@ llvm.func @triton_gen.barrier.global() {
1515
triton_gen.barrier { mem_fence = Global }
1616
llvm.return
1717
}
18+
19+
// -----
20+
21+
llvm.func @triton_gen.split_barrier_arrive() {
22+
// CHECK-LABEL: triton_gen.split_barrier_arrive() {
23+
// CHECK: spirv.INTEL.ControlBarrierArrive <Workgroup> <Workgroup> <None>
24+
triton_gen.split_barrier_arrive {execution_scope=WorkGroup, memory_scope=WorkGroup}
25+
llvm.return
26+
}
27+
28+
// -----
29+
30+
llvm.func @triton_gen.split_barrier_wait() {
31+
// CHECK-LABEL: triton_gen.split_barrier_wait() {
32+
// CHECK: spirv.INTEL.ControlBarrierWait <Workgroup> <Workgroup> <None>
33+
triton_gen.split_barrier_wait {execution_scope=WorkGroup, memory_scope=WorkGroup}
34+
llvm.return
35+
}

test/TritonGEN/tritongen.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,20 @@ llvm.func @triton_gen.barrier() {
77
llvm.return
88
}
99

10+
llvm.func @triton_gen.split_barrier_arrive() {
11+
// CHECK-LABEL: triton_gen.split_barrier_arrive
12+
// CHECK: triton_gen.split_barrier_arrive {execution_scope = WorkGroup, memory_scope = WorkGroup}
13+
triton_gen.split_barrier_arrive {execution_scope=WorkGroup, memory_scope=WorkGroup}
14+
llvm.return
15+
}
16+
17+
llvm.func @triton_gen.split_barrier_wait() {
18+
// CHECK-LABEL: triton_gen.split_barrier_wait
19+
// CHECK: triton_gen.split_barrier_wait {execution_scope = WorkGroup, memory_scope = WorkGroup}
20+
triton_gen.split_barrier_wait {execution_scope=WorkGroup, memory_scope=WorkGroup}
21+
llvm.return
22+
}
23+
1024
llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
1125
// CHECK: llvm.func @triton_gen.dpas(%arg0: vector<8xi32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) {
1226
// CHECK-NEXT: %0 = triton_gen.dpas %arg0, %arg1, %arg2 {pa = i8, pb = i8, rc = 8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>

test/TritonIntelGPU/split-barrier.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
2323
// CHECK: ttig.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
2424
// CHECK-NEXT: ttig.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
2525
// CHECK: scf.for %[[IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x256xf32, #mma>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>)
26-
// WORKGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive <Workgroup> <Workgroup> <None>
27-
// SUBGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive <Subgroup> <Subgroup> <None>
26+
// WORKGROUP_SCOPE-NEXT: triton_gen.split_barrier_arrive {execution_scope = WorkGroup, memory_scope = WorkGroup}
27+
// SUBGROUP_SCOPE-NEXT: triton_gen.split_barrier_arrive {execution_scope = SubGroup, memory_scope = SubGroup}
2828
// CHECK: ttig.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
2929
// CHECK: ttig.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
3030
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
31-
// WORKGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Workgroup> <Workgroup> <None>
32-
// SUBGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Subgroup> <Subgroup> <None>
31+
// WORKGROUP_SCOPE: triton_gen.split_barrier_wait {execution_scope = WorkGroup, memory_scope = WorkGroup}
32+
// SUBGROUP_SCOPE: triton_gen.split_barrier_wait {execution_scope = SubGroup, memory_scope = SubGroup}
3333
// CHECK-NEXT: scf.yield
3434
%23:3 = scf.for %arg2 = %c0_i32 to %c64_i32 step %c64_i32 iter_args(%arg3 = %cst, %arg4 = %18, %arg5 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
3535
%55:3 = scf.for %arg9 = %c0_i32 to %c64_i32 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
@@ -70,13 +70,13 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
7070
// CHECK: ttig.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
7171
// CHECK-NEXT: ttig.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
7272
// CHECK: scf.for %[[IV:.*]] = {{.*}} to {{.*}} step {{.*}} iter_args({{.*}}) -> (tensor<128x256xf32, #mma>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>, !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>, !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>)
73-
// WORKGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive <Workgroup> <Workgroup> <None>
74-
// SUBGROUP_SCOPE-NEXT: spirv.INTEL.ControlBarrierArrive <Subgroup> <Subgroup> <None>
73+
// WORKGROUP_SCOPE-NEXT: triton_gen.split_barrier_arrive {execution_scope = WorkGroup, memory_scope = WorkGroup}
74+
// SUBGROUP_SCOPE-NEXT: triton_gen.split_barrier_arrive {execution_scope = SubGroup, memory_scope = SubGroup}
7575
// CHECK: ttig.prefetch {{.*}} : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
7676
// CHECK: ttig.prefetch {{.*}} : !tt.ptr<tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>
7777
// CHECK: tt.dot {{.*}} : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<128x256xf32, #[[$DPAS]]>
78-
// WORKGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Workgroup> <Workgroup> <None>
79-
// SUBGROUP_SCOPE: spirv.INTEL.ControlBarrierWait <Subgroup> <Subgroup> <None>
78+
// WORKGROUP_SCOPE: triton_gen.split_barrier_wait {execution_scope = WorkGroup, memory_scope = WorkGroup}
79+
// SUBGROUP_SCOPE: triton_gen.split_barrier_wait {execution_scope = SubGroup, memory_scope = SubGroup}
8080
// CHECK-NEXT: scf.yield
8181
%23:3 = scf.for %arg9 = %c0_i32 to %c64_i32 step %c64_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<128x256xf32, #dpas>, !tt.ptr<tensor<128x64xf16, #dot0>>, !tt.ptr<tensor<64x256xf16, #dot1>>) : i32 {
8282
%56 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<128x64xf16, #dot0>>

third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,59 @@ def TritonGEN_BarrierOp : TritonGEN_Op<"barrier"> {
4747
}];
4848
}
4949

50+
def TritonGEN_SplitBarrierArriveOp : TritonGEN_Op<"split_barrier_arrive"> {
51+
let summary = "Split barrier signal";
52+
let description = [{
53+
Indicates that an invocation has arrived at a split control barrier. This
54+
may allow other invocations waiting on the split control barrier to continue
55+
executing.
56+
57+
When `Execution` is `Workgroup` or larger, behavior is undefined unless all
58+
invocations within `Execution` execute the same dynamic instance of this
59+
instruction. When `Execution` is `Subgroup` or `Invocation`, the behavior of
60+
this instruction in non-uniform control flow is defined by the client API.
61+
62+
If `Semantics` is not `None`, this instruction also serves as the start of a
63+
memory barrier similar to an `OpMemoryBarrier` instruction with the same
64+
`Memory` and `Semantics` operands. This allows atomically specifying both a
65+
control barrier and a memory barrier (that is, without needing two
66+
instructions). If `Semantics` is `None`, `Memory` is ignored.
67+
}];
68+
let arguments = (ins TritonGEN_MemScope:$execution_scope, TritonGEN_MemScope:$memory_scope);
69+
let results = (outs);
70+
let assemblyFormat = [{
71+
` ` `{` `execution_scope` `=` $execution_scope `,` `memory_scope` `=` $memory_scope `}` attr-dict
72+
}];
73+
}
74+
75+
def TritonGEN_SplitBarrierWaitOp : TritonGEN_Op<"split_barrier_wait"> {
76+
let summary = "Split barrier wait";
77+
let description = [{
78+
Waits for other invocations of this module to arrive at a split control
79+
barrier.
80+
81+
When `Execution` is `Workgroup` or larger, behavior is undefined unless all
82+
invocations within `Execution` execute the same dynamic instance of this
83+
instruction. When `Execution` is `Subgroup` or `Invocation`, the behavior of
84+
this instruction in non-uniform control flow is defined by the client API.
85+
86+
If `Semantics` is not `None`, this instruction also serves as the end of a
87+
memory barrier similar to an `OpMemoryBarrier` instruction with the same
88+
`Memory` and `Semantics` operands. This ensures that memory accesses issued
89+
before arriving at the split barrier are observed before memory accesses
90+
issued after this instruction. This control is ensured only for memory
91+
accesses issued by this invocation and observed by another invocation
92+
executing within `Memory` scope. This allows atomically specifying both a
93+
control barrier and a memory barrier (that is, without needing two
94+
instructions). If `Semantics` is `None`, `Memory` is ignored.
95+
}];
96+
let arguments = (ins TritonGEN_MemScope:$execution_scope, TritonGEN_MemScope:$memory_scope);
97+
let results = (outs);
98+
let assemblyFormat = [{
99+
` ` `{` `execution_scope` `=` $execution_scope `,` `memory_scope` `=` $memory_scope `}` attr-dict
100+
}];
101+
}
102+
50103
//===----------------------------------------------------------------------===//
51104
// Matrix operations
52105
//===----------------------------------------------------------------------===//

third_party/intel/lib/TritonGENToSPIRV/TritonGENToSPIRVPass.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ using namespace mlir::triton;
2424

2525
namespace {
2626

27+
static spirv::Scope getSpirvScope(TritonGEN::MemScope scope) {
28+
switch (scope) {
29+
case TritonGEN::MemScope::WORK_GROUP:
30+
return spirv::Scope::Workgroup;
31+
case TritonGEN::MemScope::SUB_GROUP:
32+
return spirv::Scope::Subgroup;
33+
default:
34+
llvm_unreachable("unexpected MemScope value");
35+
}
36+
}
37+
2738
struct TritonGENBarrierLowering
2839
: public OpConversionPattern<TritonGEN::BarrierOp> {
2940
using OpConversionPattern<TritonGEN::BarrierOp>::OpConversionPattern;
@@ -57,6 +68,35 @@ struct TritonGENBarrierLowering
5768
}
5869
};
5970

71+
struct TritonGENSplitBarrierArriveLowering
72+
: public OpConversionPattern<TritonGEN::SplitBarrierArriveOp> {
73+
using OpConversionPattern<
74+
TritonGEN::SplitBarrierArriveOp>::OpConversionPattern;
75+
76+
LogicalResult
77+
matchAndRewrite(TritonGEN::SplitBarrierArriveOp op, OpAdaptor adaptor,
78+
ConversionPatternRewriter &rewriter) const override {
79+
rewriter.replaceOpWithNewOp<spirv::INTELControlBarrierArriveOp>(
80+
op, getSpirvScope(op.getExecutionScope()),
81+
getSpirvScope(op.getMemoryScope()), spirv::MemorySemantics::None);
82+
return success();
83+
}
84+
};
85+
86+
struct TritonGENSplitBarrierWaitLowering
87+
: public OpConversionPattern<TritonGEN::SplitBarrierWaitOp> {
88+
using OpConversionPattern<TritonGEN::SplitBarrierWaitOp>::OpConversionPattern;
89+
90+
LogicalResult
91+
matchAndRewrite(TritonGEN::SplitBarrierWaitOp op, OpAdaptor adaptor,
92+
ConversionPatternRewriter &rewriter) const override {
93+
rewriter.replaceOpWithNewOp<spirv::INTELControlBarrierWaitOp>(
94+
op, getSpirvScope(op.getExecutionScope()),
95+
getSpirvScope(op.getMemoryScope()), spirv::MemorySemantics::None);
96+
return success();
97+
}
98+
};
99+
60100
} // namespace
61101

62102
//===----------------------------------------------------------------------===//
@@ -100,5 +140,6 @@ struct ConvertTritonGENToSPIRV
100140

101141
void mlir::triton::populateTritonGENToSPIRVConversionPatterns(
102142
RewritePatternSet &patterns) {
103-
patterns.add<TritonGENBarrierLowering>(patterns.getContext());
143+
patterns.add<TritonGENBarrierLowering, TritonGENSplitBarrierArriveLowering,
144+
TritonGENSplitBarrierWaitLowering>(patterns.getContext());
104145
}

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/SoftwarePipeliner.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ static bool preCondition(scf::ForOp forOp) {
3838
return true;
3939
}
4040

41-
static void
42-
pipelineLoop(scf::ForOp forOp, int numStages,
43-
std::optional<spirv::Scope> barrierScope = std::nullopt) {
41+
static void pipelineLoop(
42+
scf::ForOp forOp, int numStages,
43+
std::optional<triton::TritonGEN::MemScope> barrierScope = std::nullopt) {
4444
mlir::scf::PipeliningOption options;
4545
if (!preCondition(forOp))
4646
return;
@@ -60,18 +60,18 @@ pipelineLoop(scf::ForOp forOp, int numStages,
6060

6161
scf::ForOp loop = (*newForOp);
6262
if (barrierScope) {
63-
assert((*barrierScope == spirv::Scope::Subgroup) ||
64-
(*barrierScope == spirv::Scope::Workgroup) &&
63+
assert((*barrierScope == triton::TritonGEN::MemScope::SUB_GROUP) ||
64+
(*barrierScope == triton::TritonGEN::MemScope::WORK_GROUP) &&
6565
"The barrier scope must be SubGroup or Workgroup");
6666
OpBuilder b(loop);
6767
Location loc = loop.getLoc();
6868
b.setInsertionPointToStart(loop.getBody());
69-
b.create<spirv::INTELControlBarrierArriveOp>(
70-
loc, *barrierScope, *barrierScope, spirv::MemorySemantics::None);
69+
b.create<triton::TritonGEN::SplitBarrierArriveOp>(loc, *barrierScope,
70+
*barrierScope);
7171
auto yield = cast<scf::YieldOp>(loop.getBody()->getTerminator());
7272
b.setInsertionPoint(yield);
73-
b.create<spirv::INTELControlBarrierWaitOp>(
74-
loc, *barrierScope, *barrierScope, spirv::MemorySemantics::None);
73+
b.create<triton::TritonGEN::SplitBarrierWaitOp>(loc, *barrierScope,
74+
*barrierScope);
7575
}
7676
}
7777

@@ -92,15 +92,15 @@ struct IntelGPUPipelinePass
9292
if (numStages <= 1)
9393
return;
9494

95-
std::optional<spirv::Scope> barrierScope = std::nullopt;
95+
std::optional<triton::TritonGEN::MemScope> barrierScope = std::nullopt;
9696
switch (splitBarrierScope) {
9797
case ttgi::SplitBarrierScope::None:
9898
break;
9999
case ttgi::SplitBarrierScope::Workgroup:
100-
barrierScope = spirv::Scope::Workgroup;
100+
barrierScope = triton::TritonGEN::MemScope::WORK_GROUP;
101101
break;
102102
case ttgi::SplitBarrierScope::Subgroup:
103-
barrierScope = spirv::Scope::Subgroup;
103+
barrierScope = triton::TritonGEN::MemScope::SUB_GROUP;
104104
break;
105105
}
106106

0 commit comments

Comments
 (0)