Skip to content

Commit 0ef2a3f

Browse files
authored
Merge branch 'main' into issue2662
2 parents 9f74d61 + 9952acf commit 0ef2a3f

File tree

11 files changed

+268
-78
lines changed

11 files changed

+268
-78
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory | FileCheck %s
2+
3+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
4+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
5+
6+
// Check no scratch memory is allocated for sub-group shuffle-like layout conversions.
7+
8+
// CHECK-LABEL: module attributes
9+
// CHECK-SAME: triton_gpu.shared = 0 : i32
10+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
11+
// CHECK: tt.func @test_sub_group_shuffle
12+
// CHECK-NOT: llvm.ptr<3>
13+
tt.func @test_sub_group_shuffle(%arg0: tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
14+
%0 = triton_gpu.convert_layout %arg0 : tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
15+
tt.return %0 : tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
16+
}
17+
}
18+
19+
// -----
20+
21+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
22+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
23+
24+
// Check scracth memory configuration for different sub-group transpose-like layout conversions.
25+
26+
// CHECK-LABEL: module attributes
27+
// CHECK-SAME: triton_gpu.shared = 512 : i32
28+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
29+
tt.func @test_f16(%arg0: tensor<16x16xf16, #blocked>) -> tensor<16x16xf16, #blocked1> {
30+
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #blocked1>
31+
tt.return %0 : tensor<16x16xf16, #blocked1>
32+
}
33+
}
34+
35+
// -----
36+
37+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
38+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
39+
40+
// Check scracth memory configuration for different sub-group transpose-like layout conversions.
41+
42+
// CHECK-LABEL: module attributes
43+
// CHECK-SAME: triton_gpu.shared = 1024 : i32
44+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
45+
tt.func @test_f32(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16x16xf32, #blocked1> {
46+
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #blocked1>
47+
tt.return %0 : tensor<16x16xf32, #blocked1>
48+
}
49+
}
50+
51+
// -----
52+
53+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 2], order = [0, 1]}>
54+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [0, 1]}>
55+
56+
// Check scracth memory configuration for different sub-group transpose-like layout conversions.
57+
58+
// CHECK-LABEL: module attributes
59+
// CHECK-SAME: triton_gpu.shared = 32768 : i32
60+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
61+
tt.func @test_f32(%arg0: tensor<128x64xf32, #blocked>) -> tensor<128x64xf32, #blocked1> {
62+
%0 = triton_gpu.convert_layout %arg0 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked1>
63+
tt.return %0 : tensor<128x64xf32, #blocked1>
64+
}
65+
}

test/Conversion/intel/sub-group-shuffle.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
1111
// CHECK-LABEL: llvm.func spir_kernelcc @test_f16(
12-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16)>,
12+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f16)>)
1313
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f16)>
1414
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
1515
// CHECK: llvm.call spir_funccc @_Z17sub_group_shuffleDhj(%[[VAL_2]], %[[VAL_4]])
@@ -49,7 +49,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
4949
}
5050

5151
// CHECK-LABEL: llvm.func spir_kernelcc @test_bf16(
52-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(bf16)>,
52+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(bf16)>)
5353
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(bf16)>
5454
// CHECK: %[[VAL_2:.*]] = llvm.bitcast %[[VAL_1]] : bf16 to i16
5555
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -91,7 +91,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
9191
}
9292

9393
// CHECK-LABEL: llvm.func spir_kernelcc @test_i1(
94-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(i1)>,
94+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(i1)>)
9595
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(i1)>
9696
// CHECK: %[[VAL_2:.*]] = llvm.zext %[[VAL_1]] : i1 to i8
9797
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -133,7 +133,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
133133
}
134134

135135
// CHECK-LABEL: llvm.func spir_kernelcc @test_ptr(
136-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr<1>)>,
136+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr<1>)>)
137137
// CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(ptr<1>)>
138138
// CHECK: %[[VAL_2:.*]] = llvm.ptrtoint %[[VAL_1]] : !llvm.ptr<1> to i64
139139
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -186,7 +186,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
186186

187187
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
188188
// CHECK-LABEL: llvm.func spir_kernelcc @test_f32(
189-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f32)>,
189+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f32)>)
190190
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f32)>
191191
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i32) : i32
192192
// CHECK: llvm.call spir_funccc @_Z17sub_group_shufflefj(%[[VAL_2]], %[[VAL_4]])
@@ -269,7 +269,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
269269

270270
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
271271
// CHECK-LABEL: llvm.func spir_kernelcc @test_non_sliced_multi_register(
272-
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f64, f64)>,
272+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(f64, f64)>)
273273
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][0] : !llvm.struct<(f64, f64)>
274274
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(f64, f64)>
275275
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(0 : i32) : i32

test/Conversion/intel/sub-group-transpose.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
426426
tt.return %0 : tensor<32x64xf32, #blocked1>
427427
}
428428
}
429+
430+
// -----
431+
432+
// Test no barriers are inserted when back to back transpositions are performed.
433+
434+
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1]}>
435+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [2, 2], order = [0, 1]}>
436+
437+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
438+
// CHECK-LABEL: llvm.func spir_kernelcc @test_back_to_back
439+
// CHECK-NOT: barrier
440+
tt.func @test_back_to_back(%arg0: tensor<32x64xf32, #blocked>, %arg1: tensor<32x64xf32, #blocked>) -> (tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>) {
441+
%0 = triton_gpu.convert_layout %arg0 : tensor<32x64xf32, #blocked> -> tensor<32x64xf32, #blocked1>
442+
%1 = triton_gpu.convert_layout %arg1 : tensor<32x64xf32, #blocked> -> tensor<32x64xf32, #blocked1>
443+
tt.return %0, %1 : tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>
444+
}
445+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#ifndef TRITON_INTEL_ANALYSIS_MEMBAR_H
2+
#define TRITON_INTEL_ANALYSIS_MEMBAR_H
3+
4+
namespace mlir {
5+
class Operation;
6+
namespace intel {
7+
/// Intel-specific callback to filter operations that need no barriers between
8+
/// each other.
9+
///
10+
/// This is useful as the granularity to check whether barriers are needed is
11+
/// quite coarse. The filter will return true if no barrier is needed between
12+
/// `lhsOp` and `rhsOp`.
13+
bool membarFilter(Operation *lhsOp, Operation *rhsOp);
14+
} // namespace intel
15+
} // namespace mlir
16+
17+
#endif // TRITON_INTEL_ANALYSIS_MEMBAR_H

third_party/intel/lib/Analysis/Allocation.cpp

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1616
#include "llvm/ADT/SmallVector.h"
1717

18+
#include "intel/include/Analysis/Utility.h"
19+
1820
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
1921
using ::mlir::triton::gpu::BlockedEncodingAttr;
2022
using ::mlir::triton::gpu::DotOperandEncodingAttr;
@@ -51,10 +53,10 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
5153

5254
// mma or dot layout does not have an order, so the order depends on the
5355
// layout of the other operand.
54-
auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout)
55-
: getOrder(srcLayout);
56-
auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout)
57-
: getOrder(dstLayout);
56+
const auto &inOrd = (srcMmaLayout || srcDotLayout) ? getOrder(dstLayout)
57+
: getOrder(srcLayout);
58+
const auto &outOrd = (dstMmaLayout || dstDotLayout) ? getOrder(srcLayout)
59+
: getOrder(dstLayout);
5860

5961
return {inOrd, outOrd};
6062
}
@@ -104,6 +106,26 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
104106

105107
ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
106108
RankedTensorType dstTy) {
109+
if (gpu::intel::cvtIsSubGroupShuffle(srcTy, dstTy)) {
110+
// Conversions that can be implemented as sub-group shuffles do not need
111+
// scratch memory.
112+
return ScratchConfig({}, {});
113+
}
114+
115+
if (gpu::intel::cvtIsSubGroupTranspose(srcTy, dstTy)) {
116+
// Conversions that can be implemented as sub-group transposes store the
117+
// whole tensor in shared memory and read it afterwards.
118+
auto srcEncoding = cast<gpu::DistributedEncodingTrait>(srcTy.getEncoding());
119+
unsigned threadsPerWarp = product(srcEncoding.getThreadsPerWarp());
120+
unsigned warpsPerCTA = product(srcEncoding.getWarpsPerCTA());
121+
unsigned remaining = product(srcTy.getShape()) /
122+
(threadsPerWarp * threadsPerWarp * warpsPerCTA);
123+
SmallVector<unsigned> repShape{threadsPerWarp, threadsPerWarp, remaining,
124+
warpsPerCTA};
125+
return ScratchConfig(repShape, repShape,
126+
/*inVec=*/1, /*outVec=*/threadsPerWarp);
127+
}
128+
107129
// Initialize vector sizes and stride
108130
auto repShape = getRepShapeForCvt(srcTy, dstTy);
109131
if (repShape.empty())
@@ -346,7 +368,7 @@ class AllocationAnalysis {
346368
/// arguments are involved.
347369
void resolveAliasBufferLiveness(
348370
function_ref<Interval<size_t>(Value value)> getLiveness) {
349-
for (auto aliasBufferIter : allocation->getAliasBuffer()) {
371+
for (const auto &aliasBufferIter : allocation->getAliasBuffer()) {
350372
auto value = aliasBufferIter.first;
351373
auto buffers = aliasBufferIter.second;
352374
auto range = getLiveness(value);
@@ -486,7 +508,7 @@ class AllocationAnalysis {
486508
std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) {
487509
auto xRange = bufferRange[buffer];
488510
bool res = xRange.intersects(range);
489-
for (auto val : tripleMap)
511+
for (const auto &val : tripleMap)
490512
res = res &&
491513
!val.second.intersects(xRange); // only one buffer intersect
492514
return res;

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
123123
divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d));
124124
}
125125
}
126-
return AxisInfo(contiguity, divisibility, constancy, constantValue);
126+
return AxisInfo(std::move(contiguity), std::move(divisibility),
127+
std::move(constancy), constantValue);
127128
}
128129

129130
protected:
@@ -543,7 +544,8 @@ class SplatOpAxisInfoVisitor final
543544
divisibility.push_back(opInfo.getDivisibility(0));
544545
constancy.push_back(retTy.getShape()[d]);
545546
}
546-
return AxisInfo(contiguity, divisibility, constancy,
547+
return AxisInfo(std::move(contiguity), std::move(divisibility),
548+
std::move(constancy),
547549
operands[0]->getValue().getConstantValue());
548550
}
549551
};
@@ -574,7 +576,8 @@ class LoadOpAxisInfoVisitor final : public AxisInfoVisitorImpl<triton::LoadOp> {
574576
maskInfo.has_value() ? maskInfo->getConstancy(d) : 0));
575577
}
576578

577-
return AxisInfo(contiguity, divisibility, constancy);
579+
return AxisInfo(std::move(contiguity), std::move(divisibility),
580+
std::move(constancy));
578581
}
579582
};
580583

@@ -608,7 +611,8 @@ class ExpandDimsOpAxisInfoVisitor final
608611
contiguity.insert(contiguity.begin() + op.getAxis(), 1);
609612
divisibility.insert(divisibility.begin() + op.getAxis(), newDivisibility);
610613
constancy.insert(constancy.begin() + op.getAxis(), 1);
611-
return AxisInfo(contiguity, divisibility, constancy,
614+
return AxisInfo(std::move(contiguity), std::move(divisibility),
615+
std::move(constancy),
612616
operands[0]->getValue().getConstantValue());
613617
}
614618
};
@@ -637,7 +641,8 @@ class BroadcastOpAxisInfoVisitor final
637641
constancy.push_back(opShape[d] == 1 ? retShape[d]
638642
: opInfo.getConstancy(d));
639643
}
640-
return AxisInfo(contiguity, divisibility, constancy,
644+
return AxisInfo(std::move(contiguity), std::move(divisibility),
645+
std::move(constancy),
641646
operands[0]->getValue().getConstantValue());
642647
}
643648
};
@@ -712,7 +717,8 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
712717
contiguity.push_back(1);
713718
}
714719

715-
return AxisInfo(contiguity, divisibility, constancy, constantValue);
720+
return AxisInfo(std::move(contiguity), std::move(divisibility),
721+
std::move(constancy), constantValue);
716722
}
717723

718724
private:
@@ -840,7 +846,8 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
840846
constantValue = lhsInfo.getConstantValue();
841847
}
842848

843-
return AxisInfo(contiguity, divisibility, constancy, constantValue);
849+
return AxisInfo(std::move(contiguity), std::move(divisibility),
850+
std::move(constancy), constantValue);
844851
}
845852
};
846853

@@ -993,7 +1000,8 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
9931000
contiguity.push_back(
9941001
std::min(lhsInfo.getContiguity(d), rhsInfo.getContiguity(d)));
9951002
}
996-
return AxisInfo(contiguity, divisibility, constancy, std::nullopt);
1003+
return AxisInfo(std::move(contiguity), std::move(divisibility),
1004+
std::move(constancy), std::nullopt);
9971005
}
9981006
}
9991007
};
@@ -1038,7 +1046,8 @@ class MakeTensorPtrOpAxisInfoVisitor final
10381046
constancy.push_back(1);
10391047
}
10401048

1041-
auto axisInfo = AxisInfo(contiguity, divisibility, constancy);
1049+
auto axisInfo = AxisInfo(std::move(contiguity), std::move(divisibility),
1050+
std::move(constancy));
10421051

10431052
LLVM_DEBUG({
10441053
std::string axisStr;
@@ -1143,8 +1152,8 @@ LogicalResult AxisInfoAnalysis::visitOperation(
11431152
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
11441153
newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end());
11451154
}
1146-
curr = AxisInfo(newContiguity, newDivisibility, newConstancy,
1147-
curr.getConstantValue());
1155+
curr = AxisInfo(std::move(newContiguity), std::move(newDivisibility),
1156+
std::move(newConstancy), curr.getConstantValue());
11481157
// join all lattice elements
11491158
for (auto *result : results)
11501159
propagateIfChanged(result, result->join(curr));
@@ -1154,17 +1163,18 @@ LogicalResult AxisInfoAnalysis::visitOperation(
11541163
void AxisInfoAnalysis::visitForOpInductionVar(
11551164
scf::ForOp op, ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices) {
11561165
ProgramPoint programPoint(op);
1157-
const auto lb =
1166+
const auto &lb =
11581167
getLatticeElementFor(&programPoint, op.getLowerBound())->getValue();
1159-
const auto step =
1168+
const auto &step =
11601169
getLatticeElementFor(&programPoint, op.getStep())->getValue();
11611170

11621171
AxisInfo::DimVectorT knownContiguity(1, 1);
11631172
AxisInfo::DimVectorT knownDivisibility(1, 1);
11641173
AxisInfo::DimVectorT knownConstancy(1, 1);
11651174
knownDivisibility[0] = gcd(lb.getDivisibility(0), step.getDivisibility(0));
11661175
auto inductionVar =
1167-
AxisInfo(knownContiguity, knownDivisibility, knownConstancy);
1176+
AxisInfo(std::move(knownContiguity), std::move(knownDivisibility),
1177+
std::move(knownConstancy));
11681178
(void)argLattices[0]->join(inductionVar);
11691179
}
11701180

third_party/intel/lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_triton_library(TritonIntelAnalysis
33
AxisInfo.cpp
44
DPAS.cpp
55
Liveness.cpp
6+
Membar.cpp
67
Utility.cpp
78

89
DEPENDS

0 commit comments

Comments
 (0)