Skip to content

Commit 3f4fdd1

Browse files
Merge commit '4d2e9e5de96a5d6ea163f2de04ae5c5b6be45825'
2 parents 492ea92 + 4d2e9e5 commit 3f4fdd1

File tree

32 files changed

+378
-168
lines changed

32 files changed

+378
-168
lines changed

bin/CMakeLists.txt

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,7 @@ add_llvm_executable(triton-opt triton-opt.cpp PARTIAL_SOURCES_INTENDED)
77
# TODO: what's this?
88
llvm_update_compile_flags(triton-opt)
99
target_link_libraries(triton-opt PRIVATE
10-
TritonLLVMIR
11-
TritonAnalysis
12-
TritonTransforms
13-
TritonGPUTransforms
14-
TritonNvidiaGPUTransforms
1510
TritonIntelLLVMIR
16-
MLIRGPUToROCDLTransforms
1711
${dialect_libs}
1812
${conversion_libs}
1913
${triton_libs}
@@ -32,11 +26,6 @@ mlir_check_all_link_libraries(triton-reduce)
3226

3327
llvm_update_compile_flags(triton-reduce)
3428
target_link_libraries(triton-reduce PRIVATE
35-
TritonLLVMIR
36-
TritonAnalysis
37-
TritonTransforms
38-
TritonGPUTransforms
39-
TritonNvidiaGPUTransforms
4029
${dialect_libs}
4130
${conversion_libs}
4231
${triton_libs}
@@ -54,10 +43,6 @@ add_llvm_executable(triton-lsp triton-lsp.cpp PARTIAL_SOURCES_INTENDED)
5443

5544
llvm_update_compile_flags(triton-lsp)
5645
target_link_libraries(triton-lsp PRIVATE
57-
TritonAnalysis
58-
TritonTransforms
59-
TritonGPUTransforms
60-
TritonNvidiaGPUTransforms
6146
${dialect_libs}
6247
${conversion_libs}
6348
${triton_libs}
@@ -96,8 +81,6 @@ export_executable_symbols_for_plugins(triton-llvm-opt)
9681

9782
add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED)
9883
target_link_libraries(triton-tensor-layout PRIVATE
99-
TritonGPUIR
100-
TritonNvidiaGPUIR
10184
${triton_libs}
10285
${conversion_libs}
10386
${dialect_libs}

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,15 +1154,15 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
11541154
// Returns true on success.
11551155
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
11561156
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
1157-
Type elemLlvmTy, std::optional<int32_t> maxVecElems, Value shmemBase,
1158-
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
1157+
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
1158+
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
11591159
const TargetInfoBase &target,
11601160
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
11611161

11621162
inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
11631163
Location loc, const TargetInfoBase &target, unsigned inVec,
11641164
RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout,
1165-
Type resElemTy, SharedMemoryObject smemObj, RewriterBase &rewriter,
1165+
Type resElemTy, const SharedMemoryObject &smemObj, RewriterBase &rewriter,
11661166
ArrayRef<Value> offsetVals, ArrayRef<Value> srcStrides) {
11671167
// This utility computes the pointers for accessing the provided swizzled
11681168
// shared memory layout `resSharedLayout`. More specifically, it computes,
@@ -1324,14 +1324,14 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
13241324
SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
13251325
triton::gpu::MemDescType srcTy,
13261326
Type elemLlvmTy,
1327-
SharedMemoryObject smemObj,
1327+
const SharedMemoryObject &smemObj,
13281328
Location loc, RewriterBase &rewriter,
13291329
const TargetInfoBase &target);
13301330

13311331
void storeDistributedToShared(
13321332
triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
1333-
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
1334-
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
1333+
ArrayRef<Value> srcVals, const SharedMemoryObject &smemObj, Location loc,
1334+
RewriterBase &rewriter, const TargetInfoBase &target,
13351335
std::pair<size_t, Type> *const llvmOpCount = nullptr);
13361336

13371337
inline Value getStructFromSharedMemoryObject(Location loc,

include/triton/Tools/LinearLayout.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,10 @@ class LinearLayout {
414414

415415
bool isSurjective() const { return surjective; }
416416

417+
bool isInvertible() const {
418+
return surjective && getTotalInDimSize() == getTotalOutDimSize();
419+
}
420+
417421
const BasesT &getBases() const { return bases; }
418422

419423
// Get the pos'th basis vector for the inDim -> outDim mapping.
@@ -673,6 +677,9 @@ class LinearLayout {
673677
// don't place any guarantees on the choices made by this function.
674678
[[nodiscard]] LinearLayout invertAndCompose(const LinearLayout &outer) const;
675679

680+
// Get the layout that is the inverse of this layout.
681+
[[nodiscard]] LinearLayout invert() const;
682+
676683
// For each in-dim, returns a bitmask of the "free variables" in the layout
677684
// function.
678685
//

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@ void lowerDistributedToShared(
2525
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
2626
auto elemTy = typeConverter->convertType(srcTy.getElementType());
2727

28-
auto smemBase = smemObj.getBase();
29-
auto dstStrides = smemObj.getStrides();
3028
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
31-
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
32-
loc, rewriter, targetInfo, llvmOpCount);
29+
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
30+
targetInfo, llvmOpCount);
3331
}
3432

3533
struct GlobalScratchAllocOpConversion
@@ -157,14 +155,9 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
157155
// If we remove this one, ldmatrix will IMA. It can probably be relaxed
158156
// though
159157
canUseLdmatrix &=
160-
srcTy.getShape()[0] >= 8 && srcTy.getShape()[1] >= 4 * kWidth;
161-
// To be removed in https://github.com/triton-lang/triton/pull/5154
162-
bool legacyLoweringIsBuggy =
163-
(kWidth >= 8 || (kWidth == 4 && bitwidth == 32) ||
164-
dstTy.getRank() == 3) &&
165-
mma.isAmpere();
166-
return (mma.isHopper() && !canUseLdmatrix) ||
167-
(mma.isAmpere() && legacyLoweringIsBuggy);
158+
srcTy.getShape()[0] >= 8 &&
159+
srcTy.getShape()[1] >= 4 * kWidth & dstTy.getRank() <= 2;
160+
return !canUseLdmatrix;
168161
}
169162
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
170163
return true;

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 141 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,139 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
169169
return ret;
170170
}
171171

172+
namespace {
173+
174+
Value getSmemVecAddr(RankedTensorType registerTy,
175+
triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
176+
Location loc, RewriterBase &rewriter,
177+
const LinearLayout &regToSharedLayout, Value regId,
178+
Value laneId, Value warpId,
179+
const SharedMemoryObject &smemObj) {
180+
MLIRContext *ctx = rewriter.getContext();
181+
StringAttr kBlock = str_attr("block");
182+
StringAttr kRegister = str_attr("register");
183+
StringAttr kLane = str_attr("lane");
184+
StringAttr kWarp = str_attr("warp");
185+
auto shape = sharedTy.getShape();
186+
auto rank = shape.size();
187+
auto allocShape = sharedTy.getAllocShape();
188+
auto sharedEnc =
189+
dyn_cast<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding());
190+
191+
auto smemBase = smemObj.getBase();
192+
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
193+
auto smemOffsets = smemObj.getOffsets();
194+
auto smemStrides = smemObj.getStrides();
195+
Value smemOffset;
196+
// When loading or storing to shared memory, we consider two cases for
197+
// performance reasons:
198+
//
199+
// 1. Non-swizzled shared memory.
200+
// 2. Swizzled shared memory.
201+
//
202+
// Consider lowering `ttg.local_load %a`. In the first case, we can
203+
// directly construct a linear layout using `%a`'s shape and shared memory
204+
// encoding, irrespective of `%a`'s rank or whether it represents a slice of a
205+
// larger tensor.
206+
//
207+
// The method does not apply for swizzled shared memory in some scenarios.
208+
// Key properties of swizzling in Triton are:
209+
//
210+
// - Swizzling applies only to tensors with rank ≥ 2.
211+
// - It is restricted to the last two dimensions of the tensor.
212+
// - These last two dimensions are always treated as the most "minor."
213+
//
214+
// An important edge case arises when `%a` results from `%a = ttg.subview %b`,
215+
// where `%b` is swizzled (and so is `%a`). In this case, constructing a
216+
// layout and determining shared memory offsets using `%a`'s shape is
217+
// incorrect. This is because swizzling depends on the original shape of `%b`,
218+
// which differs from `%a`'s shape. As a result, some locations may fall
219+
// outside `%a`'s contiguous view of memory. Specifically, an element `[i
220+
// (row_idx), j (col_idx)]` in `%a` might map to `[i, j']` after swizzling,
221+
// where `j'` lies outside `%a`'s shape but still within `%b`'s shape.
222+
//
223+
// We propose case 2 (see comments below), which provides a more general
224+
// solution for all swizzled shared memory scenarios, including the edge case
225+
// mentioned above.
226+
if (/*no swizzling*/ sharedEnc.getMaxPhase() == 1 ||
227+
/*swizzling but same shape*/ shape == allocShape ||
228+
/*swizzling and rank-reduced and rank >= 2*/
229+
(shape == allocShape.take_back(rank) && rank >= 2)) { // Case 1
230+
// Get the address to load/store. The multi-dim address is (offsetX1, ...,
231+
// offsetXN, block), where the offsets appear in minor-to-major order, and
232+
// we drop_end to drop block, which we know from above will be 0.
233+
smemOffsets = llvm::to_vector(llvm::drop_end(llvm::make_second_range(
234+
applyLinearLayout(loc, rewriter, regToSharedLayout,
235+
{{kRegister, regId},
236+
{kLane, laneId},
237+
{kWarp, warpId},
238+
{kBlock, i32_val(0)}}))));
239+
// Reorder strides according to `order`. This way they match the
240+
// multi-dimensional offsets in regToSharedLayout.
241+
smemOffset = dot(rewriter, loc, smemOffsets,
242+
applyPermutation(smemStrides, sharedOrder));
243+
} else { // Case 2 -> rank-reduced swizzling
244+
assert(rank >= 2 && "Swizzling only applies to tensors with rank >= 2");
245+
// We define both tensor offsets and shared memory offsets:
246+
//
247+
// - Tensor offsets: Relative offsets within a given tensor.
248+
// - Shared memory offsets: Absolute offsets within the shared memory.
249+
//
250+
// In Triton, the shared memory layout provides an invertible, one-to-one
251+
// mapping between tensor offsets and shared memory offsets. The `base`
252+
// field of any shared memory object represents both the shared memory
253+
// offset and the tensor offset relative to the original tensor at
254+
// allocation, prior to any subview operations.
255+
//
256+
// To determine the shared memory offsets for a specific register when
257+
// dealing with swizzled and sliced tensors, the process involves:
258+
//
259+
// 1. Retrieving the original tensor's `invertAllocSharedLayout`, which
260+
// maps the allocated tensor's offsets back to shared memory offsets.
261+
// 2. Reconstructing the register's offsets in the allocated tensor by
262+
// summing:
263+
// - The shared memory offsets of the current view's base, and
264+
// - The relative tensor offsets of the register.
265+
//
266+
// This approach ensures that "absolute" tensor offsets can be
267+
// mapped to the correct shared memory addresses using
268+
// `invertAllocSharedLayout`.
269+
std::optional<LinearLayout> regLayout =
270+
triton::gpu::toLinearLayout(shape, registerTy.getEncoding());
271+
auto allocSharedLayout = triton::gpu::toLinearLayout(
272+
allocShape.take_back(rank), sharedTy.getEncoding(),
273+
elemLlvmTy.getIntOrFloatBitWidth());
274+
assert(allocSharedLayout.has_value() &&
275+
"Failed to convert layout to linear layout");
276+
auto invertAllocSharedLayout = allocSharedLayout->invert();
277+
auto multiDimTensorOffsets =
278+
llvm::to_vector(applyLinearLayout(loc, rewriter, *regLayout,
279+
{{kRegister, regId},
280+
{kLane, laneId},
281+
{kWarp, warpId},
282+
{kBlock, i32_val(0)}}));
283+
for (auto i = 0; i < rank; i++) {
284+
multiDimTensorOffsets[i].second =
285+
add(multiDimTensorOffsets[i].second, smemOffsets[i]);
286+
}
287+
smemOffset = applyLinearLayout(loc, rewriter, invertAllocSharedLayout,
288+
multiDimTensorOffsets)[0]
289+
.second;
290+
Value baseToAllocBaseDist = dot(rewriter, loc, smemOffsets, smemStrides);
291+
smemOffset = sub(smemOffset, baseToAllocBaseDist);
292+
}
293+
auto ptrTy = smemBase.getType();
294+
auto vecAddr = gep(ptrTy, elemLlvmTy, smemBase, smemOffset);
295+
vecAddr.setInbounds(true);
296+
return vecAddr;
297+
}
298+
299+
} // namespace
300+
172301
bool emitTransferBetweenRegistersAndShared(
173302
RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
174-
Type elemLlvmTy, std::optional<int32_t> maxVecElems, Value shmemBase,
175-
ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter,
303+
Type elemLlvmTy, std::optional<int32_t> maxVecElems,
304+
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
176305
const TargetInfoBase &target,
177306
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
178307
MLIRContext *ctx = rewriter.getContext();
@@ -230,28 +359,12 @@ bool emitTransferBetweenRegistersAndShared(
230359

231360
int numElems = regToSharedLayout->getInDimSize(kRegister);
232361
auto vecTy = vec_ty(elemLlvmTy, vecElems);
233-
auto ptrTy = shmemBase.getType();
234362
Value zero = i32_val(0);
235363
SmallVector<Value> ret;
236364
for (int i = 0; i < numElems / vecElems; i++) {
237-
// Get the address to load/store. The multi-dim address is (offsetX1, ...,
238-
// offsetXN, block), where the offsets appear in minor-to-major order, and
239-
// we drop_end to drop block, which we know from above will be 0.
240-
auto multiDimShmemOffset =
241-
llvm::to_vector(llvm::drop_end(llvm::make_second_range(
242-
applyLinearLayout(loc, rewriter, *regToSharedLayout,
243-
{{kRegister, i32_val(i * vecElems)},
244-
{kLane, laneId},
245-
{kWarp, warpId},
246-
{kBlock, zero}}))));
247-
248-
// Reorder strides according to `order`. This way they match the
249-
// multi-dimensional offsets in regToSharedLayout.
250-
auto sharedOrder = triton::gpu::getOrder(sharedTy.getEncoding());
251-
Value shmemOffset = dot(rewriter, loc, multiDimShmemOffset,
252-
applyPermutation(shmemStrides, sharedOrder));
253-
auto vecAddr = gep(ptrTy, elemLlvmTy, shmemBase, shmemOffset);
254-
vecAddr.setInbounds(true);
365+
auto vecAddr = getSmemVecAddr(
366+
registerTy, sharedTy, elemLlvmTy, loc, rewriter, *regToSharedLayout,
367+
i32_val(i * vecElems), laneId, warpId, smemObj);
255368

256369
perVectorCallback(vecTy, vecAddr);
257370
}
@@ -261,14 +374,13 @@ bool emitTransferBetweenRegistersAndShared(
261374
SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
262375
triton::gpu::MemDescType srcTy,
263376
Type elemLlvmTy,
264-
SharedMemoryObject smemObj,
377+
const SharedMemoryObject &smemObj,
265378
Location loc, RewriterBase &rewriter,
266379
const TargetInfoBase &target) {
267380
SmallVector<Value> ret;
268381
bool success = emitTransferBetweenRegistersAndShared(
269-
dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj.getBase(),
270-
smemObj.getStrides(), loc, rewriter, target,
271-
[&](VectorType vecTy, Value vecAddr) {
382+
dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc,
383+
rewriter, target, [&](VectorType vecTy, Value vecAddr) {
272384
auto vecVal = load(vecTy, vecAddr);
273385
vecVal.setAlignment(vecTy.getNumElements() *
274386
elemLlvmTy.getIntOrFloatBitWidth() / 8);
@@ -285,14 +397,14 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
285397

286398
void storeDistributedToShared(triton::gpu::MemDescType dstTy,
287399
RankedTensorType srcTy, Type elemLlvmTy,
288-
ArrayRef<Value> srcVals, Value smemBase,
289-
ArrayRef<Value> dstStrides, Location loc,
400+
ArrayRef<Value> srcVals,
401+
const SharedMemoryObject &smemObj, Location loc,
290402
RewriterBase &rewriter,
291403
const TargetInfoBase &target,
292404
std::pair<size_t, Type> *const llvmOpCount) {
293405
bool success = emitTransferBetweenRegistersAndShared(
294-
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
295-
dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
406+
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc,
407+
rewriter, target, [&](VectorType vecTy, Value vecAddr) {
296408
ArrayRef<Value> vals = srcVals.take_front(vecTy.getNumElements());
297409
srcVals = srcVals.drop_front(vecTy.getNumElements());
298410

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ struct MemDescSubviewOpConversion
394394
int rankReduced = srcTy.getRank() - destRank;
395395
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
396396
strides.push_back(smemObj.strides[i]);
397-
offsetVals.push_back(opOffsetVals[i]);
397+
offsetVals.push_back(add(opOffsetVals[i], smemObj.offsets[i]));
398398
}
399399
// Compute the offset based on the original strides of the shared memory
400400
// object

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp,
272272

273273
builder.setInsertionPointAfter(viewLoad);
274274
auto sharedLoad = builder.createWithStage<ttg::LocalLoadOp>(
275-
loc, stage, clusterId, loadOp.getType(),
275+
loc, stageForFirstUse, clusterForFirstUse, loadOp.getType(),
276276
viewLoad /*,wait->getResult(0)*/);
277277
auto result = sharedLoad->getResults();
278278
loadOp->replaceAllUsesWith(result);

lib/Tools/LinearLayout.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ LinearLayout::checkInvariants(bool requireSurjective) {
337337
"can be reached by some `in` coordinate, but was not:" +
338338
toString();
339339
}
340+
340341
return std::nullopt;
341342
}
342343

@@ -918,6 +919,17 @@ LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const {
918919
return flatComposed.reshapeIns(retInDims).reshapeOuts(retOutDims);
919920
}
920921

922+
LinearLayout LinearLayout::invert() const {
923+
// A^-1(x) = A^-1(I(x)), thus A.invert() = I.invertAndCompose(A)
924+
assert(isInvertible() &&
925+
"A linear layout must be surjective and square to be invertible");
926+
LinearLayout identity = LinearLayout::empty();
927+
for (auto outDim : getOutDimNames()) {
928+
identity *= LinearLayout::identity1D(getOutDimSize(outDim), outDim, outDim);
929+
}
930+
return identity.invertAndCompose(*this);
931+
}
932+
921933
llvm::MapVector<StringAttr, int32_t>
922934
LinearLayout::getFreeVariableMasks() const {
923935
std::unique_ptr<uint64_t[]> mat = getMatrix(*this);

0 commit comments

Comments
 (0)