Skip to content

Commit e302ae6

Browse files
Merge commit '6f5baf6801b44e51b7ba8eedaa619e39c912bef6'
2 parents 2c6a850 + 6f5baf6 commit e302ae6

File tree

32 files changed

+1761
-254
lines changed

32 files changed

+1761
-254
lines changed

.github/workflows/integration-tests.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -408,10 +408,6 @@ jobs:
408408
cd python
409409
ccache --zero-stats
410410
pip install -v -e '.[tests]'
411-
- name: Clean up after an unsuccessful build
412-
if: ${{ !success() && steps.amd-install-triton.outcome != 'success' }}
413-
run: |
414-
rm -rf ~/.triton
415411
- name: CCache Stats
416412
run: ccache --print-stats
417413
- name: Run lit tests
@@ -477,8 +473,11 @@ jobs:
477473
~/.ccache
478474
key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }}
479475
- name: Clean up caches
476+
# Always cleanup the worker, even if builds or tests failed
477+
if: always()
480478
run: |
481-
rm -rf ~/.triton/cache
479+
rm -rf ~/.triton
480+
rm -rf ~/.ccache
482481
Build-Tests:
483482
needs: Runner-Preparation
484483
if: needs.Runner-Preparation.outputs.matrix-MACOS != ''

.github/workflows/integration-tests.yml.in

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -402,11 +402,6 @@ jobs:
402402
ccache --zero-stats
403403
pip install -v -e '.[tests]'
404404

405-
- name: Clean up after an unsuccessful build
406-
if: ${{ !success() && steps.amd-install-triton.outcome != 'success' }}
407-
run: |
408-
rm -rf ~/.triton
409-
410405
- *print-ccache-stats
411406
- *run-lit-tests-step
412407

@@ -442,8 +437,11 @@ jobs:
442437
- *save-build-artifacts-step
443438

444439
- name: Clean up caches
440+
# Always cleanup the worker, even if builds or tests failed
441+
if: always()
445442
run: |
446-
rm -rf ~/.triton/cache
443+
rm -rf ~/.triton
444+
rm -rf ~/.ccache
447445

448446
Build-Tests:
449447
needs: Runner-Preparation

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,10 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
211211
- `LLVM_ENABLE_TIMING` dumps the timing information for each LLVM pass.
212212
- `TRITON_DEFAULT_FP_FUSION` overrides the default behavior of allowing fp fusion (mul+add->fma).
213213
- `MLIR_ENABLE_REMARK` enables the performance warnings that are emitted as remarks.
214-
- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx.
215-
- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx when `TRITON_KERNEL_DUMP` is set to 1.
216-
- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx at the beginning of each compilation stage.
217-
- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx files when `TRITON_KERNEL_OVERRIDE` is set to 1.
214+
- `TRITON_KERNEL_DUMP` enables the dumping of the IR from each compilation stage and the final ptx/amdgcn.
215+
- `TRITON_DUMP_DIR` specifies the directory to save the dumped IR and ptx/amdgcn when `TRITON_KERNEL_DUMP` is set to 1.
216+
- `TRITON_KERNEL_OVERRIDE` enables the override of the compiled kernel with a user-specified IR/ptx/amdgcn at the beginning of each compilation stage.
217+
- `TRITON_OVERRIDE_DIR` specifies the directory from which to load the IR/ptx/amdgcn files when `TRITON_KERNEL_OVERRIDE` is set to 1.
218218

219219
**Kernel Override Steps**
220220

@@ -224,7 +224,7 @@ export TRITON_KERNEL_DUMP=1
224224
export TRITON_DUMP_DIR=<dump_dir>
225225
export TRITON_KERNEL_OVERRIDE=1
226226
export TRITON_OVERRIDE_DIR=<override_dir>
227-
# Step 1: Run the kernel once to dump kernel's IRs and ptx in $TRITON_DUMP_DIR
227+
# Step 1: Run the kernel once to dump kernel's IRs and ptx/amdgcn in $TRITON_DUMP_DIR
228228
# Step 2: Copy $TRITON_DUMP_DIR/<kernel_hash> to $TRITON_OVERRIDE_DIR
229229
# Step 3: Delete the stages that you do not want to override and modify the stage you do want to override
230230
# Step 4: Run the kernel again to see the overridden result

include/triton/Analysis/Utility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ class GatherLoweringHelper {
161161

162162
// Get the shared memory scratch size required by this op.
163163
unsigned getScratchSizeInBytes();
164+
// Determine if the gather can be performed completely within a warp.
165+
bool isWarpLocal();
164166

165167
private:
166168
triton::GatherOp gatherOp;

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,8 +1123,8 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
11231123
return idx;
11241124
}
11251125

1126-
// Emit code to compute the (blockId, warpId, laneId) for the current thread.
1127-
std::tuple</*blockId=*/Value, /*warpId=*/Value, /*laneId=*/Value>
1126+
// Emit code to compute the (laneId, warpId, blockId) for the current thread.
1127+
std::tuple</*laneId=*/Value, /*warpId=*/Value, /*blockId=*/Value>
11281128
emitHardwareTuple(Location loc, RewriterBase &rewriter,
11291129
const TargetInfoBase &target, bool withCTAOffset,
11301130
unsigned threadsPerWarp);

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,12 @@ LinearLayout ensureLayoutNotSmallerThan(
214214
const LinearLayout &layout,
215215
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);
216216

217+
// Return a vector of the standard out dimension names for tensor layouts. These
218+
// are "dim0", "dim1", etc.
217219
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank);
220+
// Return an identity mapping from `inDimName` to the standard out dimensions,
221+
// with the dimensions sized according to the shape. The bases are sorted
222+
// according to `order`, with the most minor dimension first.
218223
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
219224
ArrayRef<unsigned> order);
220225

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define TRITON_TRITONGPU_TRANSFORMS_PIPELINER_PIPELINING_UTILITY_H_
33

44
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
#include <optional>
6+
#include <utility>
57
#include <vector>
68

79
namespace mlir {
@@ -38,6 +40,7 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
3840
// Return the minClusterId and maxClusterId for the given ForOp.
3941
std::pair<int, int> getMinMaxCluster(scf::ForOp &forOp);
4042
std::pair<int, int> getStageCluster(Operation *op);
43+
std::optional<std::pair<int, int>> maybeGetStageCluster(Operation *op);
4144
void setStageCluster(Operation *op, int stage, int cluster);
4245
} // namespace triton
4346
} // namespace mlir
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#pragma once
2+
#include "mlir/IR/BuiltinTypes.h"
3+
#include "mlir/IR/PatternMatch.h"
4+
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
6+
namespace mlir::triton::nvidia_gpu {
7+
8+
constexpr inline int TMA_SIZE_BYTES = 128;
9+
constexpr inline int TMA_ALIGN = 128;
10+
11+
template <typename BuilderT>
12+
mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
13+
mlir::triton::MakeTensorDescOp op,
14+
BuilderT &builder) {
15+
using namespace mlir;
16+
MLIRContext *ctx = op.getContext();
17+
auto loc = op.getLoc();
18+
auto mkI32Constant = [&](int32_t val) {
19+
return builder.template create<arith::ConstantOp>(
20+
loc, builder.getI32Type(), builder.getI32IntegerAttr(val));
21+
};
22+
23+
auto elemType = op.getBase().getType().getPointeeType();
24+
auto elemSize = elemType.getIntOrFloatBitWidth() / 8;
25+
26+
int32_t contig_dim_size = op.getTensorShape().back();
27+
int32_t contig_dim_size_in_bytes = contig_dim_size * elemSize;
28+
if (contig_dim_size_in_bytes > 128) {
29+
contig_dim_size = 128 / elemSize;
30+
}
31+
llvm::SmallVector<Value> boxDim;
32+
boxDim.push_back(mkI32Constant(contig_dim_size));
33+
for (int k = op.getTensorShape().size() - 2; k >= 0; --k) {
34+
boxDim.push_back(mkI32Constant(op.getTensorShape()[k]));
35+
}
36+
37+
int32_t swizzle_mode;
38+
if (contig_dim_size_in_bytes >= 128) {
39+
swizzle_mode = 3;
40+
} else if (contig_dim_size_in_bytes == 64) {
41+
swizzle_mode = 2;
42+
} else if (contig_dim_size_in_bytes == 32) {
43+
swizzle_mode = 1;
44+
} else {
45+
op->emitError()
46+
<< "contiguous box dimension must be at least 32 bytes but got "
47+
<< contig_dim_size_in_bytes;
48+
return failure();
49+
}
50+
51+
Value elemSizeVal = builder.template create<arith::ConstantOp>(
52+
loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize));
53+
Value globalStride = builder.template create<arith::MulIOp>(
54+
loc, op.getStrides()[0], elemSizeVal);
55+
// TODO: Workaround for ptxas bug, remove when we update ptxas
56+
Value four = builder.template create<arith::ConstantOp>(
57+
loc, builder.getI64Type(), builder.getI64IntegerAttr(4));
58+
globalStride =
59+
builder.template create<arith::ShRSIOp>(loc, globalStride, four);
60+
61+
int elemTypeEnum;
62+
switch (elemSize) {
63+
case 1: {
64+
elemTypeEnum = 0;
65+
break;
66+
}
67+
case 2: {
68+
elemTypeEnum = 1;
69+
break;
70+
}
71+
case 4: {
72+
elemTypeEnum = 2;
73+
break;
74+
}
75+
default: {
76+
op->emitError()
77+
<< "Tensor descriptor element type must have size 1, 2, or 4 but got "
78+
<< elemSize;
79+
return failure();
80+
}
81+
}
82+
83+
auto one = mkI32Constant(1);
84+
builder.template create<triton::ExperimentalTensormapCreateOp>(
85+
loc,
86+
/*desc_ptr=*/tmaPtr,
87+
/*global_address=*/op.getBase(),
88+
/*box_dim=*/boxDim,
89+
/*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]},
90+
/*global_stride=*/ValueRange{globalStride},
91+
/*element_strides=*/ValueRange{one, one},
92+
/*elem_type*/ builder.getI32IntegerAttr(elemTypeEnum),
93+
/*interleave_layout*/ builder.getI32IntegerAttr(0),
94+
/*swizzle_mode=*/builder.getI32IntegerAttr(swizzle_mode),
95+
/*fill_mode=*/builder.getI32IntegerAttr(0));
96+
return success();
97+
}
98+
99+
} // namespace mlir::triton::nvidia_gpu

lib/Analysis/Utility.cpp

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,13 +419,93 @@ GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp)
419419
: gatherOp(gatherOp) {}
420420

421421
unsigned GatherLoweringHelper::getScratchSizeInBytes() {
422-
// For now, lower the gather op by writing the source tensor to shared memory.
423-
// TODO(jeff): Leverage locality to avoid using scratch space when possible.
422+
// If the gather is warp-local, no scratch space is needed.
423+
if (isWarpLocal())
424+
return 0;
425+
426+
// Otherwise, performing the gather will require scratch space to communicate
427+
// the source tensor across threads. For now, assume the whole source tensor
428+
// is written back to shared memory.
424429
RankedTensorType srcType = gatherOp.getSrc().getType();
425430
return product(srcType.getShape()) *
426431
ceil<unsigned>(srcType.getElementTypeBitWidth(), 8);
427432
}
428433

434+
bool GatherLoweringHelper::isWarpLocal() {
435+
// The gather is warp-local if for each column along the gather axis in the
436+
// source and index tensors, all the elements are owned by the same warp.
437+
RankedTensorType srcType = gatherOp.getSrc().getType();
438+
RankedTensorType idxType = gatherOp.getIndices().getType();
439+
std::optional<LinearLayout> srcLayout =
440+
toLinearLayout(srcType.getShape(), srcType.getEncoding());
441+
std::optional<LinearLayout> idxLayout =
442+
toLinearLayout(idxType.getShape(), idxType.getEncoding());
443+
444+
// FIXME: If an unsupported layout was encountered, assume the gather is not
445+
// warp-local.
446+
if (!srcLayout || !idxLayout)
447+
return false;
448+
449+
Builder b(gatherOp.getContext());
450+
StringAttr kBlock = b.getStringAttr("block");
451+
StringAttr kWarp = b.getStringAttr("warp");
452+
StringAttr kLane = b.getStringAttr("lane");
453+
StringAttr kGatherDim =
454+
b.getStringAttr("dim" + std::to_string(gatherOp.getAxis()));
455+
456+
// The tensor layouts must be distributed layouts, where the basis matrix is a
457+
// subpermutation matrix (permutation matrix plus zeros for broadcasting).
458+
// FIXME(jeff): Check this invariant somehow.
459+
//
460+
// We want to know if all elements of a column along the gather axis are
461+
// mapped to the same set of warps, which means the gather can be performed
462+
// entirely within the warp. We need to query
463+
//
464+
// srcLayout.invert().sublayoutIsZero({kGatherDim}, {kBlock, kWarp})
465+
//
466+
// But due to broadcasting, the matrix might not be invertible. But since the
467+
// matrix is a permutation matrix (checked below), we can instead query
468+
//
469+
// srcLayout.sublayoutIsZero({kBlock, kWarp}, {kGatherDim})
470+
//
471+
// Which implies that changing the warp will not change the gather dimension.
472+
// And since there is no swizzling, this applies to all warps.
473+
if (!srcLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim) ||
474+
!idxLayout->sublayoutIsZero({kBlock, kWarp}, kGatherDim))
475+
return false;
476+
477+
SmallVector<StringAttr> otherDims;
478+
for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) {
479+
if (dim != gatherOp.getAxis()) {
480+
otherDims.push_back(b.getStringAttr("dim" + Twine(dim)));
481+
}
482+
}
483+
484+
// If the gather axis `dimN` is invariant to the warp, but the `(block, warp)`
485+
// mapping to all other dimensions must be the same for both layouts. If so,
486+
// then the warp that owns a particular index element also owns all the source
487+
// elements it could index into.
488+
if (srcLayout->sublayout({kBlock, kWarp}, otherDims) !=
489+
idxLayout->sublayout({kBlock, kWarp}, otherDims))
490+
return false;
491+
492+
// The two constraints above ensure that data-movement to perform the gather
493+
// operation are contained within a warp. The subsequent constraints simplify
494+
// codegen.
495+
496+
// Require that for any given gather column, the threads mapped to the column
497+
// in the index and source tensors are the same. This means we don't need to
498+
// xor shuffle across threads before emitting index shuffles; we push warp
499+
// shuffling to layout conversions.
500+
if (srcLayout->sublayout(kLane, otherDims) !=
501+
idxLayout->sublayout(kLane, otherDims))
502+
return false;
503+
504+
// Otherwise, the source layout has to be invertible. This primarily means
505+
// the codegen path doesn't support broadcasted source layouts.
506+
return srcLayout->isInvertible();
507+
}
508+
429509
unsigned getNumScratchElements(ArrayRef<unsigned> shape) {
430510
if (shape.empty())
431511
return 0;

0 commit comments

Comments
 (0)