Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ env:
TRITON_BUILD_WITH_CLANG_LLD: "TRUE"
TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE"
TRITON_DISABLE_LINE_INFO: 1
PROTON_SKIP_PC_SAMPLING_TEST: 1
jobs:
Runner-Preparation:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -460,7 +461,7 @@ jobs:
- name: Install brew dependencies
run: |
brew update
brew install ccache llvm
brew install ccache llvm@19 lld
- name: Compute cache keys
id: cache-key
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/integration-tests.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ env:
TRITON_BUILD_WITH_CLANG_LLD: "TRUE"
TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE"
TRITON_DISABLE_LINE_INFO: 1

PROTON_SKIP_PC_SAMPLING_TEST: 1

jobs:
Runner-Preparation:
Expand Down Expand Up @@ -439,7 +439,7 @@ jobs:
- name: Install brew dependencies
run: |
brew update
brew install ccache llvm
brew install ccache llvm@19 lld

- *compute-cache-keys-step
- *cache-build-dependencies-step
Expand Down
4 changes: 4 additions & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ export_executable_symbols_for_plugins(triton-llvm-opt)
add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED)
target_link_libraries(triton-tensor-layout PRIVATE
TritonGPUIR
TritonNvidiaGPUIR
${triton_libs}
${conversion_libs}
${dialect_libs}
TritonTestAnalysis
)

add_llvm_executable(triton-translate
Expand Down
12 changes: 7 additions & 5 deletions bin/triton-tensor-layout.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include "RegisterTritonDialects.h"

#include "mlir/AsmParser/AsmParser.h"
#include "mlir/AsmParser/AsmParserState.h"
#include "mlir/IR/MLIRContext.h"

#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"

#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
Expand Down Expand Up @@ -114,7 +117,7 @@ LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename,
return failure();
}

auto printLambda = [&](StringRef name, Attribute attr) {
auto printLambda = [&](StringRef name, mlir::Attribute attr) {
ss << "Print layout attribute: #" << name << " = " << attr << "\n";

auto rankedTensorTy = RankedTensorType::get(
Expand Down Expand Up @@ -155,7 +158,7 @@ LogicalResult printLayoutFromString(MLIRContext *context,
if (layoutAttrStr.empty())
return success();

Attribute layout = parseAttribute(layoutAttrStr, context);
mlir::Attribute layout = parseAttribute(layoutAttrStr, context);
if (!layout) {
llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n";
return failure();
Expand All @@ -178,8 +181,7 @@ int main(int argc, char **argv) {
cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n");

DialectRegistry registry;
// Register all dialects that can print tensor layout.
registry.insert<triton::gpu::TritonGPUDialect>();
registerTritonDialects(registry);

MLIRContext ctx(registry);
ctx.loadAllAvailableDialects();
Expand All @@ -189,7 +191,7 @@ int main(int argc, char **argv) {
return 1;
}

Type parsedTy = parseType(TensorStr, &ctx);
mlir::Type parsedTy = parseType(TensorStr, &ctx);
if (!parsedTy) {
llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr
<< "\n";
Expand Down
9 changes: 0 additions & 9 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,6 @@ class TargetInfoBase {
unsigned numLaneToReduce,
unsigned interleave) const = 0;

// TODO (Keren): Remove this function once layout conversion using stmatrix is
// handled by Linear Layout.
virtual bool processReplicaUsingStMatrix(
RewriterBase &rewriter, Location loc, Value smemBase,
SmallVector<Value> &vals, RankedTensorType srcTy, Type elemTy,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> origRepShape,
ArrayRef<unsigned> outOrd, unsigned accumNumReplicates,
int swizzleByteWidth = 0) const = 0;

virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
// Emits LLVM code with |rewriter| to print a message following the given
// format from the device. |formatStrStart| is the pointer to the start of
Expand Down
25 changes: 24 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,32 @@ getThreadsPerWarpWithUniqueData(Attribute layout,
SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);

// Returns the dimensions of the tensor from minor (fast-varying) to
// major (slow-varying). For blocked, mma, and dotOperand layouts,
// though the elements are in registers, the order refers to memory
// layout of the original tensor in global memory.
// For shared Layout, the order refers to which dimension of the original tensor
// is contiguous in shared memory.
SmallVector<unsigned> getOrder(Attribute layout);

// Returns the dimensions along which warpId's are distributed.
// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4]
// tells there are 2 warps along dim0 and 4 warps along dim1.
// warpOrder tells the specific order when distributing warp IDs.
// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows
// [warp0 warp2 warp4 warp6]
// [warp1 warp3 warp5 warp7]
// Note that in most cases, getWarpOrder and getOrder return the same results.
// But this is not guaranteed.
SmallVector<unsigned> getWarpOrder(Attribute layout);

SmallVector<unsigned> getOrder(Attribute layout);
// Returns the dimensions along which threadId's are distributed.
// Similar to warpOrder, threadOrder is necessary to tell the specific thread
// distribution in the warp.
// Note that, in most cases, getThreadOrder and getOrder return the same
// results. But this is not guaranteed. One exception is mfma.transposed layout,
// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1].
SmallVector<unsigned> getThreadOrder(Attribute layout);

CTALayoutAttr getCTALayout(Attribute layout);

Expand Down
128 changes: 125 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,134 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
// row0 reg[0-1] reg[4-5]
// row8 reg[2-3] reg[6-7]
//
// When `swizzleByteSize` is non-zero, the layout is constructed
// differently due to leading dimension offset and swizzling.
// There are two key concepts to understand:
//
// 1. Chunks: The leading dimension (i.e., the column dimension) is divided
// into chunks, where each chunk's size is determined by `swizzleByteSize`.
// 2. Swizzling within tiles: Each tile applies a swizzling pattern to its
// rows to optimize memory access.
//
// - Concept 1: Chunks
//
// In the swizzled layout, the leading dimension is strided by
// `swizzleByteSize`. This introduces the concept of a "chunk", where each chunk
// spans a certain number of columns.
//
// For a tile size of `stmatrix.x4` (16x16 elements), with each element being 16
// bits (2 bytes), each tile occupies 16 rows and 32 bytes per row (since 16
// elements * 2 bytes per element = 32 bytes per row).
//
// Given a `swizzleByteSize` of 128 bytes, the number of tiles per chunk can be
// calculated as:
//
// Number of tiles per chunk = swizzleByteSize / (bytes per row) = 128 bytes /
// 32 bytes = 4 tiles
//
// Therefore, each chunk contains 4 tiles horizontally, spanning 64 columns
// (since each tile is 16 columns):
//
// col0-15 col16-31 col32-47 col48-63
// row0-15 tile0 tile1 tile2 tile3
//
// For a tensor of size 128x128 elements (#rows x #columns), and each element
// being 16 bits, the tensor can be divided into multiple chunks both
// horizontally and vertically. Chunks are stored in memory in a "column-major"
// order based on chunks, meaning chunk1's address follows chunk0's.
//
// Assuming we have 8 warps, and we assign each warp to process a chunk of 16
// rows (rows per tile) and 128 columns (the width of two chunks). This results
// in each warp handling one horizontal slice of the tensor.
//
// The overall layout can be visualized as:
//
// |<- 128 * 128 bytes ->|<- 128 * 128 bytes ->|
// columns 0-63 columns 64-127
// warp0 | rows 0-15 chunk0 chunk8
// warp1 | rows 16-31 chunk1 chunk9
// warp2 | rows 32-47 chunk2 chunk10
// warp3 | rows 48-63 chunk3 chunk11
// warp4 | rows 64-79 chunk4 chunk12
// warp5 | rows 80-95 chunk5 chunk13
// warp6 | rows 96-111 chunk6 chunk14
// warp7 | rows 112-127 chunk7 chunk15
//
// - Concept 2: Swizzling within tiles
//
// Within each 16x16 tile, rows are swizzled to optimize memory access patterns.
// This swizzling is similar to what's defined in `TritonGPUAttrDefs.td`. at the
// level of each 16x16 tile rather than the entire tensor.
//
// Key parameters for swizzling:
//
// - `perPhase`: The number of rows over which to apply a XOR operation at
// each phase.
// - `maxPhase`: The total number of phases.
// - `vectorWidth`: The number of elements per vector, which is 8 in this case
// because `stmatrix` stores 8 contiguous elements per thread.
//
// The offset of each element within a tile is calculated using the formula:
//
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
// maxPhase)) * elementSize
//
// where `elementSize` is the size of each element in bytes (2 bytes for 16-bit
// elements).
//
// For example, consider the element at index `(row=1, col=0)` in chunk0:
//
// Without swizzling:
//
// offset = row * swizzleByteSize + col * elementSize
// = 1 * 128 bytes + 0 * 2 bytes
// = 128 bytes
//
// With swizzling (assuming `perPhase=1`, `maxPhase=8`, `vectorWidth=8`):
//
// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) %
// maxPhase)) * elementSize
// = 1 * 128 bytes + (8 * ((1 / 1) % 8)) * 2 bytes
// = 128 bytes + (8 * (1 % 8)) * 2 bytes
// = 128 bytes + 8 * 2 bytes
// = 128 bytes + 16 bytes
// = 144 bytes
//
// This swizzling ensures that elements are stored in a way that optimizes for
// memory bandwidth and reduces bank conflicts.
//
// - Verification through Linear Layout
//
// We can verify the offsets with the following outputs of the corresponding
// linear layout, where each element is 16 bits (2 bytes):
//
// - register=1 -> offset=1
// register=2 -> offset=2
// register=4 -> offset=4
// register=8 -> offset=16
// register=16 -> offset=32
// register=32 -> offset=8192
// - lane=1 -> offset=72
// lane=2 -> offset=144
// lane=4 -> offset=288
// lane=8 -> offset=512
// lane=16 -> offset=8
// - warp=1 -> offset=1024
// warp=2 -> offset=2048
// warp=4 -> offset=4096
//
// For index `(row=1, col=0)`, which corresponds to `reg=0` and `lane=1` in
// `warp=0`, the offset is calculated as 72 * 2 bytes = 144 bytes. The result
// matches our earlier calculation.
//
// TODO(Keren): We should replace tensorTy with a LinearLayout and the element
// bit width of the tensor in the future to support more flexible tensor
// encodings
std::optional<LinearLayout> chooseStMatrixLayoutForRegToRegConversion(
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order);
std::optional<LinearLayout>
chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
ArrayRef<unsigned> repShape,
ArrayRef<unsigned> paddedRepShape,
ArrayRef<unsigned> order, int swizzleByteSize);
} // namespace mlir::triton::gpu

#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
4 changes: 2 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentOrder(sliceEncoding.getParent());
}
return getOrder(layout);
return getThreadOrder(layout);
}

} // namespace
Expand Down Expand Up @@ -77,7 +77,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
threadOffset = threadsPerWarp[sliceLayout.getDim()];
} else {
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
for (unsigned i = 0; i < order.size(); i++) {
if (order[i] == axis)
break;
Expand Down
18 changes: 6 additions & 12 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,9 @@ struct ConvertLayoutOpConversion
if (repId != 0) {
barrier();
}
auto successful = targetInfo.processReplicaUsingStMatrix(
rewriter, loc, smemBase, vals, srcTy,
getTypeConverter()->convertType(srcTy.getElementType()),
paddedRepShape, origRepShape, outOrd, accumNumReplicates);
if (!successful) {
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
multiDimRepId, inVec, paddedRepShape, origRepShape,
outOrd, vals, smemBase);
}
processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep,
multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd,
vals, smemBase);
barrier();
processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep,
multiDimRepId, outVec, paddedRepShape, origRepShape,
Expand Down Expand Up @@ -483,9 +477,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
// Input dims: [reg, lane, warp]
// Output dims: [offset, iteration]
std::optional<LinearLayout> shmemStoreLayout =
chooseStMatrixLayoutForRegToRegConversion(
ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order);
chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape,
scratchConfig.paddedRepShape, scratchConfig.order,
/*swizzleByteSize=*/0);
bool isStMatrix = shmemStoreLayout.has_value();
if (!isStMatrix) {
shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout);
Expand Down
1 change: 0 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
RankedTensorType dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
// TODO: do we need to check if src is shared ?
if (isa<SharedEncodingAttr>(srcLayout) &&
isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
dstLayout)) {
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getThreadOrder;
using ::mlir::triton::gpu::getTotalElemsPerThread;

namespace {
Expand Down Expand Up @@ -271,7 +272,7 @@ struct ReduceOpConversion

auto threadsPerWarp =
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
auto order = getOrder(srcLayout);
auto order = getThreadOrder(srcLayout);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
Value laneIdAxis = multiDimLaneId[axis];
Expand Down
Loading
Loading