Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:

- name: Install dependent packages
run: |
sudo pip3 install tabulate cmake sphinx matplotlib myst_parser sphinx-rtd-theme pandas pytest sphinx-gallery sphinx-multiversion
sudo pip3 install tabulate cmake sphinx matplotlib myst_parser sphinx-rtd-theme pandas pytest sphinx-gallery sphinx-multiversion llnl-hatchet

#- name: Fetch dependent branches
# run: |
Expand Down
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
36adf8ecedb64047021265a1e1730773d3b3a9e8
df0864e761107b07e38f5503e0cbee0cebb4c5e8
3 changes: 0 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,6 @@ def documenter(app, obj, parent):
'examples_dirs': '../python/tutorials/',
'gallery_dirs': 'getting-started/tutorials',
'filename_pattern': '',
# TODO: Re-enable the grouped-gemm tutorial. It currently hits this
# assertion:
# https://github.com/triton-lang/triton/blob/main/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp#L127
'ignore_pattern': r'(__init__\.py|11.*.py)',
'within_subsection_order': FileNameSortKey,
'reference_url': {
Expand Down
1 change: 1 addition & 0 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent());
auto ans = mmaLayout.getVersionMajor() == 3 &&
dotOperandLayout.getOpIdx() == 0 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
!cvtNeedsSharedMemory(parentTy, srcTy) &&
(elementTypeSize == 16 || elementTypeSize == 8);
return ans;
Expand Down
56 changes: 23 additions & 33 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#include "ReduceScanCommon.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <vector>

using namespace mlir;
using namespace mlir::triton;
Expand Down Expand Up @@ -80,36 +77,16 @@ struct ReduceOpConversion
private:
const TargetInfoBase &targetInfo;

void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
SmallVector<Value> &acc, ValueRange cur, bool isFirst) const {
if (isFirst) {
acc = SmallVector<Value>(cur.begin(), cur.end());
return;
void accumulate(Location loc, ConversionPatternRewriter &rewriter,
Region &combineOp, SmallVector<Value> &acc, ValueRange cur,
Value pred = {}) const {
auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred);
if (acc.size() < results.size()) {
acc.resize(results.size());
}

// Create a new copy of the reduce block, and inline it
Block *currentBlock = rewriter.getBlock();
Region &parent = *currentBlock->getParent();
rewriter.cloneRegionBefore(combineOp, &parent.front());
auto &newReduce = parent.front();
auto returnOp = dyn_cast<triton::ReduceReturnOp>(newReduce.getTerminator());

llvm::SmallVector<Value> combineArgs(2 * acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
combineArgs[i] = acc[i];
combineArgs[acc.size() + i] = cur[i];
}

rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(),
combineArgs);

auto results = returnOp.getResult();
for (unsigned i = 0; i < acc.size(); ++i) {
acc[i] = results[i];
}

// Delete the terminator, which is no longer used
rewriter.eraseOp(returnOp);
}

SmallVector<SmallVector<Value>>
Expand Down Expand Up @@ -165,7 +142,7 @@ struct ReduceOpConversion
SmallVector<unsigned> key = offsets[i];
key[op.getAxis()] = 0;
bool isFirst = accs.find(key) == accs.end();
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]);
if (isFirst)
indices[key] = srcIndices[i];
}
Expand All @@ -175,17 +152,29 @@ struct ReduceOpConversion
// region and the accumulator values as source.
void warpReduce(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce, unsigned interleave) const {
unsigned numLaneToReduce, unsigned interleave,
Value pred = {}) const {
auto success = targetInfo.warpReduce(rewriter, loc, acc, op,
numLaneToReduce, interleave);
if (success)
return;

auto mod = op->getParentOfType<ModuleOp>();
unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
if (iWarpSize > numLaneToReduce) {
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(iWarpSize);
Value laneId = urem(threadId, warpSize);
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce));
pred = pred ? and_(pred, lanePred) : lanePred;
}

for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave);
}
accumulate(rewriter, op.getCombineOp(), acc, shfl, false);
accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred);
}
}

Expand Down Expand Up @@ -344,7 +333,8 @@ struct ReduceOpConversion
acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy,
threadIsNeeded);
}
warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */);
warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */,
threadIsNeeded);
// only the first thread in each sizeInterWarps is writing
Value writeOffset = readOffset;
SmallVector<Value> writePtrs(op.getNumOperands());
Expand Down
94 changes: 89 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
// TODO: refactor so that it doesn't fail if Allocation.h
// is included after utility.h (due to conflict in `store` macro
// and <atomic>
#include "triton/Analysis/Allocation.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"

#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
//
#include "mlir/IR/TypeUtilities.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include <set>
#include <iterator>
#include <type_traits>

#define DEBUG_TYPE "ttgpu_to_llvm"
Expand All @@ -32,6 +31,91 @@ namespace ttng = ::mlir::triton::nvidia_gpu;
namespace mlir::triton {
class ReduceOp;
class ScanOp;

inline SmallVector<Value>
inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock,
Block *insertionBlock, Block::iterator insertionPoint,
ValueRange combineArgs) {
auto returnOp = combineBlock.getTerminator();
rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint,
combineArgs);

auto results = SmallVector<Value>(returnOp->getOperands());

// Delete the terminator, which is no longer used
rewriter.eraseOp(returnOp);
return results;
}

inline SmallVector<Value> applyCombineOp(Location loc,
ConversionPatternRewriter &rewriter,
Region &combineOp, ValueRange acc,
ValueRange cur, Value pred = {}) {
// Allows for passing an unitialized acc and use cur as the neutral element
if (acc.size() == 0) {
return cur;
}
assert(cur.size() == acc.size());

// Create a new copy of the combine block, and try to speculatively inline it
Block *currentBlock = rewriter.getBlock();
Region &parent = *currentBlock->getParent();

rewriter.cloneRegionBefore(combineOp, parent,
std::next(currentBlock->getIterator()));
Block &newCombine = *currentBlock->getNextNode();

llvm::SmallVector<Value> combineArgs(2 * acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
combineArgs[i] = acc[i];
combineArgs[acc.size() + i] = cur[i];
}

auto isRegionSpeculatable =
std::all_of(newCombine.begin(), newCombine.end(),
[](auto &op) { return isSpeculatable(&op); });

if (!pred || isRegionSpeculatable) {
// Fast path, region has no side effects so we can unconditionally execute
return inlineCombineBlock(rewriter, newCombine, currentBlock,
rewriter.getInsertionPoint(), combineArgs);
}

// Slow case, create an if to only execute region when pred is true
// #currentBlock
// if (pred) {
// #newCombine
// results = combineOp(cur, acc)
// yield results
// } else {
// yield undef
// }
// #thenBlock
Block *thenBlock =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());

auto returnOp = newCombine.getTerminator();
auto results = SmallVector<Value>(returnOp->getOperands());

rewriter.setInsertionPointToEnd(currentBlock);
SmallVector<Value> thenBlockArgs;
thenBlockArgs.reserve(results.size());
for (auto result : results) {
auto ty = result.getType();
auto undef = rewriter.create<LLVM::UndefOp>(loc, ty);
thenBlockArgs.push_back(undef);
thenBlock->addArgument(ty, loc);
}
rewriter.create<cf::CondBranchOp>(loc, pred, &newCombine, combineArgs,
thenBlock, thenBlockArgs);

// Split a block after the call.
rewriter.setInsertionPointToEnd(&newCombine);
rewriter.replaceOpWithNewOp<cf::BranchOp>(returnOp, thenBlock, results);
rewriter.setInsertionPointToStart(thenBlock);
return SmallVector<Value>(thenBlock->getArguments());
}

} // namespace mlir::triton

template <typename SourceOp>
Expand Down
Loading