Skip to content

Commit 7d03355

Browse files
Merge commit '755d4164081b92a909df2e1ad4c56174c8ce5529'
2 parents 3e1165f + 755d416 commit 7d03355

File tree

16 files changed

+260
-85
lines changed

16 files changed

+260
-85
lines changed

include/triton/Tools/LayoutUtils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#ifndef TRITON_TOOLS_LAYOUTUTILS_H
2+
#define TRITON_TOOLS_LAYOUTUTILS_H
3+
4+
#include "triton/Tools/LinearLayout.h"
5+
6+
namespace mlir::triton {
7+
// Is the sublayout defined from dimNames to dimNames the identity?
8+
// In particular, is the input and output size in these dimensions
9+
// the same, and are the bases the identity?
10+
bool squareSublayoutIsIdentity(const LinearLayout &ll,
11+
ArrayRef<StringAttr> dimNames);
12+
} // namespace mlir::triton
13+
14+
#endif // TRITON_TOOLS_LAYOUTUTILS_H

include/triton/Tools/LinearLayout.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -611,11 +611,6 @@ class LinearLayout {
611611
bool sublayoutIsZero(ArrayRef<StringAttr> inDimNames,
612612
ArrayRef<StringAttr> outDimNames) const;
613613

614-
// Is the sublayout defined from dimNames to dimNames the identity?
615-
// In particular, is the input and output size in these dimensions
616-
// the same, and are the bases the identity?
617-
bool squareSublayoutIsIdentity(ArrayRef<StringAttr> dimNames) const;
618-
619614
// Computes and returns L(x, y, z).
620615
//
621616
// If you want to apply the layout to mlir Values instead of integers, that

lib/Analysis/Utility.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,14 @@ bool ReduceOpHelper::isSupportedLayout() {
225225
}
226226

227227
auto srcLayout = getSrcLayout();
228-
if (isa<BlockedEncodingAttr>(srcLayout)) {
228+
if (isa<BlockedEncodingAttr, LinearEncodingAttr, SliceEncodingAttr>(
229+
srcLayout)) {
229230
return true;
230231
}
232+
231233
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(srcLayout)) {
232234
return mmaLayout.supportReduction();
233235
}
234-
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(srcLayout)) {
235-
return true;
236-
}
237236
return false;
238237
}
239238

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,7 +1470,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,
14701470

14711471
SmallVector<unsigned> ret(rank, 1);
14721472
auto nonZero = [](auto val) { return val != 0; };
1473-
int nonZeroIdx = -1;
1473+
int nonZeroIdx = 0;
14741474
for (const auto &basis : bases) {
14751475
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
14761476
// Bases can have one or zero non-zero elements
@@ -1482,7 +1482,6 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,
14821482
} else if (!skipBroadcast) {
14831483
// If we've seen a non-zero basis, we double the size of the previous dim
14841484
// This is just needed to count the CTAsPerCGA
1485-
assert(nonZeroIdx != -1);
14861485
ret[nonZeroIdx] *= 2;
14871486
}
14881487
}
@@ -1633,7 +1632,8 @@ LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
16331632
// the invariant that the shape of the LL is that of the tensor
16341633
// We choose the former for BC
16351634
auto ll = *toLinearLayout(shape);
1636-
return basesPerDim(ll, StringAttr::get(getContext(), "register"));
1635+
return basesPerDim(ll, StringAttr::get(getContext(), "register"),
1636+
/*skipBroadcast=*/false);
16371637
}
16381638

16391639
// Start of Selection

lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,27 @@ namespace gpu {
2323

2424
namespace {
2525

26+
bool hasGpuBarriers(scf::ForOp forOp) {
27+
WalkResult result = forOp.walk(
28+
[&](mlir::gpu::BarrierOp barrier) { return WalkResult::interrupt(); });
29+
return result.wasInterrupted();
30+
}
31+
32+
// Return true if the preconditions for pipelining the loop are met.
33+
bool isSafeToPipeline(scf::ForOp forOp) {
34+
// Skip loop with distance > 1 for now.
35+
// TODO: relax the constraint in the expander.
36+
if (loopHasDistGreaterThanOne(forOp))
37+
return false;
38+
// Don't pipeline outer loops.
39+
if (isOuterLoop(forOp))
40+
return false;
41+
// Skip loops with barriers.
42+
if (hasGpuBarriers(forOp))
43+
return false;
44+
return true;
45+
}
46+
2647
bool hasLatenciesAssigned(scf::ForOp forOp,
2748
const DenseMap<Operation *, int> &opLatency) {
2849
for (auto &op : forOp.getBody()->without_terminator()) {
@@ -261,7 +282,7 @@ void scheduleRemainingToLastStage(scf::ForOp forOp, CoarseSchedule &schedule,
261282

262283
void scheduleLoop(scf::ForOp forOp,
263284
const DenseMap<Operation *, int> &opLatency) {
264-
if (!hasLatenciesAssigned(forOp, opLatency))
285+
if (!hasLatenciesAssigned(forOp, opLatency) || !isSafeToPipeline(forOp))
265286
return;
266287
// Based on the latencies, schedule the key ops to the stages.
267288
CoarseSchedule schedule = scheduleKeyOps(forOp, opLatency);

lib/Dialect/TritonGPU/Transforms/Pipeliner/SoftwarePipeliner.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,6 @@ namespace gpu {
3333
#define GEN_PASS_DEF_TRITONGPUPIPELINE
3434
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
3535

36-
// Return true if the preconditions for pipelining the loop are met.
37-
static bool preCondition(scf::ForOp forOp) {
38-
// Skip loop with distance > 1 for now.
39-
// TODO: relax the constraint in the expander.
40-
if (loopHasDistGreaterThanOne(forOp))
41-
return false;
42-
// Don't pipeline outer loops.
43-
if (isOuterLoop(forOp))
44-
return false;
45-
return true;
46-
}
47-
4836
static void tryAndPipelineOuterLoop(scf::ForOp forOp) {
4937
mlir::triton::PipeliningOption options;
5038
bool foundSchedule = false;
@@ -60,8 +48,6 @@ static void tryAndPipelineOuterLoop(scf::ForOp forOp) {
6048

6149
static bool pipelineLoop(scf::ForOp forOp, int numStages) {
6250
mlir::triton::PipeliningOption options;
63-
if (!preCondition(forOp))
64-
return false;
6551

6652
bool foundSchedule = false;
6753
foundSchedule = preProcessLoopAndGetSchedule(forOp, numStages, options);

lib/Tools/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_triton_library(TritonTools
2+
LayoutUtils.cpp
23
LinearLayout.cpp
34

45
DEPENDS

lib/Tools/LayoutUtils.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "triton/Tools/LayoutUtils.h"
2+
3+
namespace mlir::triton {
4+
5+
bool squareSublayoutIsIdentity(const LinearLayout &ll,
6+
ArrayRef<StringAttr> dimNames) {
7+
// The empty layout is the identity
8+
if (dimNames.size() == 0) {
9+
return true;
10+
}
11+
// Check that the input-output sizes are the same
12+
LinearLayout sl = ll.sublayout(dimNames, dimNames);
13+
for (StringAttr dim : dimNames) {
14+
if (ll.getInDimSize(dim) != ll.getOutDimSize(dim)) {
15+
return false;
16+
}
17+
}
18+
// Once the inputs and output dimensions are the same, we can just check
19+
// that the basis for the single remaining dimension is the identity.
20+
sl = sl.flattenIns().flattenOuts();
21+
int b = 0;
22+
const auto &inDimBases = sl.getBases().begin()->second;
23+
for (auto basis : inDimBases) {
24+
if (basis[0] != (1 << b)) {
25+
return false;
26+
}
27+
b++;
28+
}
29+
return true;
30+
}
31+
32+
} // namespace mlir::triton

lib/Tools/LinearLayout.cpp

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "mlir/IR/BuiltinAttributes.h"
88
#include "third_party/f2reduce/f2reduce.h"
9+
#include "triton/Tools/LayoutUtils.h"
910
#include "triton/Tools/StrUtil.h"
1011
#include "llvm/ADT/STLExtras.h"
1112
#include "llvm/ADT/SetOperations.h"
@@ -651,7 +652,7 @@ bool LinearLayout::isTrivialOver(ArrayRef<StringAttr> dimNames) const {
651652
// We can quotient out dimNames iff they don't affect the remainingInDimNames
652653
// in the result. In other words, we want to check that B is zero, and C is
653654
// zero, and D is the identity
654-
return squareSublayoutIsIdentity(dimNames) &&
655+
return squareSublayoutIsIdentity(*this, dimNames) &&
655656
sublayoutIsZero(remainingInDimNames, dimNames) &&
656657
sublayoutIsZero(dimNames, remainingOutDimNames);
657658
}
@@ -730,33 +731,6 @@ bool LinearLayout::sublayoutIsZero(ArrayRef<StringAttr> inDimNames,
730731
return true;
731732
}
732733

733-
bool LinearLayout::squareSublayoutIsIdentity(
734-
ArrayRef<StringAttr> dimNames) const {
735-
// The empty layout is the identity
736-
if (dimNames.size() == 0) {
737-
return true;
738-
}
739-
// Check that the input-output sizes are the same
740-
LinearLayout sl = sublayout(dimNames, dimNames);
741-
for (StringAttr dim : dimNames) {
742-
if (getInDimSize(dim) != getOutDimSize(dim)) {
743-
return false;
744-
}
745-
}
746-
// Once the inputs and output dimensions are the same, we can just check
747-
// that the basis for the single remaining dimension is the identity.
748-
sl = sl.flattenIns().flattenOuts();
749-
int b = 0;
750-
const auto &inDimBases = sl.bases.begin()->second;
751-
for (auto basis : inDimBases) {
752-
if (basis[0] != (1 << b)) {
753-
return false;
754-
}
755-
b++;
756-
}
757-
return true;
758-
}
759-
760734
SmallVector<std::pair<StringAttr, int32_t>>
761735
LinearLayout::apply(ArrayRef<std::pair<StringAttr, int32_t>> ins) const {
762736
assertDimsEqualIgnoringOrder(llvm::make_first_range(ins), getInDimNames());
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s
2+
3+
#linear = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>
4+
5+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
6+
7+
// CHECK-LABEL: @reduce_linear_layout
8+
tt.func private @reduce_linear_layout(%arg0: tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> {
9+
// CHECK-NEXT: [[SRC0:%.*]] = extractvalue {{.*}} %0, 0
10+
// CHECK-NEXT: [[SRC1:%.*]] = extractvalue {{.*}} %0, 1
11+
// CHECK-NEXT: [[SRC2:%.*]] = extractvalue {{.*}} %0, 2
12+
// CHECK-NEXT: [[SRC3:%.*]] = extractvalue {{.*}} %0, 3
13+
14+
// The layout looks lke
15+
// [[ T0:0, T32:0, T0:1, T32:1, ...
16+
// [ T4:0, T36:0, T4:1, T36:1, ...
17+
// [ T0:2, T32:2, T0:3, T32:3, ...
18+
// [ T4:2, T36:2, T4:3, T36:3,
19+
// ...
20+
//
21+
// A reduction along axis=0 consists of adding registers (0, 2) and (1, 3)
22+
// before shuffling.
23+
//
24+
// Columns along axis=0 are contained within a warp, so reduction arcoss warps
25+
// is not needed.
26+
27+
// Reduce within threads
28+
// CHECK-NEXT: [[SUM0:%.*]] = add i32 [[SRC0]], [[SRC2]]
29+
// CHECK-NEXT: [[SUM1:%.*]] = add i32 [[SRC1]], [[SRC3]]
30+
31+
// Reduce within warp.
32+
// CHECK-NEXT: [[W0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM0]], i32 16, i32 31)
33+
// CHECK-NEXT: [[WSUM0:%.*]] = add i32 [[W0]], [[SUM0]]
34+
// CHECK-NEXT: [[W1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM0]], i32 8, i32 31)
35+
// CHECK-NEXT: [[WSUM1:%.*]] = add i32 [[WSUM0]], [[W1]]
36+
// CHECK-NEXT: [[W2:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM1]], i32 4, i32 31)
37+
// CHECK-NEXT: [[WSUM2:%.*]] = add i32 [[WSUM1]], [[W2]]
38+
// CHECK-NEXT: [[W3:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM2]], i32 2, i32 31)
39+
// CHECK-NEXT: [[WSUM3:%.*]] = add i32 [[WSUM2]], [[W3]]
40+
41+
// CHECK-NEXT: [[W4:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM1]], i32 16, i32 31)
42+
// CHECK-NEXT: [[WSUM4:%.*]] = add i32 [[W4]], [[SUM1]]
43+
// CHECK-NEXT: [[W5:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM4]], i32 8, i32 31)
44+
// CHECK-NEXT: [[WSUM5:%.*]] = add i32 [[WSUM4]], [[W5]]
45+
// CHECK-NEXT: [[W6:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM5]], i32 4, i32 31)
46+
// CHECK-NEXT: [[WSUM6:%.*]] = add i32 [[WSUM5]], [[W6]]
47+
// CHECK-NEXT: [[W7:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM6]], i32 2, i32 31)
48+
// CHECK-NEXT: [[WSUM7:%.*]] = add i32 [[WSUM6]], [[W7]]
49+
50+
// CHECK-NEXT: [[DST0:%.*]] = insertvalue { i32, i32 } undef, i32 [[WSUM3]], 0
51+
// CHECK-NEXT: [[DST1:%.*]] = insertvalue { i32, i32 } [[DST0]], i32 [[WSUM7]], 1
52+
53+
%0 = "tt.reduce"(%arg0) ({
54+
^bb0(%arg1: i32, %arg2: i32):
55+
%1 = arith.addi %arg1, %arg2 : i32
56+
tt.reduce.return %1 : i32
57+
}) {axis = 0 : i32} : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>>
58+
59+
// CHECK-NEXT: ret { i32, i32 } [[DST1]]
60+
tt.return %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>>
61+
}
62+
63+
tt.func @anchor(%ptr: !llvm.ptr, %arg0: tensor<32x16xi32, #linear>) {
64+
%0 = tt.call @reduce_linear_layout(%arg0) : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>>
65+
%1 = builtin.unrealized_conversion_cast %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> to !llvm.struct<(i32, i32)>
66+
llvm.store volatile %1, %ptr : !llvm.struct<(i32, i32)>, !llvm.ptr
67+
tt.return
68+
}
69+
70+
}

0 commit comments

Comments
 (0)