Skip to content

Commit 052efe4

Browse files
committed
Merge commit 'a89c5bd1c1707f003b03d214394cba495aae2320'
Signed-off-by: Anatoly Myachev <[email protected]>
2 parents 0f4c607 + a89c5bd commit 052efe4

File tree

14 files changed

+430
-326
lines changed

14 files changed

+430
-326
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,25 @@ inline bool isCanonicalIndex(unsigned index, unsigned freeVarMask) {
624624
// group code isolated from above by invoking this function.
625625
void makeAllWarpGroupsIsolatedFromAbove(Operation *op);
626626

627+
/// Converts ConverLayoutOp to llvm using padded pattern.
628+
/// This pattern adds unused memory locations after every rows of tensor fastest
629+
/// changing dimension:
630+
/// e0 e1 e2 e3 p p \
631+
/// e4 e5 e6 e7 p p \
632+
/// ...
633+
/// e e e e p p
634+
/// Dimension order is chosen in order to use wide output reads.
635+
///
636+
/// \param op operation to convert
637+
/// \param src llvm structure containing operation input
638+
/// \param targetInfo
639+
/// \param typeConverter
640+
/// \param rewriter
641+
/// \returns llvm structure containing converted output
642+
Value transferWithinBlockPadding(triton::gpu::ConvertLayoutOp op, Value src,
643+
const TargetInfoBase &targetInfo,
644+
const LLVMTypeConverter *typeConverter,
645+
RewriterBase &rewriter);
627646
} // namespace mlir
628647

629648
#endif

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 3 additions & 269 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
272272
const LinearLayout &dstLayout,
273273
OpAdaptor adaptor,
274274
ConversionPatternRewriter &rewriter) const {
275-
MLIRContext *ctx = op.getContext();
276-
auto loc = op.getLoc();
277-
auto b = TritonLLVMOpBuilder(loc, rewriter);
278-
auto srcTy = op.getSrc().getType();
279-
auto dstTy = op.getType();
280-
281-
assert(cvtNeedsSharedMemory(srcTy, dstTy));
275+
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
282276

283277
// Try to use swizzling to implement the conversion
284278
// HACK Remove once AMD tests pass for the swizzling path
@@ -287,52 +281,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
287281
return success();
288282
}
289283

290-
SmallVector<Value> inVals =
291-
unpackLLElements(loc, adaptor.getSrc(), rewriter);
292-
assert(!inVals.empty());
293-
294-
// We munge the input values by converting i<n> (n<8) elements to i8 and
295-
// pointers to i64. This is necessary because TargetInfo::loadDShared and
296-
// storeDShared can't handle vectors of pointers or sub-byte elements.
297-
auto elemTy = srcTy.getElementType();
298-
auto isSubByteInt =
299-
elemTy.isInteger() && elemTy.getIntOrFloatBitWidth() < 8;
300-
auto isPtr = isa<triton::PointerType>(elemTy);
301-
auto llvmElemTyOrig = getTypeConverter()->convertType(elemTy);
302-
if (isSubByteInt)
303-
elemTy = IntegerType::get(elemTy.getContext(), 8);
304-
else if (isPtr)
305-
elemTy = IntegerType::get(elemTy.getContext(), 64);
306-
auto llvmElemTy = getTypeConverter()->convertType(elemTy);
307-
308-
// Munge input values
309-
for (const auto &it : llvm::enumerate(inVals)) {
310-
if (isSubByteInt) {
311-
inVals[it.index()] = b.zext(llvmElemTy, it.value());
312-
} else if (isPtr) {
313-
inVals[it.index()] = b.ptrtoint(llvmElemTy, it.value());
314-
}
315-
}
316-
317-
// Pretty sure this is the identity function ATM
318-
// It'd be better to simply call `quotient({kBlock})` and
319-
// remove kBlock from transferWithinBlockImpl
320-
auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout);
321-
auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout);
322-
SmallVector<Value> outVals = transferWithinBlockImpl(
323-
inVals, op, srcLayoutWithinBlock, dstLayoutWithinBlock, rewriter);
324-
325-
// Unmunge output values
326-
for (const auto &it : llvm::enumerate(outVals)) {
327-
if (isSubByteInt) {
328-
outVals[it.index()] = b.trunc(llvmElemTyOrig, it.value());
329-
} else if (isPtr) {
330-
outVals[it.index()] = b.inttoptr(llvmElemTyOrig, it.value());
331-
}
332-
}
284+
Value result = transferWithinBlockPadding(op, adaptor.getSrc(), targetInfo,
285+
getTypeConverter(), rewriter);
333286

334-
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
335-
op.getType());
336287
rewriter.replaceOp(op, result);
337288
return success();
338289
}
@@ -343,223 +294,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
343294
DecomposedWarpConversion decomposed,
344295
OpAdaptor adaptor,
345296
ConversionPatternRewriter &rewriter) const;
346-
347-
SmallVector<Value>
348-
transferWithinBlockImpl(ArrayRef<Value> inVals, ConvertLayoutOp op,
349-
const LinearLayout &srcLayout,
350-
const LinearLayout &dstLayout,
351-
ConversionPatternRewriter &rewriter) const {
352-
MLIRContext *ctx = op.getContext();
353-
auto loc = op.getLoc();
354-
auto b = TritonLLVMOpBuilder(loc, rewriter);
355-
356-
RankedTensorType srcTy = op.getSrc().getType();
357-
auto srcElemTy = srcTy.getElementType();
358-
const bool isInt1 = srcElemTy.isInteger(1);
359-
360-
StringAttr kRegister = str_attr("register");
361-
StringAttr kLane = str_attr("lane");
362-
StringAttr kWarp = str_attr("warp");
363-
StringAttr kBlock = str_attr("block");
364-
StringAttr kOffset = str_attr("offset");
365-
StringAttr kIteration = str_attr("iteration");
366-
367-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
368-
369-
auto scratchConfig =
370-
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
371-
auto tensorShapePerCTA = convertType<unsigned, int64_t>(getShapePerCTA(
372-
op.getSrc().getType().getEncoding(), op.getType().getShape()));
373-
// Input dims: [offset, iteration, block]
374-
// Output dims: dimN-1, dimN-2, ..., dim0, where N is obtained from repShape
375-
LinearLayout sharedLayout = chooseShemLayoutForRegToRegConversion(
376-
ctx, tensorShapePerCTA, scratchConfig.repShape, scratchConfig.order);
377-
378-
// Layout for the store from registers to shared memory.
379-
//
380-
// Note: If two threads in the same warp write to the same shmem offset, the
381-
// hardware resolves that without a stall or a bank conflict. Therefore we
382-
// don't need to avoid duplicate writes.
383-
// Input dims: [reg, lane, warp]
384-
// Output dims: [offset, iteration]
385-
bool isStMatrix = targetInfo.canUseStMatrix(
386-
op.getSrc().getType(), scratchConfig.repShape,
387-
scratchConfig.paddedRepShape, scratchConfig.order,
388-
/*swizzleByteSize=*/0);
389-
LinearLayout shmemStoreLayout =
390-
isStMatrix ? chooseStMatrixLayout(ctx, op.getSrc().getType(),
391-
/*swizzleByteSize=*/0)
392-
: srcLayout.invertAndCompose(sharedLayout);
393-
394-
const int shmemAllocatedNumElems =
395-
getNumScratchElements(scratchConfig.paddedRepShape);
396-
assert(shmemStoreLayout.getOutDimSize(kOffset) <= shmemAllocatedNumElems);
397-
398-
// Layout for the load from shmem to registers.
399-
LinearLayout shmemLoadLayout = dstLayout.invertAndCompose(sharedLayout);
400-
401-
// Check that the `register` fully determines the `iteration`. That is,
402-
// each thread does exactly the same reads and writes to shmem on each
403-
// iteration, just with different input/output registers.
404-
assert(
405-
shmemStoreLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
406-
assert(
407-
shmemLoadLayout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
408-
409-
// iteration -> registers
410-
SmallVector<SmallVector<int>> inRegsForIter =
411-
collectRegsForIter(ctx, shmemStoreLayout);
412-
SmallVector<SmallVector<int>> outRegsForIter =
413-
collectRegsForIter(ctx, shmemLoadLayout);
414-
415-
Value smemBase =
416-
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
417-
auto sharedPtrTy = smemBase.getType();
418-
Type elemTy = inVals[0].getType();
419-
auto outSize = shmemLoadLayout.getInDimSize(kRegister);
420-
auto iterations = sharedLayout.getInDimSize(kIteration);
421-
assert(scratchConfig.inVec * iterations <= inVals.size());
422-
assert(scratchConfig.outVec * iterations <= outSize);
423-
424-
// Check only one dimension has been padded.
425-
// This means the difference between the padded shape and the original shape
426-
// should only be in one dimension, specifically in
427-
// `scratchConfig.order[0]`.
428-
auto rank = scratchConfig.repShape.size();
429-
for (auto i = 0; i < rank; i++) {
430-
if (i == scratchConfig.order[0]) {
431-
continue;
432-
}
433-
assert(scratchConfig.repShape[i] == scratchConfig.paddedRepShape[i]);
434-
}
435-
auto paddedStride = scratchConfig.repShape[scratchConfig.order[0]];
436-
auto paddedSize =
437-
scratchConfig.paddedRepShape[scratchConfig.order[0]] - paddedStride;
438-
439-
// Linear layout function is split in two parts below:
440-
//
441-
// L(r, t, w, b) = L(0, t, w, b) xor L(r, 0, 0, 0)
442-
// offset = regBase xor regIdx
443-
//
444-
// It is the same hack as what we've done in the emitIndices function to get
445-
// around performance issues on AMD GPUs
446-
auto getVecAddr = [&](LinearLayout &layout, Value &regBase,
447-
int regSlice) -> Value {
448-
auto regIdx = layout
449-
.apply({{kRegister, regSlice},
450-
{kLane, 0},
451-
{kWarp, 0},
452-
{kBlock, 0}})[0]
453-
.second;
454-
Value offset = b.xor_(regBase, b.i32_val(regIdx));
455-
if (paddedSize > 0) {
456-
assert(llvm::isPowerOf2_32(paddedStride));
457-
assert(llvm::isPowerOf2_32(paddedSize));
458-
auto rshiftVal = llvm::Log2_32(paddedStride);
459-
auto lshiftVal = llvm::Log2_32(paddedSize);
460-
offset = b.add(
461-
b.shl(b.lshr(offset, b.i32_val(rshiftVal)), b.i32_val(lshiftVal)),
462-
offset);
463-
}
464-
auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset,
465-
LLVM::GEPNoWrapFlags::inbounds);
466-
return vecAddr;
467-
};
468-
469-
auto storeBase = applyLinearLayout(loc, rewriter, shmemStoreLayout,
470-
{{kRegister, b.i32_val(0)},
471-
{kLane, laneId},
472-
{kWarp, warpId},
473-
{kBlock, b.i32_val(0)}})[0]
474-
.second;
475-
auto loadBase = applyLinearLayout(loc, rewriter, shmemLoadLayout,
476-
{{kRegister, b.i32_val(0)},
477-
{kLane, laneId},
478-
{kWarp, warpId},
479-
{kBlock, b.i32_val(0)}})[0]
480-
.second;
481-
// register idx -> Value
482-
llvm::MapVector<int, Value> outVals;
483-
for (int i = 0; i < iterations; i++) {
484-
if (i != 0)
485-
b.barrier();
486-
487-
auto &inRegs = inRegsForIter[i];
488-
auto &outRegs = outRegsForIter[i];
489-
490-
// When using `stmatrix`, we can store `inVec` elements even if they are
491-
// not contiguous
492-
auto inVec = isStMatrix ? shmemStoreLayout.getNumConsecutiveInOut()
493-
: scratchConfig.inVec;
494-
for (int j = 0; j < inVals.size() / iterations; j += inVec) {
495-
auto inRegSlice = inRegs[j];
496-
Value vecAddr = getVecAddr(shmemStoreLayout, storeBase, inRegSlice);
497-
SmallVector<Value> inValsVec;
498-
for (int k = 0; k < inVec; k++)
499-
inValsVec.push_back(inVals[inRegSlice + k]);
500-
Value valsVec = packLLVector(loc, inValsVec, rewriter);
501-
if (isStMatrix) {
502-
targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec);
503-
} else {
504-
targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec,
505-
/*pred=*/b.true_val());
506-
}
507-
}
508-
509-
b.barrier();
510-
511-
for (int j = 0; j < outSize / iterations; j += scratchConfig.outVec) {
512-
auto outRegSlice = outRegs[j];
513-
auto vecAddr = getVecAddr(shmemLoadLayout, loadBase, outRegSlice);
514-
Value valsVec =
515-
targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt,
516-
vec_ty(elemTy, scratchConfig.outVec),
517-
/*pred=*/b.true_val());
518-
for (Value v : unpackLLVector(loc, valsVec, rewriter)) {
519-
if (isInt1) {
520-
// TODO(Intel): special handling for the boolean case required. Does
521-
// this prevent a later optimization that we can't handle, or is
522-
// there something about the layout/SLM loads and stores that
523-
// requires special "transcribing" the boolean to the result of the
524-
// cmp?
525-
outVals[outRegSlice++] =
526-
b.icmp_ne(v, rewriter.create<LLVM::ConstantOp>(
527-
loc, i8_ty, rewriter.getI8IntegerAttr(0)));
528-
} else {
529-
outVals[outRegSlice++] = v;
530-
}
531-
}
532-
}
533-
}
534-
535-
SmallVector<Value> outValsVec;
536-
for (size_t i = 0; i < outVals.size(); i++)
537-
outValsVec.push_back(outVals[i]);
538-
return outValsVec;
539-
}
540-
541-
// Determine which registers are read/written in which iteration of the shmem
542-
// transfer specified by `layout`.
543-
SmallVector<SmallVector<int> /*registers*/>
544-
collectRegsForIter(MLIRContext *ctx, const LinearLayout &layout) const {
545-
StringAttr kRegister = str_attr("register");
546-
StringAttr kLane = str_attr("lane");
547-
StringAttr kWarp = str_attr("warp");
548-
StringAttr kBlock = str_attr("block");
549-
StringAttr kIteration = str_attr("iteration");
550-
551-
// The choice of iteration should be determined only by the register. That
552-
// is, it should be correct to split the register dimension into iterations.
553-
assert(layout.sublayoutIsZero({kLane, kWarp, kBlock}, {kIteration}));
554-
555-
LinearLayout sublayout = layout.sublayout({kRegister}, {kIteration});
556-
SmallVector<SmallVector<int>> ret(sublayout.getOutDimSize(kIteration));
557-
for (int reg = 0; reg < sublayout.getInDimSize(kRegister); reg++) {
558-
auto idx = sublayout.apply({{kRegister, reg}});
559-
ret[idx.begin()->second].push_back(reg);
560-
}
561-
return ret;
562-
}
563297
};
564298

565299
} // namespace

0 commit comments

Comments
 (0)