Skip to content

Commit 679954d

Browse files
[Backend][NFC] Move legacy layout converter to Utility (#7506)
This PR moves legacy layoutConvert ttg->llvm converter to Utility library, so it can be used in multiple places. --------- Co-authored-by: Alexander Efimov <[email protected]>
1 parent ac0d4db commit 679954d

File tree

3 files changed

+292
-253
lines changed

3 files changed

+292
-253
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,25 @@ inline bool isCanonicalIndex(unsigned index, unsigned freeVarMask) {
624624
// group code isolated from above by invoking this function.
625625
void makeAllWarpGroupsIsolatedFromAbove(Operation *op);
626626

627+
/// Converts ConverLayoutOp to llvm using padded pattern.
628+
/// This pattern adds unused memory locations after every rows of tensor fastest
629+
/// changing dimension:
630+
/// e0 e1 e2 e3 p p \
631+
/// e4 e5 e6 e7 p p \
632+
/// ...
633+
/// e e e e p p
634+
/// Dimension order is chosen in order to use wide output reads.
635+
///
636+
/// \param op operation to convert
637+
/// \param src llvm structure containing operation input
638+
/// \param targetInfo
639+
/// \param typeConverter
640+
/// \param rewriter
641+
/// \returns llvm structure containing converted output
642+
Value transferWithinBlockPadding(triton::gpu::ConvertLayoutOp op, Value src,
643+
const TargetInfoBase &targetInfo,
644+
const LLVMTypeConverter *typeConverter,
645+
RewriterBase &rewriter);
627646
} // namespace mlir
628647

629648
#endif

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 3 additions & 253 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
272272
const LinearLayout &dstLayout,
273273
OpAdaptor adaptor,
274274
ConversionPatternRewriter &rewriter) const {
275-
MLIRContext *ctx = op.getContext();
276-
auto loc = op.getLoc();
277-
auto b = TritonLLVMOpBuilder(loc, rewriter);
278-
auto srcTy = op.getSrc().getType();
279-
auto dstTy = op.getType();
280-
281-
assert(cvtNeedsSharedMemory(srcTy, dstTy));
275+
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
282276

283277
// Try to use swizzling to implement the conversion
284278
// HACK Remove once AMD tests pass for the swizzling path
@@ -287,52 +281,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
287281
return success();
288282
}
289283

290-
SmallVector<Value> inVals =
291-
unpackLLElements(loc, adaptor.getSrc(), rewriter);
292-
assert(!inVals.empty());
293-
294-
// We munge the input values by converting i<n> (n<8) elements to i8 and
295-
// pointers to i64. This is necessary because TargetInfo::loadDShared and
296-
// storeDShared can't handle vectors of pointers or sub-byte elements.
297-
auto elemTy = srcTy.getElementType();
298-
auto isSubByteInt =
299-
elemTy.isInteger() && elemTy.getIntOrFloatBitWidth() < 8;
300-
auto isPtr = isa<triton::PointerType>(elemTy);
301-
auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy);
302-
if (isSubByteInt)
303-
elemTy = IntegerType::get(elemTy.getContext(), 8);
304-
else if (isPtr)
305-
elemTy = IntegerType::get(elemTy.getContext(), 64);
306-
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
307-
308-
// Munge input values
309-
for (const auto &it : llvm::enumerate(inVals)) {
310-
if (isSubByteInt) {
311-
inVals[it.index()] = b.zext(llvmElemTy, it.value());
312-
} else if (isPtr) {
313-
inVals[it.index()] = b.ptrtoint(llvmElemTy, it.value());
314-
}
315-
}
316-
317-
// Pretty sure this is the identity function ATM
318-
// It'd be better to simply call `quotient({kBlock})` and
319-
// remove kBlock from transferWithinBlockImpl
320-
auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout);
321-
auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout);
322-
SmallVector<Value> outVals = transferWithinBlockImpl(
323-
inVals, op, srcLayoutWithinBlock, dstLayoutWithinBlock, rewriter);
324-
325-
// Unmunge output values
326-
for (const auto &it : llvm::enumerate(outVals)) {
327-
if (isSubByteInt) {
328-
outVals[it.index()] = b.trunc(llvmElemTyOrig, it.value());
329-
} else if (isPtr) {
330-
outVals[it.index()] = b.inttoptr(llvmElemTyOrig, it.value());
331-
}
332-
}
284+
Value result = transferWithinBlockPadding(op, adaptor.getSrc(), targetInfo,
285+
getTypeConverter(), rewriter);
333286

334-
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
335-
op.getType());
336287
rewriter.replaceOp(op, result);
337288
return success();
338289
}
@@ -343,207 +294,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
343294
DecomposedWarpConversion decomposed,
344295
OpAdaptor adaptor,
345296
ConversionPatternRewriter &rewriter) const;
346-
347-
SmallVector<Value>
348-
transferWithinBlockImpl(ArrayRef<Value> inVals, ConvertLayoutOp op,
349-
const LinearLayout &srcLayout,
350-
const LinearLayout &dstLayout,
351-
ConversionPatternRewriter &rewriter) const {
352-
MLIRContext *ctx = op.getContext();
353-
auto loc = op.getLoc();
354-
auto b = TritonLLVMOpBuilder(loc, rewriter);
355-
356-
StringAttr kRegister = str_attr("register");
357-
StringAttr kLane = str_attr("lane");
358-
StringAttr kWarp = str_attr("warp");
359-
StringAttr kBlock = str_attr("block");
360-
StringAttr kOffset = str_attr("offset");
361-
StringAttr kIteration = str_attr("iteration");
362-
363-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
364-
365-
auto scratchConfig =
366-
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
367-
auto tensorShapePerCTA = convertType<unsigned, int64_t>(getShapePerCTA(
368-
op.getSrc().getType().getEncoding(), op.getType().getShape()));
369-
// Input dims: [offset, iteration, block]
370-
// Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape
371-
LinearLayout sharedLayout = chooseShemLayoutForRegToRegConversion(
372-
ctx, tensorShapePerCTA, scratchConfig.repShape, scratchConfig.order);
373-
374-
// Layout for the store from registers to shared memory.
375-
//
376-
// Note: If two threads in the same warp write to the same shmem offset, the
377-
// hardware resolves that without a stall or a bank conflict. Therefore we
378-
// don't need to avoid duplicate writes.
379-
// Input dims: [reg, lane, warp]
380-
// Output dims: [offset, iteration]
381-
bool isStMatrix = targetInfo.canUseStMatrix(
382-
op.getSrc().getType(), scratchConfig.repShape,
383-
scratchConfig.paddedRepShape, scratchConfig.order,
384-
/*swizzleByteSize=*/0);
385-
LinearLayout shmemStoreLayout =
386-
isStMatrix ? chooseStMatrixLayout(ctx, op.getSrc().getType(),
387-
/*swizzleByteSize=*/0)
388-
: srcLayout.invertAndCompose(sharedLayout);
389-
390-
const int shmemAllocatedNumElems =
391-
getNumScratchElements(scratchConfig.paddedRepShape);
392-
assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems);
393-
394-
// Layout for the load from shmem to registers.
395-
LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout);
396-
397-
// Check that the `register` fully determines the `iteration`. That is,
398-
// each thread does exactly the same reads and writes to shmem on each
399-
// iteration, just with different input/output registers.
400-
assert(
401-
shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
402-
assert(
403-
shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
404-
405-
// iteration -> registers
406-
SmallVector<SmallVector<int>> inRegsForIter =
407-
collectRegsForIter(ctx, shmemStoreLayout);
408-
SmallVector<SmallVector<int>> outRegsForIter =
409-
collectRegsForIter(ctx, shmemLoadLayout);
410-
411-
Value smemBase =
412-
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
413-
auto sharedPtrTy = smemBase.getType();
414-
Type elemTy = inVals[0].getType();
415-
auto outSize = shmemLoadLayout.getInDimSize(kRegister);
416-
auto iterations = sharedLayout.getInDimSize(kIteration);
417-
assert(scratchConfig.inVec * iterations <= inVals.size());
418-
assert(scratchConfig.outVec * iterations <= outSize);
419-
420-
// Check only one dimension has been padded.
421-
// This means the difference between the padded shape and the original shape
422-
// should only be in one dimension, specifically in
423-
// `scratchConfig.order[0]`.
424-
auto rank = scratchConfig.repShape.size();
425-
for (auto i = 0; i < rank; i++) {
426-
if (i == scratchConfig.order[0]) {
427-
continue;
428-
}
429-
assert(scratchConfig.repShape[i] == scratchConfig.paddedRepShape[i]);
430-
}
431-
auto paddedStride = scratchConfig.repShape[scratchConfig.order[0]];
432-
auto paddedSize =
433-
scratchConfig.paddedRepShape[scratchConfig.order[0]] - paddedStride;
434-
435-
// Linear layout function is split in two parts below:
436-
//
437-
// L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0)
438-
// offset = regBase xor regIdx
439-
//
440-
// It is the same hack as what we've done in the emitIndices function to get
441-
// around performance issues on AMD GPUs
442-
auto getVecAddr = [&](LinearLayout &layout, Value &regBase,
443-
int regSlice) -> Value {
444-
auto regIdx = layout
445-
.apply({{kRegister, regSlice},
446-
{kLane, 0},
447-
{kWarp, 0},
448-
{kBlock, 0}})[0]
449-
.second;
450-
Value offset = b.xor_(regBase, b.i32_val(regIdx));
451-
if (paddedSize > 0) {
452-
assert(llvm::isPowerOf2_32(paddedStride));
453-
assert(llvm::isPowerOf2_32(paddedSize));
454-
auto rshiftVal = llvm::Log2_32(paddedStride);
455-
auto lshiftVal = llvm::Log2_32(paddedSize);
456-
offset = b.add(
457-
b.shl(b.lshr(offset, b.i32_val(rshiftVal)), b.i32_val(lshiftVal)),
458-
offset);
459-
}
460-
auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset,
461-
LLVM::GEPNoWrapFlags::inbounds);
462-
return vecAddr;
463-
};
464-
465-
auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout,
466-
{{kRegister, b.i32_val(0)},
467-
{kLane, laneId},
468-
{kWarp, warpId},
469-
{kBlock, b.i32_val(0)}})[0]
470-
.second;
471-
auto loadBase = applyLinearLayout(loc, rewriter, shmemLoadLayout,
472-
{{kRegister, b.i32_val(0)},
473-
{kLane, laneId},
474-
{kWarp, warpId},
475-
{kBlock, b.i32_val(0)}})[0]
476-
.second;
477-
// register idx -> Value
478-
llvm::MapVector<int, Value> outVals;
479-
for (int i = 0; i < iterations; i++) {
480-
if (i != 0)
481-
b.barrier();
482-
483-
auto &inRegs = inRegsForIter[i];
484-
auto &outRegs = outRegsForIter[i];
485-
486-
// When using `stmatrix`, we can store `inVec` elements even if they are
487-
// not contiguous
488-
auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut()
489-
: scratchConfig.inVec;
490-
for (int j = 0; j < inVals.size() / iterations; j += inVec) {
491-
auto inRegSlice = inRegs[j];
492-
Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice);
493-
SmallVector<Value> inValsVec;
494-
for (int k = 0; k < inVec; k++)
495-
inValsVec.push_back(inVals[inRegSlice + k]);
496-
Value valsVec = packLLVector(loc, inValsVec, rewriter);
497-
if (isStMatrix) {
498-
targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec);
499-
} else {
500-
targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec,
501-
/*pred=*/b.true_val());
502-
}
503-
}
504-
505-
b.barrier();
506-
507-
for (int j = 0; j < outSize / iterations; j += scratchConfig.outVec) {
508-
auto outRegSlice = outRegs[j];
509-
auto vecAddr = getVecAddr(shmemLoadLayout, loadBase, outRegSlice);
510-
Value valsVec =
511-
targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt,
512-
vec_ty(elemTy, scratchConfig.outVec),
513-
/*pred=*/b.true_val());
514-
for (Value v : unpackLLVector(loc, valsVec, rewriter))
515-
outVals[outRegSlice++] = v;
516-
}
517-
}
518-
519-
SmallVector<Value> outValsVec;
520-
for (size_t i = 0; i < outVals.size(); i++)
521-
outValsVec.push_back(outVals[i]);
522-
return outValsVec;
523-
}
524-
525-
// Determine which registers are read/written in which iteration of the shmem
526-
// transfer specified by `layout`.
527-
SmallVector<SmallVector<int> /*registers*/>
528-
collectRegsForIter(MLIRContext *ctx, const LinearLayout &layout) const {
529-
StringAttr kRegister = str_attr("register");
530-
StringAttr kLane = str_attr("lane");
531-
StringAttr kWarp = str_attr("warp");
532-
StringAttr kBlock = str_attr("block");
533-
StringAttr kIteration = str_attr("iteration");
534-
535-
// The choice of iteration should be determined only by the register. That
536-
// is, it should be correct to split the register dimension into iterations.
537-
assert(layout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
538-
539-
LinearLayout sublayout = layout.sublayout({kRegister}, {kIteration});
540-
SmallVector<SmallVector<int>> ret(sublayout.getOutDimSize(kIteration));
541-
for (int reg = 0; reg < sublayout.getInDimSize(kRegister); reg++) {
542-
auto idx = sublayout.apply({{kRegister, reg}});
543-
ret[idx.begin()->second].push_back(reg);
544-
}
545-
return ret;
546-
}
547297
};
548298

549299
} // namespace

0 commit comments

Comments
 (0)