Skip to content

Commit cc1a80c

Browse files
Merge commit '0e9706cd4f5d69302e9b7331cc820fdad062c80b'
2 parents 40cde44 + 0e9706c commit cc1a80c

File tree

49 files changed

+2970
-540
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2970
-540
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@ jobs:
116116
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
117117
fi
118118
119+
# Run tests under triton/python/triton_kernels/tests/ on gfx950 and gfx942
120+
if [ "${{ matrix.runner[0] }}" = "amd-gfx950" ] || [ "${{ matrix.runner[0] }}" = "amd-gfx942" ]; then
121+
cd ../../triton_kernels/
122+
python3 -m pytest -s -n 12 tests/
123+
fi
124+
119125
- name: Run asan tests on AMD
120126
if: false
121127
run: |

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ test-unit: all
3939
TRITON_DISABLE_LINE_INFO=0 $(PYTEST) -s python/test/unit/language/test_line_info.py
4040
# Run attention separately to avoid out of gpu memory
4141
$(PYTEST) -vs python/tutorials/06-fused-attention.py
42+
$(PYTEST) -vs python/tutorials/gluon/01-attention-forward.py
4243
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
4344
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
4445
$(PYTEST) -s -n $(NUM_PROCS) python/test/gluon

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class TargetInfoBase {
9696

9797
virtual bool supportLdMatrix() const { return false; }
9898
virtual bool supportStMatrix() const { return false; }
99+
virtual bool isCuda() const { return false; }
99100

100101
// Annotate target specific information to local store operations during
101102
// lowering to LLVM.

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,11 +1143,10 @@ Row |
11431143
let hasCustomAssemblyFormat = 1;
11441144

11451145
let extraClassDeclaration = extraDistributedDeclaration # [{
1146-
SmallVector<int64_t> getElemsPerInstrForOperands() const;
1146+
SmallVector<int64_t> getElemsPerInstrForOperands(int kDim, int opIdx) const;
11471147
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape,
1148-
Type elemType, int kWidth, int opIdx) const;
1148+
Type elemType, int kWidth, int kDim, int opIdx) const;
11491149
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
1150-
unsigned getKWidthForOperands() const;
11511150
static SmallVector<unsigned> getMNKDimPerInstr();
11521151
}];
11531152
}

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class CoarseSchedule {
4545
const_iterator begin() const { return orderClusters.begin(); }
4646
iterator end() { return orderClusters.end(); }
4747
const_iterator end() const { return orderClusters.end(); }
48-
size_t size() { return orderClusters.size(); }
48+
size_t size() const { return orderClusters.size(); }
4949
iterator newAtBack() {
5050
orderClusters.push_back(orderClusters.size());
5151
return std::prev(orderClusters.end());
@@ -88,7 +88,7 @@ class CoarseSchedule {
8888
DenseMap<Operation *, std::pair<int, Cluster>> opToStageAndCluster;
8989

9090
void setNumStages(int numStages) { this->numStages = numStages; }
91-
int getNumStages() { return numStages; }
91+
int getNumStages() const { return numStages; }
9292

9393
void insert(Operation *op, int stage, Cluster cluster) {
9494
if (stage >= numStages) {
@@ -115,7 +115,7 @@ class CoarseSchedule {
115115

116116
void erase(Operation *op) { opToStageAndCluster.erase(op); }
117117

118-
int count(Operation *op) { return opToStageAndCluster.count(op); }
118+
int count(Operation *op) const { return opToStageAndCluster.count(op); }
119119

120120
std::pair<int, Cluster> operator[](Operation *op) {
121121
return opToStageAndCluster[op];
@@ -129,25 +129,25 @@ class CoarseSchedule {
129129
Cluster splitClusterBefore(Operation *op, scf::ForOp forOp);
130130

131131
// Check if op a will show up before op b in the final unrolled code.
132-
bool isOpBefore(Operation *a, Operation *b);
132+
bool isOpBefore(Operation *a, Operation *b) const;
133133

134134
// Check if op a is in earlier cluster than op b.
135-
bool isOpInEarlierCluster(Operation *a, Operation *b);
135+
bool isOpInEarlierCluster(Operation *a, Operation *b) const;
136136

137137
// Check if op a is in the same cluster as op b.
138-
bool isOpInSameCluster(Operation *a, Operation *b);
138+
bool isOpInSameCluster(Operation *a, Operation *b) const;
139139

140140
SmallVector<std::tuple<Operation *, int, Cluster>>
141-
getOpsInOrder(scf::ForOp forOp);
141+
getOpsInOrder(scf::ForOp forOp) const;
142142
std::vector<std::pair<Operation *, unsigned>>
143-
createFinalSchedule(scf::ForOp forOp);
143+
createFinalSchedule(scf::ForOp forOp) const;
144144

145145
bool empty() const { return opToStageAndCluster.size() == 0; }
146146
auto end() const { return opToStageAndCluster.end(); }
147147
auto begin() const { return opToStageAndCluster.begin(); }
148148

149149
// Set <stage, cluster> based on CoarseSchedule.
150-
void serialize(scf::ForOp &forOp);
150+
void serialize(scf::ForOp &forOp) const;
151151
// Create a CoarseSchedule based on forOp's <stage, cluster>.
152152
LogicalResult deSerialize(scf::ForOp &forOp);
153153

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef TRITON_GENERIC_SWIZZLING_H
2+
#define TRITON_GENERIC_SWIZZLING_H
3+
4+
#include "llvm/ADT/ArrayRef.h"
5+
#include "llvm/ADT/SmallVector.h"
6+
#include <cstdint>
7+
8+
namespace mlir::triton {
9+
class LinearLayout;
10+
}
11+
12+
namespace mlir::triton::gpu {
13+
LinearLayout optimalSwizzling(const LinearLayout &src, const LinearLayout &dst,
14+
int32_t bitwidth);
15+
16+
std::pair<int, int> logBankConflicts(const LinearLayout &src,
17+
const LinearLayout &dst,
18+
const LinearLayout &smem,
19+
int32_t bitwidth);
20+
} // namespace mlir::triton::gpu
21+
22+
#endif // TRITON_GENERIC_SWIZZLING_H

include/triton/Tools/LayoutUtils.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,18 @@ LinearLayout zerosLike(const LinearLayout &layout);
116116
// For a layout A with A.hasInDim(kReg), find a permutation of registers action
117117
// such that action.apply(A) may be divisible by B
118118
// It's not always true that the action returned by this function will
119-
// allow us to divideLeft, but it is true that if it if there exists one, it is
120-
// the one returned by this function.
121-
std::optional<ColumnAction> regPermForDivideLeft(const LinearLayout &A,
122-
const LinearLayout &B);
119+
// allow us to divideLeft (resp. divideRight), but it is true that if it if
120+
// there exists one, it is the one returned by this function.
121+
std::optional<ColumnAction> regPermForDivide(const LinearLayout &A,
122+
const LinearLayout &B, bool left);
123123

124124
// For a layout A with A.hasInDim(kReg), find a permutation of registers action
125125
// such that action.apply(A) has the broadcasted registers removed
126126
ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout);
127127

128+
std::pair<int64_t, ColumnAction>
129+
actionAdditiveStrides(const LinearLayout &layout);
130+
128131
// For a layout A with A.hasInDim(kReg), repeat the values so that they have
129132
// the same broadcasting as layout
130133
SmallVector<Value> broadcastAs(const SmallVector<Value> &values,

include/triton/Tools/LinearLayout.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,11 @@ class LinearLayout {
453453
auto getOutDimNames() const { return llvm::make_first_range(outDims); }
454454
auto getOutDimSizes() const { return llvm::make_second_range(outDims); }
455455

456+
// Relevant for reshaping
457+
SmallVector<std::pair<StringAttr, int32_t>> getOutDims() const {
458+
return to_vector(outDims);
459+
}
460+
456461
// Gets the position that this outDim occupies in getOutDimNames(). Asserts
457462
// if the dim is not present.
458463
int32_t getOutDimIndex(StringAttr outDim) const;
@@ -620,6 +625,7 @@ class LinearLayout {
620625

621626
// Compute a C such that A = B * C if it exists.
622627
// In other words, C = B^{-1} * A.
628+
// For divideRight, we compute A = C * B, that is, C = A * B^{-1}.
623629
// Note that such a C exists iff (every pair of input/output dim of) A is
624630
// of the form
625631
// [[B, 0],
@@ -633,6 +639,8 @@ class LinearLayout {
633639
// same dimensions as A ensures that C is well-defined.
634640
friend std::optional<LinearLayout> divideLeft(const LinearLayout &A,
635641
const LinearLayout &B);
642+
friend std::optional<LinearLayout> divideRight(const LinearLayout &A,
643+
const LinearLayout &B);
636644

637645
// Returns true if this layout acts trivially (as the identity) on the given
638646
// dimensions. This means that it's the identity on those dimensions, and it
@@ -798,9 +806,10 @@ class ColumnAction {
798806
SmallVector<size_t> action;
799807
StringAttr inDim;
800808
size_t inSizeLog2;
801-
bool isIdentity;
809+
bool isIdentity = true;
802810

803811
public:
812+
ColumnAction() = default;
804813
ColumnAction(ArrayRef<size_t> action, StringAttr inDim, size_t inSizeLog2)
805814
: action(action), inDim(inDim), inSizeLog2(inSizeLog2) {
806815
auto it = llvm::max_element(action);

lib/Analysis/Allocation.cpp

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include "triton/Dialect/Triton/IR/Utility.h"
1111
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1212
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
13+
#include "triton/Tools/GenericSwizzling.h"
14+
#include "triton/Tools/LayoutUtils.h"
1315
#include "llvm/ADT/SmallVector.h"
1416
#include "llvm/Support/Debug.h"
1517
#include "llvm/Support/raw_ostream.h"
@@ -32,6 +34,30 @@ constexpr int kPtrBitWidth = 64;
3234
// Max shmem LDS/STS instruction in bits
3335
constexpr int kMaxShmemVecBitLength = 128;
3436

37+
static unsigned getBitwidth(RankedTensorType ty) {
38+
auto isPtr = isa<PointerType>(ty.getElementType());
39+
return isPtr ? kPtrBitWidth : std::max(ty.getElementTypeBitWidth(), 8u);
40+
}
41+
42+
static unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
43+
RankedTensorType dstTy) {
44+
auto *ctx = srcTy.getContext();
45+
auto srcLayout = gpu::toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
46+
auto dstLayout = gpu::toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
47+
srcLayout = actionRemoveBroadcastedRegs(srcLayout).apply(srcLayout);
48+
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
49+
auto bitwidth = getBitwidth(srcTy);
50+
auto smem = gpu::optimalSwizzling(srcLayout, dstLayout, bitwidth);
51+
auto reps = smem.getInDimSize(StringAttr::get(ctx, "reps"));
52+
return smem.getTotalOutDimSize() / reps;
53+
}
54+
55+
static unsigned getNumScratchElemsPaddedCvt(RankedTensorType srcTy,
56+
RankedTensorType dstTy) {
57+
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
58+
return getNumScratchElements(scratchConfig.paddedRepShape);
59+
}
60+
3561
static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
3662
RankedTensorType dstTy) {
3763
Attribute srcLayout = srcTy.getEncoding();
@@ -135,12 +161,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
135161
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);
136162
// Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this
137163
// is the max vectorisation
138-
auto inBitWidth = isa<PointerType>(srcTy.getElementType())
139-
? kPtrBitWidth
140-
: srcTy.getElementTypeBitWidth();
141-
auto outBitWidth = isa<PointerType>(dstTy.getElementType())
142-
? kPtrBitWidth
143-
: dstTy.getElementTypeBitWidth();
164+
auto inBitWidth = getBitwidth(srcTy);
165+
auto outBitWidth = getBitwidth(dstTy);
144166
scratchConfig.inVec =
145167
std::min(scratchConfig.inVec, kMaxShmemVecBitLength / inBitWidth);
146168
scratchConfig.outVec =
@@ -174,27 +196,18 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
174196
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
175197
op->getParentOfType<ModuleOp>());
176198
return std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
177-
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
199+
getBitwidth(dstTy) / 8;
178200
}
179201
if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
180202
auto srcTy = cvtLayout.getSrc().getType();
181203
auto dstTy = cvtLayout.getType();
182-
auto srcEncoding = srcTy.getEncoding();
183-
auto dstEncoding = dstTy.getEncoding();
184-
if (mlir::isa<gpu::SharedEncodingTrait>(srcEncoding) ||
185-
mlir::isa<gpu::SharedEncodingTrait>(dstEncoding)) {
186-
// Conversions from/to shared memory do not need scratch memory.
204+
if (!cvtNeedsSharedMemory(srcTy, dstTy))
187205
return 0;
188-
}
189-
// ConvertLayoutOp with both input/output non-shared_layout
190-
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
191-
// also possible to realize it with other approaches in restricted
192-
// conditions, such as warp-shuffle
193-
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
194-
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
195-
return isa<PointerType>(srcTy.getElementType())
196-
? elems * kPtrBitWidth / 8
197-
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
206+
// Pesimistically take the max. We will revisit later
207+
auto elems = std::max(getNumScratchElemsSwizzledCvt(srcTy, dstTy),
208+
getNumScratchElemsPaddedCvt(srcTy, dstTy));
209+
210+
return elems * getBitwidth(srcTy) / 8;
198211
}
199212
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
200213
auto value = op->getOperand(0);

0 commit comments

Comments
 (0)