diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 0f7b255bce..91ed7c4865 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -25,7 +25,7 @@ jobs: - name: Install dependent packages run: | - sudo pip3 install tabulate cmake sphinx matplotlib myst_parser sphinx-rtd-theme pandas pytest sphinx-gallery sphinx-multiversion + sudo pip3 install tabulate cmake sphinx matplotlib myst_parser sphinx-rtd-theme pandas pytest sphinx-gallery sphinx-multiversion llnl-hatchet #- name: Fetch dependent branches # run: | diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index aeb8c2b09e..36344442bd 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -36adf8ecedb64047021265a1e1730773d3b3a9e8 +df0864e761107b07e38f5503e0cbee0cebb4c5e8 diff --git a/docs/conf.py b/docs/conf.py index 9ef6d72837..eac5168d51 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -159,9 +159,6 @@ def documenter(app, obj, parent): 'examples_dirs': '../python/tutorials/', 'gallery_dirs': 'getting-started/tutorials', 'filename_pattern': '', - # TODO: Re-enable the grouped-gemm tutorial. It currently hits this - # assertion: - # https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp#L127 'ignore_pattern': r'(__init__\.py|11.*.py)', 'within_subsection_order': FileNameSortKey, 'reference_url': { diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index d782555cac..6869e068c4 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -572,6 +572,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent()); auto ans = mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 && + mmaLayout.getWarpsPerCTA()[1] == 1 && !cvtNeedsSharedMemory(parentTy, srcTy) && (elementTypeSize == 16 || elementTypeSize == 8); return ans; diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 30de13f6a8..16c9991a17 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -1,10 +1,7 @@ #include "ReduceScanCommon.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Support/LLVM.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonGPU/Transforms/Utility.h" -#include using namespace mlir; using namespace mlir::triton; @@ -80,36 +77,16 @@ struct ReduceOpConversion private: const TargetInfoBase &targetInfo; - void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, - SmallVector &acc, ValueRange cur, bool isFirst) const { - if (isFirst) { - acc = SmallVector(cur.begin(), cur.end()); - return; + void accumulate(Location loc, ConversionPatternRewriter &rewriter, + Region &combineOp, SmallVector &acc, ValueRange cur, + Value pred = {}) const { + auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); + if (acc.size() < results.size()) { + acc.resize(results.size()); } - - // Create a new copy of the reduce block, and inline it - Block *currentBlock = rewriter.getBlock(); - Region &parent = *currentBlock->getParent(); - rewriter.cloneRegionBefore(combineOp, &parent.front()); - auto &newReduce = parent.front(); - auto returnOp = dyn_cast(newReduce.getTerminator()); - - llvm::SmallVector combineArgs(2 * acc.size()); - for (unsigned i = 0; i < acc.size(); ++i) { - combineArgs[i] = acc[i]; - combineArgs[acc.size() + i] = cur[i]; - } - - rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(), - combineArgs); - - auto results = returnOp.getResult(); for (unsigned i = 0; i < acc.size(); ++i) { acc[i] = results[i]; } - - // Delete the terminator, which is no longer used - rewriter.eraseOp(returnOp); } SmallVector> @@ -165,7 +142,7 @@ struct ReduceOpConversion SmallVector key = offsets[i]; key[op.getAxis()] = 0; bool isFirst = accs.find(key) == accs.end(); - accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst); + accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]); if (isFirst) indices[key] = srcIndices[i]; } @@ -175,17 +152,29 @@ struct ReduceOpConversion // region and the accumulator values as source. void warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce, unsigned interleave) const { + unsigned numLaneToReduce, unsigned interleave, + Value pred = {}) const { auto success = targetInfo.warpReduce(rewriter, loc, acc, op, numLaneToReduce, interleave); if (success) return; + + auto mod = op->getParentOfType(); + unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + if (iWarpSize > numLaneToReduce) { + Value threadId = getThreadId(rewriter, loc); + Value warpSize = i32_val(iWarpSize); + Value laneId = urem(threadId, warpSize); + Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce)); + pred = pred ? and_(pred, lanePred) : lanePred; + } + for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) { SmallVector shfl(acc.size()); for (unsigned i = 0; i < acc.size(); ++i) { shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave); } - accumulate(rewriter, op.getCombineOp(), acc, shfl, false); + accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred); } } @@ -344,7 +333,8 @@ struct ReduceOpConversion acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy, threadIsNeeded); } - warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */); + warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */, + threadIsNeeded); // only the first thread in each sizeInterWarps is writing Value writeOffset = readOffset; SmallVector writePtrs(op.getNumOperands()); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h b/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h index 3130001cc5..9f823e2e13 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h +++ b/lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h @@ -4,15 +4,14 @@ // TODO: refactor so that it doesn't fail if Allocation.h // is included after utility.h (due to conflict in `store` macro // and -#include "triton/Analysis/Allocation.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" -#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" // #include "mlir/IR/TypeUtilities.h" -#include "triton/Analysis/AxisInfo.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include +#include #include #define DEBUG_TYPE "ttgpu_to_llvm" @@ -32,6 +31,91 @@ namespace ttng = ::mlir::triton::nvidia_gpu; namespace mlir::triton { class ReduceOp; class ScanOp; + +inline SmallVector +inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock, + Block *insertionBlock, Block::iterator insertionPoint, + ValueRange combineArgs) { + auto returnOp = combineBlock.getTerminator(); + rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint, + combineArgs); + + auto results = SmallVector(returnOp->getOperands()); + + // Delete the terminator, which is no longer used + rewriter.eraseOp(returnOp); + return results; +} + +inline SmallVector applyCombineOp(Location loc, + ConversionPatternRewriter &rewriter, + Region &combineOp, ValueRange acc, + ValueRange cur, Value pred = {}) { + // Allows for passing an unitialized acc and use cur as the neutral element + if (acc.size() == 0) { + return cur; + } + assert(cur.size() == acc.size()); + + // Create a new copy of the combine block, and try to speculatively inline it + Block *currentBlock = rewriter.getBlock(); + Region &parent = *currentBlock->getParent(); + + rewriter.cloneRegionBefore(combineOp, parent, + std::next(currentBlock->getIterator())); + Block &newCombine = *currentBlock->getNextNode(); + + llvm::SmallVector combineArgs(2 * acc.size()); + for (unsigned i = 0; i < acc.size(); ++i) { + combineArgs[i] = acc[i]; + combineArgs[acc.size() + i] = cur[i]; + } + + auto isRegionSpeculatable = + std::all_of(newCombine.begin(), newCombine.end(), + [](auto &op) { return isSpeculatable(&op); }); + + if (!pred || isRegionSpeculatable) { + // Fast path, region has no side effects so we can unconditionally execute + return inlineCombineBlock(rewriter, newCombine, currentBlock, + rewriter.getInsertionPoint(), combineArgs); + } + + // Slow case, create an if to only execute region when pred is true + // #currentBlock + // if (pred) { + // #newCombine + // results = combineOp(cur, acc) + // yield results + // } else { + // yield undef + // } + // #thenBlock + Block *thenBlock = + rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); + + auto returnOp = newCombine.getTerminator(); + auto results = SmallVector(returnOp->getOperands()); + + rewriter.setInsertionPointToEnd(currentBlock); + SmallVector thenBlockArgs; + thenBlockArgs.reserve(results.size()); + for (auto result : results) { + auto ty = result.getType(); + auto undef = rewriter.create(loc, ty); + thenBlockArgs.push_back(undef); + thenBlock->addArgument(ty, loc); + } + rewriter.create(loc, pred, &newCombine, combineArgs, + thenBlock, thenBlockArgs); + + // Split a block after the call. + rewriter.setInsertionPointToEnd(&newCombine); + rewriter.replaceOpWithNewOp(returnOp, thenBlock, results); + rewriter.setInsertionPointToStart(thenBlock); + return SmallVector(thenBlock->getArguments()); +} + } // namespace mlir::triton template diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 675bf5a342..b07f654a1e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -1,5 +1,3 @@ -#include - #include "ReduceScanCommon.h" #include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" @@ -16,37 +14,13 @@ using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getTotalElemsPerThread; // apply combine region to acc and cur and accumulate it into acc -// TODO(Lezcano) This is now duplicated with ReduceOpConversion::reduce. -// Deduplicate -static SmallVector accumulate(ConversionPatternRewriter &rewriter, - Region &combineOp, ValueRange acc, - ValueRange cur) { - // Allows for passing an unitialized acc and use cur as the neutral element - if (acc.size() == 0) { - return cur; - } - assert(cur.size() == acc.size()); - // Create a new copy of the reduce block, and inline it - Block *currentBlock = rewriter.getBlock(); - Region &parent = *currentBlock->getParent(); - rewriter.cloneRegionBefore(combineOp, &parent.front()); - auto &newScan = parent.front(); - auto returnOp = dyn_cast(newScan.getTerminator()); - - SmallVector combineArgs(2 * acc.size()); - for (unsigned i = 0; i < acc.size(); ++i) { - combineArgs[i] = acc[i]; - combineArgs[acc.size() + i] = cur[i]; - } - - rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(), - combineArgs); - SmallVector results; - llvm::transform(returnOp.getResult(), std::back_inserter(results), - [&](Value res) { return rewriter.getRemappedValue(res); }); - // Delete the terminator, which is no longer used - rewriter.eraseOp(returnOp); - return results; +static SmallVector accumulate(ScanLoweringHelper &helper, + ConversionPatternRewriter &rewriter, + ValueRange acc, ValueRange cur, + Value pred = {}) { + auto loc = helper.getLoc(); + auto &combineOp = helper.getCombineOp(); + return applyCombineOp(loc, rewriter, combineOp, acc, cur, pred); } // Scan a contiguous elements within a thread and update `srcValues` in place. @@ -66,8 +40,8 @@ scanThreadContiguousElements(SmallVector> &srcValues, unsigned accIndex = (srcIndex % stride) + ((srcIndex / stride) / scanElementsPerThreads) * stride; - accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex], - srcValues[srcIndex]); + accs[accIndex] = + accumulate(helper, rewriter, accs[accIndex], srcValues[srcIndex]); srcValues[srcIndex] = accs[accIndex]; } } @@ -95,11 +69,11 @@ static void warpScan(SmallVector> &srcValues, for (unsigned j = 0; j < acc.size(); ++j) { shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride); } + Value mask = icmp_sge(laneIdAxis, i32_val(i)); SmallVector tempAcc = - accumulate(rewriter, helper.getCombineOp(), shfl, acc); - Value mask = icmp_slt(laneIdAxis, i32_val(i)); + accumulate(helper, rewriter, shfl, acc, mask); for (unsigned j = 0; j < acc.size(); ++j) { - acc[j] = select(mask, acc[j], tempAcc[j]); + acc[j] = select(mask, tempAcc[j], acc[j]); } } srcValues[srcIndex] = acc; @@ -164,9 +138,9 @@ static void AddPartialReduce(SmallVector> &srcValues, unsigned elementStride = helper.getAxisElementStride(); unsigned threadStride = helper.getAxisThreadStride(); unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); - Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); - Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); - Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + Value maskNotFirstWarp = icmp_ne(warpId, i32_val(0)); + Value maskNotFirstLane = icmp_ne(laneIdAxis, i32_val(0)); + Value maskNotFirstThread = or_(maskNotFirstWarp, maskNotFirstLane); struct Accumulator { SmallVector acc; SmallVector maskedAcc; @@ -212,22 +186,24 @@ static void AddPartialReduce(SmallVector> &srcValues, accumulator.maskedAcc = partialReduce; continue; } - accumulator.acc = accumulate(rewriter, helper.getCombineOp(), - accumulator.acc, partialReduce); - Value mask = icmp_slt(warpId, i32_val(i + 1)); + Value mask = icmp_sge(warpId, i32_val(i + 1)); + accumulator.acc = + accumulate(helper, rewriter, accumulator.acc, partialReduce, mask); for (unsigned j = 0; j < helper.getNumOperands(); ++j) { accumulator.maskedAcc[j] = - select(mask, accumulator.maskedAcc[j], accumulator.acc[j]); + select(mask, accumulator.acc[j], accumulator.maskedAcc[j]); } } - auto temp = accumulate(rewriter, helper.getCombineOp(), - accumulator.maskedAcc, srcValues[srcIndex]); + + Value pred = axisBlockId == 0 ? maskNotFirstWarp : Value{}; + auto temp = accumulate(helper, rewriter, accumulator.maskedAcc, + srcValues[srcIndex], pred); if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. auto val = srcValues[srcIndex]; for (unsigned i = 0; i < helper.getNumOperands(); ++i) { - temp[i] = select(maskFirstWarp, val[i], temp[i]); + temp[i] = select(maskNotFirstWarp, temp[i], val[i]); } } srcValues[srcIndex] = temp; @@ -235,19 +211,18 @@ static void AddPartialReduce(SmallVector> &srcValues, SmallVector lastElement(helper.getNumOperands()); for (unsigned i = 0; i < helper.getNumOperands(); ++i) { auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride); - lastElement[i] = select(maskFirstLane, accumulator.maskedAcc[i], elem); + lastElement[i] = select(maskNotFirstLane, elem, accumulator.maskedAcc[i]); } for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + pred = axisBlockId == 0 ? maskNotFirstThread : Value{}; auto laneValue = srcValues[srcIndex - i * elementStride]; - laneValue = - accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + laneValue = accumulate(helper, rewriter, lastElement, laneValue, pred); if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. for (unsigned j = 0; j < helper.getNumOperands(); ++j) { - laneValue[j] = - select(maskFirstThread, - srcValues[srcIndex - i * elementStride][j], laneValue[j]); + laneValue[j] = select(maskNotFirstThread, laneValue[j], + srcValues[srcIndex - i * elementStride][j]); } } srcValues[srcIndex - i * elementStride] = laneValue; @@ -300,8 +275,8 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, if (axisBlockId == 0) // First chunk and first block accumulator = srcValues[srcIndex]; else - srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(), - accumulator, srcValues[srcIndex]); + srcValues[srcIndex] = + accumulate(helper, rewriter, accumulator, srcValues[srcIndex]); // Update the rest of the contiguous elements. auto lastElement = srcValues[srcIndex]; if (scanDim > 1) { @@ -319,8 +294,7 @@ static void AddPartialReduceOneWarp(SmallVector> &srcValues, } for (unsigned i = 1; i < scanElementsPerThreads; ++i) { auto laneValue = srcValues[srcIndex - i * elementStride]; - laneValue = - accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + laneValue = accumulate(helper, rewriter, lastElement, laneValue); if (axisBlockId == 0) { for (unsigned j = 0; j < helper.getNumOperands(); ++j) { // For the first warp and first chunk we don't have anything to diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 5cc537d5fc..dc5f395c67 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -216,11 +216,11 @@ static void createTMAAsyncCopy( // If all the transitive uses of the given value have are used by a convert to // the same dot operand encoding, return the shared encoding that needs to be // used to be compatible with users' layouts. If there are imcompatible shared -// encodings set `incompatible` to true. +// encodings, raise assertion, since incompatible shared encoding has been +// handled in splitLoadsForIncompatible. static std::optional -getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { +getSharedEncIfAllUsersAreDotEnc(Value val) { ttg::SharedEncodingAttr attr; - incompatible = false; for (Operation *user : val.getUsers()) { ttg::SharedEncodingAttr tempAttr; if (user->getNumResults() != 1) @@ -230,8 +230,7 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { // First time we find a shared encoding in the chain, save it and try to // use it if it is compatible with the other users. tempAttr = cast(memDesc.getEncoding()); - if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible) - .has_value()) + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) return std::nullopt; } else { if (!isa(user)) @@ -245,16 +244,12 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { auto order = ttg::getOrder(srcTy.getEncoding()); unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); tempAttr = ttg::SharedEncodingAttr::get( - val.getContext(), dotOpEnc, srcTy.getShape(), - ttg::getOrder(srcTy.getEncoding()), - ttg::getCTALayout(srcTy.getEncoding()), - srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); + val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, + bitWidth, /*needTrans=*/false); } // Check that the shared encodings needed by the users are compatible. - if (attr != nullptr && attr != tempAttr) { - incompatible = true; - return std::nullopt; - } + if (attr != nullptr) + assert(attr == tempAttr && "incompatible shared encoding"); attr = tempAttr; } return attr; @@ -444,43 +439,8 @@ assignMemoryLayouts(llvm::SmallVector> loadInfo.sharedEncoding = getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); } else if (auto dot = dyn_cast(use)) { - bool incompatible = false; loadInfo.sharedEncoding = - getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible) - .value_or(nullptr); - // If we can't agree on a shared encoding skip pipelinig the load. - if (incompatible) - continue; - // HACK: Triton LLVM codegen has a bug where local_loads from #shared to - // #mma layout can lead to invalid code if the loaded shape is smaller - // than the mma tile (e.g. loading a 128x1 tensor for an MMAv2 dot with - // tile {16,8} is bad because 1 < 8). To work around this, don't - // pipeline such loads. - // - // The codegen bug is caught by an assertion, so if you think you've - // fixed it, feel free to delete this code and see if the assert still - // fails. :) - if (!loadInfo.sharedEncoding) { - if (auto dotEnc = dyn_cast( - dot.getResult().getType().getEncoding())) { - auto loadTy = cast(op->getResultTypes()[0]); - auto mmaInstrShape = dotEnc.getInstrShape(); - if (loadTy.getRank() < mmaInstrShape.size()) - continue; - bool ok = true; - for (int i = 0; i < mmaInstrShape.size(); i++) { - if (loadTy.getShape()[loadTy.getRank() - mmaInstrShape.size() + - i] < mmaInstrShape[i]) { - ok = false; - break; - } - } - // If this load might trigger the bug, don't do the fallback logic - // below, which might allow the load to be pipelined. - if (!ok) - continue; - } - } + getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); } } else if (auto loadOp = dyn_cast(use)) { // The use of this loadOp is another loadOp. If the use is not in the @@ -516,9 +476,87 @@ assignMemoryLayouts(llvm::SmallVector> return loadToInfo; } +// Split users to groups, each group has the same shared encoding. +// If not all users are Dot encoding, return empty vector. +static DenseMap> +handleIncompatibleSharedEncoding(Operation *loadOp) { + DenseMap> loadGroups; + // Go through transitive uses of the loadOp in the same block. + for (Operation *user : loadOp->getUsers()) { + if (user->getBlock() != loadOp->getBlock()) + continue; + if (user->getNumResults() != 1) + return loadGroups; + + ttg::SharedEncodingAttr tempAttr; + if (auto memDesc = + dyn_cast(user->getResult(0).getType())) { + tempAttr = cast(memDesc.getEncoding()); + loadGroups[tempAttr].push_back(user); + } else { + if (!isa(user)) + return loadGroups; + auto dotOpEnc = dyn_cast( + cast(user->getResult(0).getType()).getEncoding()); + if (!dotOpEnc) + return loadGroups; + auto srcTy = cast(loadOp->getResult(0).getType()); + auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy.getEncoding()); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SharedEncodingAttr::get( + loadOp->getContext(), dotOpEnc, srcTy.getShape(), + ttg::getOrder(srcTy.getEncoding()), + ttg::getCTALayout(srcTy.getEncoding()), + srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); + loadGroups[tempAttr].push_back(user); + } + } + return loadGroups; +} + +// Clone loads so each group of uses with same shared encoding will have a +// corresponding Load. +static void splitLoadsForIncompatible( + OpBuilder &builder, Operation *loadOp, + DenseMap> &lGroups) { + // The first group will use the original load, create new loads for other + // groups. + unsigned idx = 0; + builder.setInsertionPointAfter(loadOp); + for (auto pair : lGroups) { + SmallVector &group = pair.second; + if (idx++ == 0) + continue; + Operation *newLoad = builder.clone(*loadOp); + for (auto *user : group) { + user->replaceUsesOfWith(loadOp->getResult(0), newLoad->getResult(0)); + } + } +} + +static void splitLoadsWithIncompatibleEncoding(scf::ForOp forOp) { + // Get the list of all loads. + SmallVector loads; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isa(op)) { + loads.push_back(&op); + } + } + OpBuilder builder(forOp); + for (auto *loadOp : loads) { + auto lGroups = handleIncompatibleSharedEncoding(loadOp); + LDBG("groups with different encoding: " << lGroups.size() << " " + << *loadOp); + if (lGroups.size() > 1) + splitLoadsForIncompatible(builder, loadOp, lGroups); + } +} + static llvm::MapVector scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, DenseSet &rootUsers, int numStages) { + ModuleOp moduleOp = forOp->getParentOfType(); tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); @@ -537,6 +575,18 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, if (loadOpToIndLevelAndUse.empty()) return {}; + for (auto iter = loadOpToIndLevelAndUse.begin(); + iter != loadOpToIndLevelAndUse.end();) { + auto iterNext = iter + 1; + if (std::get<1>(*iter) >= numStages - 1) + // We assume loads with different dist are assigned to different stages. + // If numStages is 2, we will have no stage available for indirect loads + // with dist >= 1. In general, when dist is equal to numStages - 1, we + // should not pipeline it. + loadOpToIndLevelAndUse.erase(iter); + iter = iterNext; + } + // Check which loads are good for pipelining, and assign them // memory layouts. llvm::MapVector loadToInfo = @@ -1056,6 +1106,8 @@ static void invalidateBarriers(OpBuilder &builder, bool mlir::triton::preProcessLoopAndGetSchedule( scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) { + splitLoadsWithIncompatibleEncoding(forOp); + // Schedule the loads and root ops (dot ops) in the loop. This will give us // a scaffold for the final schedule. DenseSet rootUsers; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp index 81b2674576..5d772cf2ef 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -675,7 +675,7 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, Value rangeIncrStep = rewriter.create(loc, rangeDiff, step); Value rangeDecr = rewriter.create(loc, rangeIncrStep, stepDecr); - Value totalIterations = rewriter.create(loc, rangeDecr, step); + Value totalIterations = rewriter.create(loc, rangeDecr, step); // Capture predicates for dynamic loops. SmallVector predicates(maxStage + 1); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 7aa7d1a8f1..683ff5dfa4 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5322,11 +5322,13 @@ def matmul_kernel( # @pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv', 'float8e4b15']) @pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_str, low_precision_acc, device): + num_stages = 3 if is_cuda(): cc = torch.cuda.get_device_capability() if cc[0] >= 9 and in_type_str == "float8e4b15": pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90") elif is_hip(): + num_stages = 2 if in_type_str != 'float8e5': pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.') @@ -5340,7 +5342,8 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), - C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps) + C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps, + num_pipeline_stages=num_stages) torch_a = torch.from_numpy(A).to(device=device) th_a = f8_to_f16(torch_a, in_type_str) torch_b = torch.from_numpy(B).to(device=device) @@ -5697,3 +5700,52 @@ def check_loop_unroll_count(ir, opStr, loop_unroll_factor): for unroll_factor in [1, 2, 4, 5, 8]: h = _kernel[(1, )](torch.empty(1, device=device), unroll_factor) check_loop_unroll_count(h.asm["ttir"], 'tt.atomic_rmw', unroll_factor) + + +@triton.jit +def sanitize_add(a, b): + a64 = a.to(tl.int64) + b64 = b.to(tl.int64) + r64 = a64 + b64 + tl.device_assert((r64 >= -2**31) & (r64 <= 2**31 - 1)) + return a + b + + +def test_side_effectful_reduction(device): + if device != "cuda": + pytest.xfail() + + @triton.jit(debug=True) + def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.reduce(vals, 0, sanitize_add) + tl.store(Z, z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros((), device="cuda", dtype=torch.int32) + sanitize_sum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.sum().to(torch.int32)) + + +def test_side_effectful_scan(device): + if device != "cuda": + pytest.xfail() + + @triton.jit(debug=True) + def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr): + vals = tl.load(X + tl.arange(0, BLOCK)) + z = tl.associative_scan(vals, 0, sanitize_add) + tl.store(Z + tl.arange(0, BLOCK), z) + + BLOCK = 512 + torch.manual_seed(42) + X = torch.randint(0, 10, [BLOCK], device="cuda", dtype=torch.int32) + X[:300] = 32 + X[300:] = 0 + Z = torch.zeros_like(X) + sanitize_cumsum_kernel[(1, )](Z, X, BLOCK=BLOCK) + torch.testing.assert_close(Z, X.cumsum(0).to(torch.int32)) diff --git a/python/test/unit/runtime/test_autotuner.py b/python/test/unit/runtime/test_autotuner.py index 2858299330..7d7867a2ca 100644 --- a/python/test/unit/runtime/test_autotuner.py +++ b/python/test/unit/runtime/test_autotuner.py @@ -7,22 +7,25 @@ @pytest.mark.parametrize('use_cuda_graph', [False, True]) def test_kwargs(use_cuda_graph: bool, device: str): - N = 1024 - src = torch.randn(N, device=device) - dst = torch.empty(N, device=device) + M, N = 1024, 16 + src = torch.randn(M * N, device=device) + dst = torch.empty(M * N, device=device) - configs = [triton.Config(kwargs={'BLOCK_SIZE': 32}), triton.Config(kwargs={'BLOCK_SIZE': 128})] + configs = [triton.Config(kwargs={'BLOCK_SIZE_M': 32}), triton.Config(kwargs={'BLOCK_SIZE_M': 128})] - @triton.autotune(configs=configs, key=['N'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph) + @triton.autotune(configs=configs, key=['M'], warmup=1, rep=1, use_cuda_graph=use_cuda_graph) @triton.jit - def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): - offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - x = tl.load(src + offsets, mask=offsets < N) - tl.store(dst + offsets, x, mask=offsets < N) - - grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) - _kernel[grid](dst, src, N) - _kernel[grid](dst=dst, src=src, N=N) + def _kernel(dst, src, stride_m: tl.constexpr, M, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_M: tl.constexpr): + offsets_m = tl.program_id(0) * stride_m + tl.arange(0, BLOCK_SIZE_M) + offsets_n = tl.arange(0, BLOCK_SIZE_N) + x = tl.load(src + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :]) + tl.store(dst + offsets_m[:, None] * BLOCK_SIZE_N + offsets_n[None, :], x) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE_M']), ) + _kernel[grid](dst, src, N, M, N) + # the key word args could be in arbitrary order. + _kernel[grid](dst=dst, src=src, M=M // 2, stride_m=N, BLOCK_SIZE_N=N) + assert len(_kernel.cache) == 2 def test_restore(device): diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 59191a31b7..2b6a7ba32c 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -38,7 +38,7 @@ def __init__( self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] else: self.configs = configs - self.key_idx = [arg_names.index(k) for k in key] + self.keys = key self.cache = {} self.arg_names = arg_names @@ -136,12 +136,9 @@ def run(self, *args, **kwargs): used_cached_result = True if len(self.configs) > 1: all_args = {**self.nargs, **kwargs} - _args = [] - for name in self.arg_names: - if name in all_args: - _args.append(all_args[name]) - key = [_args[i] for i in self.key_idx] - for arg in _args: + _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} + key = [_args[key] for key in self.keys if key in _args] + for _, arg in _args.items(): if hasattr(arg, "dtype"): key.append(str(arg.dtype)) key = tuple(key) diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index c71c70bb8b..ce97cadbd3 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -239,19 +239,19 @@ def get_hip_autotune_config(): return [ triton.Config( {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, - num_warps=4, num_stages=0), + num_warps=4, num_stages=2), triton.Config( {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2}, - num_warps=8, num_stages=0), + num_warps=8, num_stages=2), triton.Config( {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, - num_warps=8, num_stages=0), + num_warps=8, num_stages=2), triton.Config( {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3}, - num_warps=4, num_stages=0), + num_warps=4, num_stages=2), triton.Config( {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8}, - num_warps=4, num_stages=0), + num_warps=4, num_stages=2), ] diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 7c9674cd4b..f447077bdf 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -3,11 +3,13 @@ =============== This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + Credits: OpenAI kernel team Extra Credits: -- Original flash attention paper (https://arxiv.org/abs/2205.14135) -- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) + +* Original flash attention paper (https://arxiv.org/abs/2205.14135) +* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) """ diff --git a/python/tutorials/07-extern-functions.py b/python/tutorials/07-extern-functions.py index 6ac63e35b3..45e4c697c4 100644 --- a/python/tutorials/07-extern-functions.py +++ b/python/tutorials/07-extern-functions.py @@ -3,7 +3,9 @@ ============================== Triton can invoke a custom function from an external library. In this example, we will use the `libdevice` library to apply `asin` on a tensor. -Please refer to https://docs.nvidia.com/cuda/libdevice-users-guide/index.html (CUDA) and/or https://github.com/ROCm/llvm-project/tree/amd-staging/amd/device-libs/ocml/src (HIP) regarding the semantics of all available libdevice functions. + +Please refer to `CUDA libdevice-users-guide `_ and/or `HIP device-lib source code `_ regarding the semantics of all available libdevice functions. + In `libdevice.py`, we try to aggregate functions with the same computation but different data types together. For example, both `__nv_asin` and `__nv_asinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`. Triton automatically selects the correct underlying device function to invoke based on input and output types. diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 1ce34f9a26..699c4aa4f9 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -1,10 +1,22 @@ """ -Persistent FP8 Matmul +Persistent Matmul ===================== This script demonstrates persistent kernel implementations of matrix multiplication using Triton. -It includes various matmul methods, such as naive, persistent, and TMA (Tensor Memory Accelerator) based approaches, and only supports GPUs with compute capability >= 9.0. -Triton and CuBLAS implementations are benchmarked under different configurations and evaluated using the proton profiler. +Various matmul methods are included, such as naive, persistent, and TMA (Tensor Memory Accelerator) based approaches. +The kernels support both FP16 and FP8 data types but the FP8 implementation is only available on CUDA devices with compute capability >= 9.0. + +Triton and cuBLAS implementations are benchmarked under different configurations and evaluated using the proton profiler. Users can pass command-line arguments to specify matrix dimensions and iteration steps flexibly. + +.. code-block:: bash + + # FP8 + python 09-persistent-matmul.py --prec fp8 --K_range 128 1024 --K_step 128 + + # FP16 + python 09-persistent-matmul.py --prec fp16 --K_range 128 1024 --K_step 128 + +Note that currently this tutorial will fail on devices with a small shared memory size, such as RTX-4090. """ import argparse @@ -36,12 +48,12 @@ def _matmul_launch_metadata(grid, kernel, args): ret = {} M, N, K = args["M"], args["N"], args["K"] ret["name"] = f"{kernel.name} [M={M}, N={N}, K={K}]" - ret["flops8"] = 2. * M * N * K if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 - ret["bytes"] = bytes_per_elem * (M * K + N * K) + ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K + ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) return ret @@ -328,7 +340,7 @@ def matmul_tma_persistent(a, b): N, K = b.shape dtype = a.dtype - c = torch.zeros((M, N), device=a.device, dtype=dtype) + c = torch.empty((M, N), device=a.device, dtype=dtype) desc_a = triton.tools.experimental_descriptor.create_2d_tma_descriptor(a.data_ptr(), M, K, configs[dtype]["BLOCK_SIZE_M"], configs[dtype]["BLOCK_SIZE_K"], @@ -481,7 +493,7 @@ def matmul_device_tma_persistent(a, b, tiles_per_update): N, K = b.shape dtype = a.dtype - c = torch.zeros((M, N), device=a.device, dtype=dtype) + c = torch.empty((M, N), device=a.device, dtype=dtype) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count tma_size = 128 workspace = torch.empty(NUM_SMS * 3 * tma_size, dtype=torch.uint8, device="cuda") @@ -511,9 +523,9 @@ def cublas_matmul(a, b): dtype = a.dtype c = torch.empty((M, N), device=a.device, dtype=dtype) bytes_per_elem = a.element_size() - flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops" - with proton.scope(f"cublas M={M}, N={N}, K={K}", - {"bytes": bytes_per_elem * (M * K + N * K), flops_str: 2. * M * N * K}): + flops_str = f"flops{bytes_per_elem * 8}" + with proton.scope(f"cublas [M={M}, N={N}, K={K}]", + {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): cublas.matmul(a, b, c) return c @@ -521,11 +533,10 @@ def cublas_matmul(a, b): def torch_matmul(a, b): M, K = a.shape N, K = b.shape - dtype = a.dtype bytes_per_elem = a.element_size() - flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops" - with proton.scope(f"torch M={M}, N={N}, K={K}", - {"bytes": bytes_per_elem * (M * K + N * K), flops_str: 2. * M * N * K}): + flops_str = f"flops{bytes_per_elem * 8}" + with proton.scope(f"torch [M={M}, N={N}, K={K}]", + {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): c = torch.matmul(a, b.T) return c @@ -558,10 +569,8 @@ def bench(K, dtype, tiles_per_update, reps=10): for _ in range(reps): matmul_tma_persistent(a, b) time.sleep(0.01) - flops_str = "flops8" if dtype == torch.float8_e4m3fn else "flops" with proton.scope( - f"matmul_kernel_device_tma_persistent M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}", - {"bytes": a.element_size() * (M * K + N * K), flops_str: 2. * M * N * K}): + f"matmul_kernel_device_tma_persistent [M={M}, N={N}, K={K}, tiles_per_update={tiles_per_update:02}]"): for _ in range(reps): matmul_device_tma_persistent(a, b, tiles_per_update) time.sleep(0.01) @@ -608,6 +617,17 @@ def validate(M, N, K, dtype, tiles_per_update): print() +def show_profile(precision, profile_name): + import triton.profiler.viewer as proton_viewer + metrics = ["time/ms"] + if precision == 'fp8': + metrics = ["tflop8/s"] + metrics + elif precision == 'fp16': + metrics = ["tflop16/s"] + metrics + file_name = f"{profile_name}.hatchet" + proton_viewer.parse(metrics, file_name, depth=100) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-K", type=int, required=False, default=512) @@ -642,3 +662,4 @@ def validate(M, N, K, dtype, tiles_per_update): for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): bench(K, dtype, args.tiles_per_update) proton.finalize() + show_profile(args.prec, "matmul") diff --git a/test/Conversion/amd/compute-base-ptr.mlir b/test/Conversion/amd/compute-base-ptr.mlir new file mode 100644 index 0000000000..e8376b1d8b --- /dev/null +++ b/test/Conversion/amd/compute-base-ptr.mlir @@ -0,0 +1,18 @@ +// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}> +#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 544 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @local_load_offset + tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) { + %0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> + %1 = triton_gpu.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> + // This catches base ptr calculation in the computeBasePtr, checks if the gep has correct element type. + // CHECK: llvm.sub + // CHECK-NEXT: llvm.getelementptr + // CHECK-SAME: (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + %2 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + tt.return + } +} diff --git a/test/TritonGPU/loop-pipeline-indirect-load.mlir b/test/TritonGPU/loop-pipeline-indirect-load.mlir new file mode 100644 index 0000000000..74794b9496 --- /dev/null +++ b/test/TritonGPU/loop-pipeline-indirect-load.mlir @@ -0,0 +1,90 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=2 | FileCheck %s +// CHECK-LABEL: @indirect_load_two_stages +// CHECK: scf.for +// CHECK: tt.dot +// CHECK: tt.load +// CHECK: async_copy_global_to_local +// CHECK: async_copy_global_to_local + +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}> +#blocked3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [2, 1], order = [1, 0]}> + +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @indirect_load_two_stages(%arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg18: i32, %arg19: i32) attributes {noinline = false} { + %c32_i32 = arith.constant 32 : i32 + %c16_i32 = arith.constant 16 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<16x128xf32, #blocked> + + %0 = tt.get_program_id y : i32 + %1 = tt.addptr %arg3, %0 : !tt.ptr, i32 + %2 = tt.load %1 : !tt.ptr + + %7 = tt.get_program_id x : i32 + %8 = arith.muli %7, %c16_i32 : i32 + %10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %15 = tt.splat %8 : i32 -> tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %18 = arith.addi %15, %10 : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + + %20 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %22 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %34 = arith.extsi %arg12 : i32 to i64 + %35 = arith.muli %2, %34 : i64 + %36 = tt.addptr %arg2, %35 : !tt.ptr, i64 + + %47 = tt.splat %arg4 : !tt.ptr -> tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %48 = tt.addptr %47, %20 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + + %59 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %61 = arith.extsi %59 : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> to tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> + %63 = tt.expand_dims %61 {axis = 0 : i32} : tensor<128xi64, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi64, #blocked3> + + %85 = arith.extsi %22 : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> to tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %107 = tt.splat %36 : !tt.ptr -> tensor<32x128x!tt.ptr, #blocked3> + %108 = tt.splat %34 : i64 -> tensor<32x1xi64, #blocked3> + %109 = tt.broadcast %63 : tensor<1x128xi64, #blocked3> -> tensor<32x128xi64, #blocked3> + + %101 = tt.splat %arg5 : !tt.ptr -> tensor<16x32x!tt.ptr, #blocked1> + %111:1 = scf.for %arg28 = %arg18 to %arg19 step %c32_i32 iter_args(%arg29 = %cst) -> (tensor<16x128xf32, #blocked>) : i32 { + %129 = tt.splat %arg28 : i32 -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %160 = tt.addptr %48, %129 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %161 = tt.load %160 : tensor<32x!tt.ptr, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %162 = tt.expand_dims %161 {axis = 0 : i32} : tensor<32xi64, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi64, #blocked1> + %163 = tt.broadcast %162 : tensor<1x32xi64, #blocked1> -> tensor<16x32xi64, #blocked1> + %182 = tt.addptr %101, %163 : tensor<16x32x!tt.ptr, #blocked1>, tensor<16x32xi64, #blocked1> + %183 = tt.load %182 : tensor<16x32x!tt.ptr, #blocked1> + + %197 = arith.extsi %arg28 : i32 to i64 + %198 = tt.splat %197 : i64 -> tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %199 = arith.addi %198, %85 : tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> + %200 = tt.expand_dims %199 {axis = 1 : i32} : tensor<32xi64, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<32x1xi64, #blocked3> + %201 = arith.muli %200, %108 : tensor<32x1xi64, #blocked3> + %202 = tt.broadcast %201 : tensor<32x1xi64, #blocked3> -> tensor<32x128xi64, #blocked3> + %203 = arith.addi %202, %109 : tensor<32x128xi64, #blocked3> + %204 = tt.addptr %107, %203 : tensor<32x128x!tt.ptr, #blocked3>, tensor<32x128xi64, #blocked3> + %209 = tt.load %204 : tensor<32x128x!tt.ptr, #blocked3> + + %210 = triton_gpu.convert_layout %183 : tensor<16x32xf32, #blocked1> -> tensor<16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %211 = triton_gpu.convert_layout %209 : tensor<32x128xf32, #blocked3> -> tensor<32x128xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %212 = tt.dot %210, %211, %arg29 : tensor<16x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<32x128xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<16x128xf32, #blocked> + scf.yield %212 : tensor<16x128xf32, #blocked> + } + %112 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 1, parent = #blocked3}>> -> tensor<16x1xi32, #blocked3> + %113 = tt.splat %2 : i64 -> tensor<16x1xi64, #blocked3> + %114 = arith.extsi %112 : tensor<16x1xi32, #blocked3> to tensor<16x1xi64, #blocked3> + %115 = arith.addi %113, %114 : tensor<16x1xi64, #blocked3> + %116 = arith.extsi %arg17 : i32 to i64 + %117 = tt.splat %116 : i64 -> tensor<16x1xi64, #blocked3> + %118 = arith.muli %115, %117 : tensor<16x1xi64, #blocked3> + %119 = tt.expand_dims %59 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked3}>> -> tensor<1x128xi32, #blocked3> + %120 = tt.broadcast %118 : tensor<16x1xi64, #blocked3> -> tensor<16x128xi64, #blocked3> + %121 = arith.extsi %119 : tensor<1x128xi32, #blocked3> to tensor<1x128xi64, #blocked3> + %122 = tt.broadcast %121 : tensor<1x128xi64, #blocked3> -> tensor<16x128xi64, #blocked3> + %123 = arith.addi %120, %122 : tensor<16x128xi64, #blocked3> + %124 = tt.splat %arg7 : !tt.ptr -> tensor<16x128x!tt.ptr, #blocked3> + %125 = tt.addptr %124, %123 : tensor<16x128x!tt.ptr, #blocked3>, tensor<16x128xi64, #blocked3> + %128 = triton_gpu.convert_layout %111#0 : tensor<16x128xf32, #blocked> -> tensor<16x128xf32, #blocked3> + tt.store %125, %128 : tensor<16x128x!tt.ptr, #blocked3> + tt.return + } +} diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 81cb2d9a01..fca72ebda7 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -84,7 +84,7 @@ // AMD: %[[SUBI_23:.*]] = arith.subi %[[UB]], %[[LB]] // AMD: %[[ADDI_24:.*]] = arith.addi %[[SUBI_23]], %[[STEP]] // AMD: %[[ADDI_25:.*]] = arith.addi %[[ADDI_24]], %[[SELECT_22]] -// AMD: %[[DIVUI_26:.*]] = arith.divui %[[ADDI_25]], %[[STEP]] +// AMD: %[[DIVUI_26:.*]] = arith.divsi %[[ADDI_25]], %[[STEP]] // AMD: %[[ADDI_27:.*]] = arith.addi %[[DIVUI_26]], %[[CM1]] // AMD: %[[CMPI_28:.*]] = arith.cmpi sge, %[[ADDI_27]], %[[C0]] // AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %[[FOR]]#4 @@ -844,9 +844,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // check that the load didn't get pipelined. - // COMMON-NOT: alloc - // COMMON: scf.for + // check that the load with incompatiable shared encoding gets cloned and feeds into uses with same encoding + // AMD-NOT: alloc + // AMD: scf.for + // CHECK: local_alloc + // CHECK: local_alloc + // CHECK: scf.for + // CHECK: local_load {{.*}} tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1 + // CHECK: convert_layout {{.*}} tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0 + // CHECK: tt.dot + // CHECK: tt.trans %arg %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> %19 = triton_gpu.convert_layout %9 : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> @@ -1453,7 +1460,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // ----- // COMMON-LABEL: @dont_pipeline_128x1 -// COMMON-NOT: local_load{{.*}}128x1 +// AMD-NOT: local_load{{.*}}128x1 +// CHECK: local_load{{.*}}128x1 #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 78a3c21280..61a782f334 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -46,6 +46,13 @@ class HIPOptions: max_num_imprecise_acc_default: int = 0 backend_name: str = 'hip' + # The following option provides hints to the AMDGPU backend regarding instruction scheduling + # for all `tt.dot` operations in a kernel. The "default" variant preserves the default + # instruction scheduling of the AMDGPU backend which aims at maximizing occupancy. + # The option is experimental and may change at any time regarding its semantics and/or may + # be gone entirely anytime. + instruction_sched_variant: str = 'default' + def __post_init__(self): default_libdir = Path(__file__).parent / 'lib' extern_libs = {} if self.extern_libs is None else dict(self.extern_libs) @@ -162,7 +169,7 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_remove_layout_conversions(pm) amd.passes.ttgpuir.add_optimize_epilogue(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) - use_new_pipeliner = os.getenv("TRITON_HIP_USE_NEW_STREAM_PIPELINE", "0") == "1" + use_new_pipeliner = os.getenv("TRITON_HIP_USE_NEW_STREAM_PIPELINE", "1") == "1" if amd.has_matrix_core_feature(options.arch): if use_new_pipeliner: # In the old pipeliner we only support num_stages = 0/1, which means something @@ -174,6 +181,7 @@ def make_ttgir(mod, metadata, options): if options.num_stages == 0: amd.passes.ttgpuir.add_stream_pipeline(pm) passes.common.add_canonicalizer(pm) + amd.passes.ttgpuir.insert_instruction_sched_hints(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) passes.ttgpuir.add_remove_layout_conversions(pm) passes.ttgpuir.add_reduce_data_duplication(pm) @@ -221,6 +229,7 @@ def make_llir(src, metadata, options): passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) + amd.passes.ttgpuir.lower_instruction_sched_hints(pm, options.instruction_sched_variant) if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": passes.llvmir.add_di_scope(pm) # This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 8960c05143..4721d14ecb 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -32,4 +32,24 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "TritonAMDGPUDialect.td" include "TritonAMDGPUAttrDefs.td" +class TT_AMDGPU_Op traits = []> : + Op { +} + +def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> { + let summary = "A placeholder op for instruction scheduling hints within a basic block"; + let description = [{ + A placeholder op for instruction scheduling hints applied to instructions within + a basic block where the placeholder op is located. This op is primarily intended + to be used to adjust instruction scheduling inside the resulting main loop + of a `tt.dot` operation. It's easier to identify dot ops at a high level and, thus, + to mark intended scheduling regions. The hint ops are eventually lowered + into LLVM AMDGPU instruction scheduling primitives, which are meant to control + how different kinds of instructions (valu/mfma, global/shared memory, etc.) should + interleave for better instruction level parallelism. + }]; + + let assemblyFormat = [{attr-dict}]; +} + #endif diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h index be9efe4033..67ff40d5b9 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.h @@ -34,6 +34,10 @@ createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0); std::unique_ptr> createConvertTritonAMDGPUToLLVMPass(StringRef targetArch, bool ftz); std::unique_ptr> createConvertBuiltinFuncToLLVMPass(); +std::unique_ptr> +createInsertInstructionSchedHintsPass(); +std::unique_ptr> +createLowerInstructionSchedHintsPass(std::string variant); #define GEN_PASS_REGISTRATION #include "TritonAMDGPUToLLVM/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index b27c3bf8f9..ccb2b1898f 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -55,4 +55,24 @@ def ConvertBuiltinFuncToLLVM : Pass<"convert-builtin-func-to-llvm", "mlir::Modul } +def InsertInstructionSchedHints : Pass<"insert-instruction-sched-hints", "mlir::ModuleOp"> { + let summary = "Insert instruction scheduling hints after the dot ops in the main loop"; + let constructor = "mlir::triton::createInsertInstructionSchedHintsPass()"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; +} + +def LowerInstructionSchedHints : Pass<"lower-insert-instruction-sched-hints", "mlir::ModuleOp"> { + let summary = "Lower instruction scheduling hints to LLVM intrinsics"; + let constructor = "mlir::triton::createLowerInstructionSchedHintsPass(\"\")"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect"]; + + let options = [ + Option<"variant", "variant", "std::string", /*default*/"\"default\"", + "instruction scheduling variant">, + ]; +} + + #endif diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index 705c4258d0..dc05155527 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -18,10 +18,12 @@ add_triton_library(TritonAMDGPUToLLVM OptimizeLDSUsage.cpp OptimizeLDSUtility.cpp SPMDOpToLLVM.cpp + SchedInstructions.cpp DEPENDS TritonAMDGPUConversionPassIncGen LINK_LIBS PUBLIC TritonGPUToLLVM + TritonAMDGPUIR ) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp index 740e106f4f..03b7c56b7e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp @@ -68,9 +68,10 @@ Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc, const SharedMemoryObject &smemObj) { Value base = smemObj.base; Type type = base.getType(); + Type elemType = smemObj.getBaseElemType(); for (int i = 0; i < smemObj.strides.size(); ++i) { Value offset = sub(i32_val(0), mul(smemObj.offsets[i], smemObj.strides[i])); - base = gep(ptr_ty(rewriter.getContext(), 3), type, base, offset); + base = gep(type, elemType, base, offset); } return base; } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp new file mode 100644 index 0000000000..c9413a52f5 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -0,0 +1,205 @@ +#include "TritonAMDGPUToLLVM/Passes.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace mlir::triton { +#define GEN_PASS_DEF_INSERTINSTRUCTIONSCHEDHINTS +#define GEN_PASS_DEF_LOWERINSTRUCTIONSCHEDHINTS +#include "TritonAMDGPUToLLVM/Passes.h.inc" +} // namespace mlir::triton + +using namespace mlir; + +namespace { + +// The bitmask that encodes kinds of the instructions from AMD ISA. +// The bitmask is used for providing instruction scheduling hints. +enum InstructionKindMask { + NONE = 0x0000000, + ALL_ALU = 0x00000001, + VALU = 0x00000002, + SALU = 0x00000004, + MFMA = 0x00000008, + ALL_VMEM = 0x00000010, + VMEM_READ = 0x00000020, + VMEM_WRITE = 0x00000040, + ALL_DS = 0x00000080, + DS_READ = 0x00000100, + DS_WRITE = 0x00000200 +}; + +// Create an intrinsic to control how different instruction kinds should +// interleave for better ILP. +void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc, + InstructionKindMask maskValue, int sizeValue, + int groupIdValue) { + MLIRContext *ctx = rewriter.getContext(); + auto intrinsicName = str_attr("llvm.amdgcn.sched.group.barrier"); + + Value mask = + LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); + Value size = + LLVM::createConstantI32(loc, rewriter, static_cast(sizeValue)); + Value groupId = LLVM::createConstantI32(loc, rewriter, + static_cast(groupIdValue)); + + LLVM::FastmathFlagsAttr defaultFlags{}; + rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{mask, size, groupId}, + defaultFlags); +} + +// Insert intrinsic that controls the types of instructions that may be +// allowed to cross the intrinsic during instruction scheduling +Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc, + int64_t maskValue) { + MLIRContext *ctx = rewriter.getContext(); + auto intrinsicName = str_attr("llvm.amdgcn.sched.barrier"); + LLVM::FastmathFlagsAttr defaultFlags{}; + + Value mask = + LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); + return rewriter.create(loc, TypeRange{}, intrinsicName, + ValueRange{mask}, defaultFlags); +} + +// Insert an experimental intrinsic for instruction group level parallelism. +// The intrinsic takes a value that specifies the strategy. +Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { + MLIRContext *ctx = rewriter.getContext(); + auto intrinsicName = str_attr("llvm.amdgcn.iglp.opt"); + LLVM::FastmathFlagsAttr defaultFlags{}; + Value iglpValue = + LLVM::createConstantI32(loc, rewriter, static_cast(value)); + return rewriter.create( + loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags); +} + +struct InstructionSchedHintsRewriter + : public OpRewritePattern { + + InstructionSchedHintsRewriter(mlir::MLIRContext *ctx, std::string variant) + : OpRewritePattern(ctx) { + std::transform(variant.begin(), variant.end(), variant.begin(), + [](unsigned char c) { return std::tolower(c); }); + + this->schedulingType = llvm::StringSwitch(variant) + .Case("default", SchedulingType::NONE) + .Case("iglp0", SchedulingType::IGLP0) + .Case("iglp1", SchedulingType::IGLP1) + .Default(SchedulingType::UNKNOWN); + } + + enum class SchedulingType : uint32_t { NONE = 0, IGLP0, IGLP1, UNKNOWN }; + + LogicalResult + matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, + PatternRewriter &rewriter) const override { + + if (this->schedulingType == SchedulingType::UNKNOWN) { + llvm::dbgs() + << "[" << getDebugName() << "]: " + << "unknown instruction scheduling variant has been provided\n"; + return mlir::failure(); + } + + // The switch controls whether instructions are allowed to cross the basic + // block boundaries at the very top and at the very bottom. Note, this is + // not supposed to be used together with IGLP OPT according to the AMDGPU + // backend documentation. + const bool limitSchedulingRange = + !(schedulingType == SchedulingType::IGLP0 || + schedulingType == SchedulingType::IGLP1); + Location loc = instructionSchedHint->getLoc(); + Block *block = instructionSchedHint->getBlock(); + if (limitSchedulingRange) { + rewriter.setInsertionPointToStart(block); + createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); + } + + rewriter.setInsertionPoint(block, std::prev(block->end())); + + switch (schedulingType) { + case SchedulingType::IGLP0: + [[fallthrough]]; + case SchedulingType::IGLP1: { + createIglpOpt(rewriter, loc, static_cast(schedulingType) - 1); + break; + } + case SchedulingType::NONE: + [[fallthrough]]; + default: { + break; + } + } + + if (limitSchedulingRange) + createSchedBarrier(rewriter, loc, InstructionKindMask::NONE); + + rewriter.eraseOp(instructionSchedHint); + return mlir::success(); + } + +private: + SchedulingType schedulingType; +}; + +struct LowerInstructionSchedHints + : public triton::impl::LowerInstructionSchedHintsBase< + LowerInstructionSchedHints> { + + explicit LowerInstructionSchedHints(std::string variant) { + this->variant = variant; + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + ModuleOp mod = getOperation(); + + ConversionTarget target(*ctx); + target.addLegalDialect(); + target.addIllegalOp(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx, this->variant); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +struct InsertInstructionSchedHints + : public triton::impl::InsertInstructionSchedHintsBase< + InsertInstructionSchedHints> { + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + ModuleOp mod = getOperation(); + + mod->walk([ctx](triton::DotOp dot) { + if (dyn_cast(dot->getParentOp())) { + mlir::OpBuilder rewriter(ctx); + rewriter.setInsertionPointAfter(dot); + rewriter.create(dot->getLoc()); + } + }); + } +}; +} // namespace + +namespace mlir::triton { +std::unique_ptr> +createLowerInstructionSchedHintsPass(std::string variant) { + return std::make_unique(variant); +} + +std::unique_ptr> +createInsertInstructionSchedHintsPass() { + return std::make_unique(); +} +} // namespace mlir::triton diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 08631e211e..f68714ab5e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" @@ -57,6 +58,7 @@ class TritonLLVMConversionTarget : public ConversionTarget { addIllegalDialect(); addIllegalDialect(); addLegalOp(); + addLegalOp(); } }; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index 9d99bf79ad..f1d04b7270 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -208,10 +208,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { auto order = ttg::getOrder(srcTy.getEncoding()); unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); tempAttr = ttg::SharedEncodingAttr::get( - val.getContext(), dotOpEnc, srcTy.getShape(), - ttg::getOrder(srcTy.getEncoding()), - ttg::getCTALayout(srcTy.getEncoding()), - srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); + val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, + bitWidth, /*needTrans=*/false); } // Check that the shared encodings needed by the users are compatible. if (!tempAttr || (attr != nullptr && attr != tempAttr)) diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index da5718ac6e..84558ea12e 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -44,6 +44,13 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { m.def("add_builtin_func_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(createConvertBuiltinFuncToLLVMPass()); }); + m.def("insert_instruction_sched_hints", [](mlir::PassManager &pm) { + pm.addPass(createInsertInstructionSchedHintsPass()); + }); + m.def("lower_instruction_sched_hints", + [](mlir::PassManager &pm, std::string variant) { + pm.addPass(createLowerInstructionSchedHintsPass(variant)); + }); m.def("add_decompose_unsupported_conversions", [](mlir::PassManager &pm, const std::string &arch) { pm.addPass( diff --git a/third_party/proton/proton/viewer.py b/third_party/proton/proton/viewer.py index f77a65007f..2067466c94 100644 --- a/third_party/proton/proton/viewer.py +++ b/third_party/proton/proton/viewer.py @@ -180,7 +180,7 @@ def filter_frames(gf, include=None, exclude=None, threshold=None, metric=None): return gf -def parse(metrics, filename, include, exclude, threshold, depth, format): +def parse(metrics, filename, include=None, exclude=None, threshold=None, depth=100, format=None): with open(filename, "r") as f: gf, raw_metrics, device_info = get_raw_metrics(f) gf = format_frames(gf, format) @@ -190,10 +190,10 @@ def parse(metrics, filename, include, exclude, threshold, depth, format): # TODO: generalize to support multiple metrics, not just the first one gf = filter_frames(gf, include, exclude, threshold, metrics[0]) print(gf.tree(metric_column=metrics, expand_name=True, depth=depth, render_header=False)) - emitWarnings(gf, metrics) + emit_warnings(gf, metrics) -def emitWarnings(gf, metrics): +def emit_warnings(gf, metrics): if "bytes (inc)" in metrics: byte_values = gf.dataframe["bytes (inc)"].values min_byte_value = np.nanmin(byte_values) @@ -209,7 +209,6 @@ def show_metrics(file_name): for raw_metric in raw_metrics: raw_metric_no_unit = raw_metric.split("(")[0].strip().lower() print(f"- {raw_metric_no_unit}") - return def main():