diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 06e2077657..e9eca9fc74 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -25,6 +25,7 @@ env: TRITON_BUILD_WITH_CLANG_LLD: "TRUE" TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" TRITON_DISABLE_LINE_INFO: 1 + PROTON_SKIP_PC_SAMPLING_TEST: 1 jobs: Runner-Preparation: runs-on: ubuntu-latest @@ -460,7 +461,7 @@ jobs: - name: Install brew dependencies run: | brew update - brew install ccache llvm + brew install ccache llvm@19 lld - name: Compute cache keys id: cache-key run: | diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 2cfc2de824..81c95ac681 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -27,7 +27,7 @@ env: TRITON_BUILD_WITH_CLANG_LLD: "TRUE" TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" TRITON_DISABLE_LINE_INFO: 1 - + PROTON_SKIP_PC_SAMPLING_TEST: 1 jobs: Runner-Preparation: @@ -439,7 +439,7 @@ jobs: - name: Install brew dependencies run: | brew update - brew install ccache llvm + brew install ccache llvm@19 lld - *compute-cache-keys-step - *cache-build-dependencies-step diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 0f2c9d52df..d37ab18b18 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -102,7 +102,11 @@ export_executable_symbols_for_plugins(triton-llvm-opt) add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED) target_link_libraries(triton-tensor-layout PRIVATE TritonGPUIR + TritonNvidiaGPUIR ${triton_libs} + ${conversion_libs} + ${dialect_libs} + TritonTestAnalysis ) add_llvm_executable(triton-translate diff --git a/bin/triton-tensor-layout.cpp b/bin/triton-tensor-layout.cpp index e0e02b49b9..4087ac1350 100644 --- a/bin/triton-tensor-layout.cpp +++ b/bin/triton-tensor-layout.cpp @@ -1,8 +1,11 @@ +#include "RegisterTritonDialects.h" + #include "mlir/AsmParser/AsmParser.h" #include "mlir/AsmParser/AsmParserState.h" #include "mlir/IR/MLIRContext.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorOr.h" @@ -114,7 +117,7 @@ LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename, return failure(); } - auto printLambda = [&](StringRef name, Attribute attr) { + auto printLambda = [&](StringRef name, mlir::Attribute attr) { ss << "Print layout attribute: #" << name << " = " << attr << "\n"; auto rankedTensorTy = RankedTensorType::get( @@ -155,7 +158,7 @@ LogicalResult printLayoutFromString(MLIRContext *context, if (layoutAttrStr.empty()) return success(); - Attribute layout = parseAttribute(layoutAttrStr, context); + mlir::Attribute layout = parseAttribute(layoutAttrStr, context); if (!layout) { llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n"; return failure(); @@ -178,8 +181,7 @@ int main(int argc, char **argv) { cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n"); DialectRegistry registry; - // Register all dialects that can print tensor layout. - registry.insert(); + registerTritonDialects(registry); MLIRContext ctx(registry); ctx.loadAllAvailableDialects(); @@ -189,7 +191,7 @@ int main(int argc, char **argv) { return 1; } - Type parsedTy = parseType(TensorStr, &ctx); + mlir::Type parsedTy = parseType(TensorStr, &ctx); if (!parsedTy) { llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr << "\n"; diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 3e3ec2b8d4..6017deb7c7 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -57,15 +57,6 @@ class TargetInfoBase { unsigned numLaneToReduce, unsigned interleave) const = 0; - // TODO (Keren): Remove this function once layout conversion using stmatrix is - // handled by Linear Layout. - virtual bool processReplicaUsingStMatrix( - RewriterBase &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzleByteWidth = 0) const = 0; - virtual std::string getMulhiFuncName(Type resultElementTy) const = 0; // Emits LLVM code with |rewriter| to print a message following the given // format from the device. |formatStrStart| is the pointer to the start of diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 3b012a6305..74ea99b588 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -75,9 +75,32 @@ getThreadsPerWarpWithUniqueData(Attribute layout, SmallVector getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); +// Returns the dimensions of the tensor from minor (fast-varying) to +// major (slow-varying). For blocked, mma, and dotOperand layouts, +// though the elements are in registers, the order refers to memory +// layout of the original tensor in global memory. +// For shared Layout, the order refers to which dimension of the original tensor +// is contiguous in shared memory. +SmallVector getOrder(Attribute layout); + +// Returns the dimensions along which warpId's are distributed. +// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4] +// tells there are 2 warps along dim0 and 4 warps along dim1. +// warpOrder tells the specific order when distributing warp IDs. +// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows +// [warp0 warp2 warp4 warp6] +// [warp1 warp3 warp5 warp7] +// Note that in most cases, getWarpOrder and getOrder return the same results. +// But this is not guaranteed. SmallVector getWarpOrder(Attribute layout); -SmallVector getOrder(Attribute layout); +// Returns the dimensions along which threadId's are distributed. +// Similar to warpOrder, threadOrder is necessary to tell the specific thread +// distribution in the warp. +// Note that, in most cases, getThreadOrder and getOrder return the same +// results. But this is not guaranteed. One exception is mfma.transposed layout, +// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1]. +SmallVector getThreadOrder(Attribute layout); CTALayoutAttr getCTALayout(Attribute layout); diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index 7afaf94558..5f4c3b77f0 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -113,12 +113,134 @@ LinearLayout chooseShemLayoutForRegToRegConversion( // row0 reg[0-1] reg[4-5] // row8 reg[2-3] reg[6-7] // +// When `swizzleByteSize` is non-zero, the layout is constructed +// differently due to leading dimension offset and swizzling. +// There are two key concepts to understand: +// +// 1. Chunks: The leading dimension (i.e., the column dimension) is divided +// into chunks, where each chunk's size is determined by `swizzleByteSize`. +// 2. Swizzling within tiles: Each tile applies a swizzling pattern to its +// rows to optimize memory access. +// +// - Concept 1: Chunks +// +// In the swizzled layout, the leading dimension is strided by +// `swizzleByteSize`. This introduces the concept of a "chunk", where each chunk +// spans a certain number of columns. +// +// For a tile size of `stmatrix.x4` (16x16 elements), with each element being 16 +// bits (2 bytes), each tile occupies 16 rows and 32 bytes per row (since 16 +// elements * 2 bytes per element = 32 bytes per row). +// +// Given a `swizzleByteSize` of 128 bytes, the number of tiles per chunk can be +// calculated as: +// +// Number of tiles per chunk = swizzleByteSize / (bytes per row) = 128 bytes / +// 32 bytes = 4 tiles +// +// Therefore, each chunk contains 4 tiles horizontally, spanning 64 columns +// (since each tile is 16 columns): +// +// col0-15 col16-31 col32-47 col48-63 +// row0-15 tile0 tile1 tile2 tile3 +// +// For a tensor of size 128x128 elements (#rows x #columns), and each element +// being 16 bits, the tensor can be divided into multiple chunks both +// horizontally and vertically. Chunks are stored in memory in a "column-major" +// order based on chunks, meaning chunk1's address follows chunk0's. +// +// Assuming we have 8 warps, and we assign each warp to process a chunk of 16 +// rows (rows per tile) and 128 columns (the width of two chunks). This results +// in each warp handling one horizontal slice of the tensor. +// +// The overall layout can be visualized as: +// +// |<- 128 * 128 bytes ->|<- 128 * 128 bytes ->| +// columns 0-63 columns 64-127 +// warp0 | rows 0-15 chunk0 chunk8 +// warp1 | rows 16-31 chunk1 chunk9 +// warp2 | rows 32-47 chunk2 chunk10 +// warp3 | rows 48-63 chunk3 chunk11 +// warp4 | rows 64-79 chunk4 chunk12 +// warp5 | rows 80-95 chunk5 chunk13 +// warp6 | rows 96-111 chunk6 chunk14 +// warp7 | rows 112-127 chunk7 chunk15 +// +// - Concept 2: Swizzling within tiles +// +// Within each 16x16 tile, rows are swizzled to optimize memory access patterns. +// This swizzling is similar to what's defined in `TritonGPUAttrDefs.td`. at the +// level of each 16x16 tile rather than the entire tensor. +// +// Key parameters for swizzling: +// +// - `perPhase`: The number of rows over which to apply a XOR operation at +// each phase. +// - `maxPhase`: The total number of phases. +// - `vectorWidth`: The number of elements per vector, which is 8 in this case +// because `stmatrix` stores 8 contiguous elements per thread. +// +// The offset of each element within a tile is calculated using the formula: +// +// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) % +// maxPhase)) * elementSize +// +// where `elementSize` is the size of each element in bytes (2 bytes for 16-bit +// elements). +// +// For example, consider the element at index `(row=1, col=0)` in chunk0: +// +// Without swizzling: +// +// offset = row * swizzleByteSize + col * elementSize +// = 1 * 128 bytes + 0 * 2 bytes +// = 128 bytes +// +// With swizzling (assuming `perPhase=1`, `maxPhase=8`, `vectorWidth=8`): +// +// offset = row * swizzleByteSize + (vectorWidth * ((row / perPhase) % +// maxPhase)) * elementSize +// = 1 * 128 bytes + (8 * ((1 / 1) % 8)) * 2 bytes +// = 128 bytes + (8 * (1 % 8)) * 2 bytes +// = 128 bytes + 8 * 2 bytes +// = 128 bytes + 16 bytes +// = 144 bytes +// +// This swizzling ensures that elements are stored in a way that optimizes for +// memory bandwidth and reduces bank conflicts. +// +// - Verification through Linear Layout +// +// We can verify the offsets with the following outputs of the corresponding +// linear layout, where each element is 16 bits (2 bytes): +// +// - register=1 -> offset=1 +// register=2 -> offset=2 +// register=4 -> offset=4 +// register=8 -> offset=16 +// register=16 -> offset=32 +// register=32 -> offset=8192 +// - lane=1 -> offset=72 +// lane=2 -> offset=144 +// lane=4 -> offset=288 +// lane=8 -> offset=512 +// lane=16 -> offset=8 +// - warp=1 -> offset=1024 +// warp=2 -> offset=2048 +// warp=4 -> offset=4096 +// +// For index `(row=1, col=0)`, which corresponds to `reg=0` and `lane=1` in +// `warp=0`, the offset is calculated as 72 * 2 bytes = 144 bytes. The result +// matches our earlier calculation. +// // TODO(Keren): We should replace tensorTy with a LinearLayout and the element // bit width of the tensor in the future to support more flexible tensor // encodings -std::optional chooseStMatrixLayoutForRegToRegConversion( - MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef repShape, - ArrayRef paddedRepShape, ArrayRef order); +std::optional +chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, int swizzleByteSize); } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 6869e068c4..20d552b15c 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -38,7 +38,7 @@ SmallVector getParentOrder(Attribute layout) { if (auto sliceEncoding = mlir::dyn_cast(layout)) { return getParentOrder(sliceEncoding.getParent()); } - return getOrder(layout); + return getThreadOrder(layout); } } // namespace @@ -77,7 +77,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { threadOffset = threadsPerWarp[sliceLayout.getDim()]; } else { auto threadsPerWarp = getThreadsPerWarp(srcLayout); - auto order = getOrder(srcLayout); + auto order = getThreadOrder(srcLayout); for (unsigned i = 0; i < order.size(); i++) { if (order[i] == axis) break; diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 403cac9dec..7bb0f198f9 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -215,15 +215,9 @@ struct ConvertLayoutOpConversion if (repId != 0) { barrier(); } - auto successful = targetInfo.processReplicaUsingStMatrix( - rewriter, loc, smemBase, vals, srcTy, - getTypeConverter()->convertType(srcTy.getElementType()), - paddedRepShape, origRepShape, outOrd, accumNumReplicates); - if (!successful) { - processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, - multiDimRepId, inVec, paddedRepShape, origRepShape, - outOrd, vals, smemBase); - } + processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, + multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd, + vals, smemBase); barrier(); processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape, origRepShape, @@ -483,9 +477,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // Input dims: [reg, lane, warp] // Output dims: [offset, iteration] std::optional shmemStoreLayout = - chooseStMatrixLayoutForRegToRegConversion( - ctx, op.getSrc().getType(), scratchConfig.repShape, - scratchConfig.paddedRepShape, scratchConfig.order); + chooseStMatrixLayout(ctx, op.getSrc().getType(), scratchConfig.repShape, + scratchConfig.paddedRepShape, scratchConfig.order, + /*swizzleByteSize=*/0); bool isStMatrix = shmemStoreLayout.has_value(); if (!isStMatrix) { shmemStoreLayout = srcLayout.invertAndCompose(sharedLayout); diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 7a9b26f425..f78964dda9 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -116,7 +116,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { RankedTensorType dstTy = op.getType(); Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); - // TODO: do we need to check if src is shared ? if (isa(srcLayout) && isa( dstLayout)) { diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 16c9991a17..414328be50 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -9,6 +9,7 @@ using namespace mlir::triton; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getOrder; +using ::mlir::triton::gpu::getThreadOrder; using ::mlir::triton::gpu::getTotalElemsPerThread; namespace { @@ -271,7 +272,7 @@ struct ReduceOpConversion auto threadsPerWarp = triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape); - auto order = getOrder(srcLayout); + auto order = getThreadOrder(srcLayout); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); Value laneIdAxis = multiDimLaneId[axis]; diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 63f7db1ffa..6de04b9ecd 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -259,14 +259,6 @@ SmallVector getOrder(Attribute layout) { auto rank = distributedLayout.getWarpsPerCTA().size(); SmallVector order(rank); std::iota(order.rbegin(), order.rend(), 0); - auto mfmaLayout = dyn_cast(layout); - if (!mfmaLayout) - return order; - // For transposed MFMA layouts, we swap M and N dimensions, which is - // always the first two in order; as we can have an optional batch - // dimension following them. - if (mfmaLayout.getIsTransposed()) - std::swap(order[0], order[1]); return order; } if (auto dotLayout = dyn_cast(layout)) { @@ -293,6 +285,14 @@ SmallVector getOrder(Attribute layout) { return {}; }; +SmallVector getThreadOrder(Attribute layout) { + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getThreadOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); + return {}; +}; + CTALayoutAttr getCTALayout(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) { @@ -1557,7 +1557,10 @@ SmallVector AMDMfmaEncodingAttr::getWarpOrder() const { return ::getWarpOrder(*this); } SmallVector AMDMfmaEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + auto order = ::getOrder(*this); + if (getIsTransposed()) + std::swap(order[0], order[1]); + return order; } SmallVector AMDMfmaEncodingAttr::getThreadsPerWarp() const { unsigned rows, cols; diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 286b1eac51..5647f0a0d7 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -507,6 +507,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { {{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}}, {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}}, {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); } else { assert(getMDim() == 16); // For mfma with 16x16 output, each of the 64 threads holds 4 elements. @@ -521,6 +528,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { {{kRegister, {{0, 1}, {0, 2}}}, {kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}}, {outDimNames[order[0]], outDimNames[order[1]]}); + // For mfma.transposed layout, the element ownership among threads are + // "transposed" within each warp. + if (getIsTransposed()) + tileLayout = LinearLayout( + {{kRegister, {{1, 0}, {2, 0}}}, + {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}}, + {outDimNames[order[0]], outDimNames[order[1]]}); } if (hasBatchDim) { assert(order[2] == 0); @@ -806,8 +820,8 @@ namespace { // stmatrix. These restrictions are retained from legacy code, and we could // relax some of them in the future. bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, - ArrayRef paddedRepShape, - ArrayRef order) { + ArrayRef paddedRepShape, ArrayRef order, + int swizzleByteSize) { auto mmaLayout = mlir::dyn_cast(tensorTy.getEncoding()); if (!mmaLayout || !mmaLayout.isHopper()) @@ -826,17 +840,87 @@ bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef repShape, return false; if (paddedRepShape[1] % 8 != 0) return false; + if (swizzleByteSize != 0 && swizzleByteSize != 32 && swizzleByteSize != 64 && + swizzleByteSize != 128) + return false; return true; } -} // anonymous namespace +std::optional chooseStMatrixLayoutLeadingOffset( + MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef repShape, + ArrayRef paddedRepShape, ArrayRef order, + int swizzleByteSize) { + StringAttr kReg = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + StringAttr kCol = S("dim1"); + StringAttr kRow = S("dim0"); + StringAttr kOffset = S("offset"); + + int perPhase; + int maxPhase; + if (swizzleByteSize == 32) { + perPhase = 4; + maxPhase = 2; + } else if (swizzleByteSize == 64) { + perPhase = 2; + maxPhase = 4; + } else if (swizzleByteSize == 128) { + perPhase = 1; + maxPhase = 8; + } else { + llvm::errs() << "Illegal swizzleByteSize: " << swizzleByteSize << "\n"; + llvm::report_fatal_error("Illegal swizzleByteSize"); + } + + // stmatrix only supports 16-bit elements, and each vector has 8 elements + int elemBitWidth = 16; + int vecSize = 8; + int numRows = 16; + int numCols = 8 * swizzleByteSize / elemBitWidth; + + // Construct a single stmatrix.x4 (16x16) tile + std::vector> basesReg = {{1, 0}, {2, 0}, {4, 0}}; + std::vector> basesLane; + for (int logRow = 0; logRow < llvm::Log2_32(numRows); logRow++) { + int row = 1 << logRow; + basesLane.push_back({vecSize * ((row / perPhase) % maxPhase), row}); + } + basesLane.push_back({8, 0}); + + // Expand the tile's register dimension to fit swizzleByteSize, which is a + // "chunk" + for (int logChunk = 0; logChunk < llvm::Log2_32(numCols / 16); logChunk++) { + int chunk = 1 << logChunk; + basesReg.push_back({16 * chunk, 0}); + } + + // Construct the layout for a single chunk + LinearLayout layout = + LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow}); -std::optional chooseStMatrixLayoutForRegToRegConversion( + // Expand the `warp` dimension according to warpsPerCTA. + auto mma = cast(tensorTy.getEncoding()); + layout *= + identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol}) + .transposeOuts(llvm::to_vector(layout.getOutDimNames())); + + // Expand the `register` dimension so the size of columns matches `n`. + int n = mma.getInstrShape()[1]; + int numWarpRows = layout.getOutDimSize(kRow); + layout = (layout.reshapeOuts({{kOffset, layout.getTotalOutDimSize()}}) * + LinearLayout::identity1D(n / numCols, kReg, kOffset)) + .reshapeOuts({{kCol, n}, {kRow, numWarpRows}}); + + auto ret = + combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape()); + return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames())) + .reshapeOuts({{kOffset, ret.getTotalOutDimSize()}, {S("iteration"), 1}}); +} + +std::optional chooseStMatrixLayoutNoLeadingOffset( MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef repShape, ArrayRef paddedRepShape, ArrayRef order) { - if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order)) - return std::nullopt; - StringAttr kReg = S("register"); StringAttr kLane = S("lane"); StringAttr kWarp = S("warp"); @@ -866,4 +950,23 @@ std::optional chooseStMatrixLayoutForRegToRegConversion( {{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}}); } +} // anonymous namespace + +std::optional +chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy, + ArrayRef repShape, + ArrayRef paddedRepShape, + ArrayRef order, int swizzleByteSize) { + if (!canUseStMatrix(tensorTy, repShape, paddedRepShape, order, + swizzleByteSize)) + return std::nullopt; + + if (swizzleByteSize == 0) + return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape, + paddedRepShape, order); + else + return chooseStMatrixLayoutLeadingOffset( + ctx, tensorTy, repShape, paddedRepShape, order, swizzleByteSize); +} + } // namespace mlir::triton::gpu diff --git a/python/setup.py b/python/setup.py index 63f2df9fff..32dd09ff2d 100644 --- a/python/setup.py +++ b/python/setup.py @@ -638,6 +638,14 @@ def get_entry_points(): return entry_points +def get_git_commit_hash(length=8): + try: + cmd = ['git', 'rev-parse', f'--short={length}', 'HEAD'] + return "+git{}".format(subprocess.check_output(cmd).strip().decode('utf-8')) + except Exception: + return "" + + def get_install_requires(): install_requires = [ "packaging", # used by third_party/intel/backend/compiler.py @@ -647,7 +655,7 @@ def get_install_requires(): setup( name=os.environ.get("TRITON_WHEEL_NAME", "triton"), - version="3.0.0" + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""), + version="3.0.0" + get_git_commit_hash() + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", ""), author="Philippe Tillet", author_email="phil@openai.com", description="A language and compiler for custom Deep Learning operations", diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index c026132143..9695a5e47e 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -57,7 +57,7 @@ def kernel(Z, desc, SIZE: tl.constexpr, BYVAL_TMA: tl.constexpr): @triton.jit def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - BYVAL_TMA: tl.constexpr): + BYVAL_TMA: tl.constexpr, dtype: tl.constexpr): if not BYVAL_TMA: tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) @@ -72,11 +72,11 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # offs_k = 0 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], tl.float16) - b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tl.float16) + a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype) + b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], dtype) accumulator = tl.dot(a, b, acc=accumulator) offs_k += BLOCK_SIZE_K - accumulator = accumulator.to(tl.float16) + accumulator = accumulator.to(dtype) tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn]) @@ -101,7 +101,7 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tm desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size()) kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, 1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma, - num_warps=8, num_stages=num_stages) + num_warps=8, num_stages=num_stages, dtype=tl.float16) ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16) torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) if BLOCK_M >= 64 and BLOCK_N >= 64: diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 9dd0b5f6f4..2a0ad6a105 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1613,7 +1613,7 @@ def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder= This loads a tensor of data based on the descriptor and offsets. """ - type = block_type(dtype, shape) + type = block_type(_constexpr_to_value(dtype), shape) return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder) diff --git a/python/triton/testing.py b/python/triton/testing.py index 07827ad853..d63fa0f314 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from typing import Any, Dict, List from . import language as tl +from . import runtime import time import logging @@ -161,7 +162,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu assert return_mode in ["min", "max", "mean", "median", "all"] import torch - di = torch._dynamo.device_interface.get_interface_for_device(device_type) + di = runtime.driver.active.get_device_interface() fn() di.synchronize() diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index 699c4aa4f9..1464d489bc 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -554,7 +554,7 @@ def bench(K, dtype, tiles_per_update, reps=10): if cublas is not None: for _ in range(reps): cublas_matmul(a, b) - time.sleep(0.01) + time.sleep(0.01) if dtype == torch.float16: for _ in range(reps): torch_matmul(a, b) diff --git a/test/TritonGPU/amd/amd-loop-pipeline-v1.mlir b/test/TritonGPU/amd/amd-loop-pipeline-v1.mlir new file mode 100644 index 0000000000..45eae93880 --- /dev/null +++ b/test/TritonGPU/amd/amd-loop-pipeline-v1.mlir @@ -0,0 +1,31 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline | FileCheck %s + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#loc = loc("/data/users/dberard/triton-env/scripts/matmul.py":6:0) +#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = false}> +module attributes {"triton_gpu.target" = "hip:gfx942", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @use_dep_args + tt.func @use_dep_args(%a_ptrs: tensor<64x32x!tt.ptr, #blocked>, %b_ptrs: tensor<32x64x!tt.ptr, #blocked1>, %loop_range: i32) -> (tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x64x!tt.ptr, #blocked1>) { + %cst = arith.constant dense<32> : tensor<64x32xi32, #blocked> + %cst2 = arith.constant dense<2048> : tensor<32x64xi32, #blocked1> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c32_i32 = arith.constant 32 : i32 + // CHECK: tt.load + // CHECK: [[FOR_OUT:%[a-z0-9_]+]]:{{[0-9]+}} = scf.for + %for:3 = scf.for %arg6 = %c0_i32 to %loop_range step %c32_i32 iter_args(%arg7 = %cst_0, %arg8 = %a_ptrs, %arg9 = %b_ptrs) -> (tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x64x!tt.ptr, #blocked1>) : i32 { + %63 = tt.load %arg8 : tensor<64x32x!tt.ptr, #blocked> + %64 = tt.load %arg9 : tensor<32x64x!tt.ptr, #blocked1> + %65 = triton_gpu.convert_layout %63 : tensor<64x32xbf16, #blocked> -> tensor<64x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %66 = triton_gpu.convert_layout %64 : tensor<32x64xbf16, #blocked1> -> tensor<32x64xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %67 = tt.dot %65, %66, %arg7 : tensor<64x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<32x64xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x64xf32, #mma> + %68 = tt.addptr %arg8, %cst : tensor<64x32x!tt.ptr, #blocked>, tensor<64x32xi32, #blocked> + %69 = tt.addptr %arg9, %cst2 : tensor<32x64x!tt.ptr, #blocked1>, tensor<32x64xi32, #blocked1> + scf.yield %67, %68, %69 : tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x64x!tt.ptr, #blocked1> + } + // CHECK: tt.return {{[^,]+}}, [[FOR_OUT]]#3, [[FOR_OUT]]#4 + tt.return %for#0, %for#1, %for#2 : tensor<64x64xf32, #mma>, tensor<64x32x!tt.ptr, #blocked>, tensor<32x64x!tt.ptr, #blocked1> + } +} diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index c1ff6e1d65..80ba6d2a5c 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -484,6 +484,10 @@ def __init__(self): self.utils = HIPUtils() self.launcher_cls = HIPLauncher + def get_device_interface(self): + import torch + return torch.cuda + @staticmethod def is_active(): import torch diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 73b0f775cb..5a2815e5fc 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -119,6 +119,10 @@ struct LoadStoreConversionBase { return axisAnalysisPass.getMaskAlignment(mask); } + unsigned getPtrAlignment(Value ptr) const { + return axisAnalysisPass.getPtrAlignment(ptr); + } + protected: const AMD::TargetInfo &targetInfo; ModuleAxisInfoAnalysis &axisAnalysisPass; @@ -193,7 +197,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, // vectorized iteration through all the pointer/mask/other elements const int valueElemNBits = std::max(8u, valueElemTy.getIntOrFloatBitWidth()); + const size_t valueElemNBytes = valueElemNBits / 8; const int numVecs = numElems / vec; + int64_t ptrAlignmentBytes = getPtrAlignment(ptr) * valueElemNBytes; auto cacheMod = op.getCache(); SmallVector loadedVals; @@ -230,8 +236,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, falseVal = v; } - auto loadVal = - llLoad(rewriter, loc, ptr, vecTy, pred, falseVal, cacheMod); + Value loadVal = llLoad(rewriter, loc, ptr, vecTy, pred, falseVal, + ptrAlignmentBytes, cacheMod); for (size_t ii = 0; ii < vec; ++ii) { Value vecIdx = createIndexAttrConstant( rewriter, loc, this->getTypeConverter()->getIndexType(), ii % vec); @@ -294,9 +300,10 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, vec = std::min(vec, maskAlign); } - const size_t dtsize = - std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); - const size_t valueElemNBits = dtsize * 8; + const size_t valueElemNBits = + std::max(8, valueElemTy.getIntOrFloatBitWidth()); + const size_t valueElemNBytes = valueElemNBits / 8; + int64_t ptrAlignmentBytes = getPtrAlignment(ptr) * valueElemNBytes; auto cacheMod = op.getCache(); const int numVecs = elemsPerThread / vec; @@ -328,7 +335,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, rewriter, loc, this->getTypeConverter()->getIndexType(), s); storeVal = insert_element(vecTy, storeVal, otherElem, indexVal); } - llStore(rewriter, loc, ptr, storeVal, pred, cacheMod); + llStore(rewriter, loc, ptr, storeVal, pred, ptrAlignmentBytes, cacheMod); } // end vec rewriter.eraseOp(op); return success(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 99ecd423b4..6e196e995c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -136,15 +136,6 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, return false; } -bool TargetInfo::processReplicaUsingStMatrix( - RewriterBase &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzleByteWidth) const { - return false; -} - void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args, RewriterBase &rewriter, bool useStdErr) const { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 544abb3e0b..d50f661193 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -46,15 +46,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase { triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const override; - bool processReplicaUsingStMatrix(RewriterBase &rewriter, Location loc, - Value smemBase, SmallVector &vals, - RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, - ArrayRef origRepShape, - ArrayRef outOrd, - unsigned accumNumReplicates, - int swizzleByteWidth) const override; - std::string getMulhiFuncName(Type resultElementTy) const override; void printf(RewriterBase &rewriter, Value formatStrStart, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 262055d645..2e114c898f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -189,12 +189,14 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, } Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, - Value pred, Value falseVal, triton::CacheModifier cm) { + Value pred, Value falseVal, int64_t alignmentBytes, + triton::CacheModifier cm) { // Try to emit llvm.intr.masked.load if we can. In theory the backend should // be happier because we emit less branchy code to optimize. The backend will // lower it down however it wants at some point. - if (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE) { + if (alignmentBytes && + (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE)) { // `llvm.intr.masked.load` only accepts vectors. If we see a scalar we need // to bitcast to `vector<1xelemTy>` (and back) int64_t vecSize = getNumElements(elemTy); @@ -203,7 +205,7 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize); bool nt = (cm == triton::CacheModifier::CG); Value vecData = rewriter.create( - loc, vecType, ptr, maskVal, falseVal, vecSize, nt); + loc, vecType, ptr, maskVal, falseVal, alignmentBytes, nt); // If it is not a vector, remember to bitcast back to a scalar vecData = bitcast(vecData, elemTy); return vecData; @@ -237,11 +239,11 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, } void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, - Value pred, triton::CacheModifier cm) { + Value pred, int64_t alignmentBytes, triton::CacheModifier cm) { // Try to emit llvm.intr.masked.store if we can. In theory the backend should // be happier because we emit less branchy code to optimize. The backend will // lower it down however it wants at some point. - if (cm == triton::CacheModifier::NONE) { + if (alignmentBytes && cm == triton::CacheModifier::NONE) { // `llvm.intr.masked.store` only accepts vectors. If we see a scalar we need // to bitcast to `vector<1xelemTy>` Type elemTy = val.getType(); @@ -249,8 +251,8 @@ void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, Type vecType = castToVectorType(elemTy); val = bitcast(val, vecType); Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize); - auto op = - rewriter.create(loc, val, ptr, maskVal, vecSize); + auto op = rewriter.create(loc, val, ptr, maskVal, + alignmentBytes); return; } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index 631f9bcce4..123234fd48 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -30,12 +30,12 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, // Loads from shared or global memory with predication. // `otherElems` is used to mask out the elements that are not loaded Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, - Value pred, Value falseVal, + Value pred, Value falseVal, int64_t alignmentBytes = 0, triton::CacheModifier cm = triton::CacheModifier::NONE); // Stores to shared or global memory with predication. void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, - Value pred, + Value pred, int64_t alignmentBytes = 0, triton::CacheModifier cm = triton::CacheModifier::NONE); } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index bf976a8138..21b74ecf99 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -269,23 +269,6 @@ class BlockedToMFMA : public RewritePattern { : RewritePattern(tt::DotOp::getOperationName(), 2, context), mfmaVersion(mfmaVersion), enforcedNonKDim(nonKDim), kPack(kPack) {} - bool isChainDot(tt::DotOp &dotOp) const { - auto filter = [&dotOp](Operation *op) { - return op->getParentRegion() == dotOp->getParentRegion(); - }; - ForwardSliceOptions fwdOpt; - fwdOpt.filter = filter; - BackwardSliceOptions bwdOpt; - bwdOpt.omitBlockArguments = true; - bwdOpt.filter = filter; - auto slices = getSlice(dotOp, bwdOpt, fwdOpt); - for (Operation *op : slices) { - if (isa(op) && (op != dotOp)) - return true; - } - return false; - } - bool isSecondDot(tt::DotOp &dotOp) const { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); @@ -400,11 +383,12 @@ class BlockedToMFMA : public RewritePattern { auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim}); - bool isTransposed = isChainDot(dotOp); + // Always use transposed mfma layout. This enables larger vectorization + // for global store instructions mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), /*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile, - /*instrShape*/ mDim, nDim, isTransposed, CTALayout); + /*instrShape*/ mDim, nDim, /*isTransposed*/ true, CTALayout); Type mfmaAccType; if (oldRetType.getElementType().isIntOrIndex()) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 8bdf9d1175..224df90283 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -71,7 +71,7 @@ class LoopPipeliner { /// shared mem and a next buffer stored in regs. int numStages = 2; - /// Arg indicies + /// Arg indicies in in pplForOp size_t depArgsBeginIdx; DenseMap depArgsIdx; @@ -165,6 +165,9 @@ class LoopPipeliner { /// Collect loads to pipeline. Return success if we can pipeline this loop LogicalResult initialize(); + // Update mapping from old forOp results to new pplForOp results + void setResultMapping(DenseMap &newResults); + /// Emit pipelined loads (before loop body) void emitPrologue(); @@ -548,6 +551,45 @@ void LoopPipeliner::emitPrologue() { } // for (Operation *op : orderedDeps) } +void LoopPipeliner::setResultMapping(DenseMap &newResults) { + // After pipelining, some of the depArgs have beem mapped to new args. + // We need to remap these. + // + // For example, if we have + // + // ptr = ... + // c = [zeros] + // ret = scf.for iter_args(a_ptr=ptr, c=c) + // a = load(a_ptr) + // c += dot(a, ...) + // a_ptr_new = a_ptr + N + // scf.yield %a_ptr_new, %c + // + // then the ptr arg should be mapped to a new arg in the for loop. + // + // ptr = ... + // c = [zeros] + // load_pre = load(ptr) + // ptr_new = ptr + N + // ret = scf.for iter_args(a_ptr=ptr, c=c, ld=load_pre, A_ptr_1=ptr_new) + // a_next = load(A_ptr_1) + // c += dot(ld, ...) + // A_ptr_new = A_ptr_1 + N + // scf.yield a_ptr, c, a_next, A_ptr_new + // + // After this, if there are downstream users of a_ptr, they should reference + // ret#3 instead of ret#0 + for (const auto &origArg : llvm::enumerate(forOp.getRegionIterArgs())) { + if (depArgs.contains(origArg.value())) { + auto oldIdx = origArg.index(); + auto newIdx = depArgsIdx[origArg.value()]; + auto oldResult = forOp->getResult(oldIdx); + auto newResult = pplForOp->getResult(newIdx); + newResults[oldResult] = newResult; + } + } +} + void LoopPipeliner::emitEpilogue(DenseMap &newResults) { if (!peelLastIter) return; @@ -846,6 +888,7 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { DenseMap newResults; for (unsigned i = 0; i < forOp->getNumResults(); ++i) newResults[forOp->getResult(i)] = pplForOp->getResult(i); + pipeliner.setResultMapping(newResults); pipeliner.emitEpilogue(newResults); // Replace the original loop diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index 3890997507..a44d849b86 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -479,6 +479,10 @@ def get_current_target(self): warp_size = 32 return GPUTarget("xpu", dev_property, warp_size) + def get_device_interface(self): + import torch + return torch.xpu + @staticmethod def is_active(): import torch diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp index ceb564f13d..94df171f91 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp @@ -143,15 +143,6 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, return true; } -bool TargetInfo::processReplicaUsingStMatrix( - RewriterBase &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzleByteWidth) const { - return false; -} - std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { std::string funcName = resultElementTy.isInteger(32) ? "__imf_umulhi" : "__imf_umul64hi"; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h index fd1afcacdc..f922279940 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h @@ -48,15 +48,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase { triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const override; - bool processReplicaUsingStMatrix(RewriterBase &rewriter, Location loc, - Value smemBase, SmallVector &vals, - RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, - ArrayRef origRepShape, - ArrayRef outOrd, - unsigned accumNumReplicates, - int swizzleByteWidth) const override; - std::string getMulhiFuncName(Type resultElementTy) const override; void printf(RewriterBase &rewriter, Value formatStrStart, diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index bf1f066d55..57c8844e1f 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -440,6 +440,10 @@ def get_current_target(self): warp_size = 32 return GPUTarget("cuda", capability, warp_size) + def get_device_interface(self): + import torch + return torch.cuda + @staticmethod def is_active(): import torch diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 8fb44ce644..2895da1b26 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -725,25 +725,59 @@ struct LocalAllocOpConversion else return failure(); + auto *ctx = rewriter.getContext(); Location loc = op->getLoc(); + RankedTensorType srcTy = op.getSrc().getType(); - Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, op); - auto srcs = unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector shape; - for (int64_t dim : srcTy.getShape()) - shape.push_back(dim); - bool loweredToStMatrix = targetInfo.processReplicaUsingStMatrix( - rewriter, loc, smemBase, srcs, srcTy, - getTypeConverter()->convertType(srcTy.getElementType()), shape, shape, - sharedLayout.getOrder(), 1, swizzleByteSize); - if (!loweredToStMatrix) + SmallVector shape = + convertType(srcTy.getShape()); + auto order = sharedLayout.getOrder(); + auto layout = chooseStMatrixLayout(rewriter.getContext(), srcTy, shape, + shape, order, swizzleByteSize); + if (!layout.has_value()) return failure(); + Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, op); + auto smemPtrTy = ptr_ty(ctx, 3); + + auto kRegister = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(layout->getInDimSize(kLane)); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + + auto regBase = applyLinearLayout(loc, rewriter, *layout, + {{kRegister, i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, i32_val(0)}})[0] + .second; + auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto srcVec = layout->getNumConsecutiveInOut(); + Type llvmElemTy = typeConverter->convertType(srcTy.getElementType()); + for (int i = 0; i < srcVals.size(); i += srcVec) { + auto regIdx = + layout + ->apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] + .second; + Value offset = xor_(regBase, i32_val(regIdx)); + auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset); + vecAddr.setInbounds(true); + SmallVector inValsVec; + for (int j = 0; j < srcVec; j++) + inValsVec.push_back(srcVals[i + j]); + Value valsVec = packLLVector(loc, inValsVec, rewriter); + targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); + } + auto resultTy = cast(op.getType()); // Workaround for 3D tensors // TODO: we need to modify the pipeline pass to give a proper shared // encoding to 3D tensors - auto order = sharedLayout.getOrder(); SmallVector newOrder; if (resultTy.getShape().size() != order.size()) { for (auto i = 0; i < order.size(); ++i) @@ -752,7 +786,6 @@ struct LocalAllocOpConversion } else { newOrder = SmallVector(order.begin(), order.end()); } - auto llvmElemTy = typeConverter->convertType(resultTy.getElementType()); auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, newOrder, loc, rewriter); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 5813b9679e..1b7f71b8cb 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -15,131 +15,6 @@ using ::mlir::LLVM::linearize; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTATile; namespace { -Value computeStMatrixAddr(Value laneId, int matStride, Location loc, - RewriterBase &rewriter, int swizzleByteWidth) { - Value rowInMat = urem(laneId, i32_val(8)); // row in the 8x8 matrix - // linear index of the matrix in the 2x2 matrices - // Decompose matIndex => s_0, s_1, that is the coordinate in 2x2 matrices in - // a warp. - Value matIndex = udiv(laneId, i32_val(8)); - Value s0 = urem(matIndex, i32_val(2)); - Value s1 = udiv(matIndex, i32_val(2)); - if (swizzleByteWidth >= 32) - s1 = xor_(s1, and_(laneId, i32_val(1))); - Value mIndex = add(rowInMat, mul(s0, i32_val(8))); - int m8n8Stride = 8; - Value offset = - add(mul(mIndex, i32_val(matStride)), mul(s1, i32_val(m8n8Stride))); - return offset; -} - -void stMatrixm8n8x4(Value offset, ArrayRef vals, int indexOffset, - Value smemBase, Type elemTy, Location loc, - RewriterBase &rewriter) { - SmallVector inputs; - auto prTy = ptr_ty(rewriter.getContext(), 3); - // Pack the input into 2xf16 - Type packedTy = vec_ty(vals[0].getType(), 2); - for (int i = 0; i < 4; i++) { - Value input = undef(packedTy); - for (int j = 0; j < 2; j++) { - input = insert_element(packedTy, input, vals[indexOffset + i * 2 + j], - i32_val(j)); - } - inputs.push_back(bitcast(input, i32_ty)); - } - Value addr = gep(smemBase.getType(), elemTy, smemBase, offset); - rewriter.create(loc, addr, inputs); -} -void storeDistributedToSharedWithStMatrix( - RankedTensorType tensorTy, Type elemTy, SmallVector &inVals, - Value smemBase, ArrayRef paddedRepShape, - ArrayRef origRepShape, Location loc, RewriterBase &rewriter, - int swizzleByteWidth) { - auto shapePerCTA = getShapePerCTA(tensorTy); - auto mmaLayout = mlir::cast(tensorTy.getEncoding()); - auto order = triton::gpu::getOrder(mmaLayout); - auto warpsPerCTA = mmaLayout.getWarpsPerCTA(); - auto shapePerCTATile = getShapePerCTATile(mmaLayout); - ArrayRef mmaShape = mmaLayout.getInstrShape(); - // 4xm8n8 matches exactly the size of 1 warp of wgmma layout for 16bit type - // and has a shape of 16x16. - int instrN = mmaShape[1] * warpsPerCTA[1]; - int instrM = mmaShape[0] * warpsPerCTA[0]; - std::array numRep = {ceil((int)origRepShape[0], instrM), - ceil((int)origRepShape[1], instrN)}; - int numBoxes = 1; - if (swizzleByteWidth == 128) { - int contigDimSizeInByte = - origRepShape[1] * elemTy.getIntOrFloatBitWidth() / 8; - numBoxes = ceil(contigDimSizeInByte, 128); - } - SmallVector boxShape = {paddedRepShape[0], paddedRepShape[1]}; - boxShape[1] = boxShape[1] / numBoxes; - Value thread = getThreadId(rewriter, loc); - Value warp = udiv(thread, i32_val(32)); - Value lane = urem(thread, i32_val(32)); - - SmallVector multiDimWarpId = - delinearize(rewriter, loc, warp, warpsPerCTA); - - // Compute the relative offset for each lane. - Value stMatrixLaneOffset = - computeStMatrixAddr(lane, boxShape[1], loc, rewriter, swizzleByteWidth); - multiDimWarpId[0] = mul(multiDimWarpId[0], i32_val(mmaShape[0])); - multiDimWarpId[1] = mul(multiDimWarpId[1], i32_val(mmaShape[1])); - SmallVector multiDimOffsetWrapped = getWrappedMultiDimOffset( - rewriter, loc, multiDimWarpId, boxShape, shapePerCTATile, shapePerCTA); - Value relativeOffset = - linearize(rewriter, loc, multiDimOffsetWrapped, boxShape, order); - relativeOffset = add(relativeOffset, stMatrixLaneOffset); - int indexOffset = 0; - int m8n8x4Stride = 16; - int numNChunk = mmaShape[1] / m8n8x4Stride; - unsigned totalNumElements = product(origRepShape); - numNChunk = numNChunk / numBoxes; - for (int m = 0; m < numRep[0]; m++) { - for (int n = 0; n < numRep[1]; n++) { - for (int box = 0; box < numBoxes; box++) { - for (int k = 0; k < numNChunk; k++) { - Value kOffset; - if (swizzleByteWidth >= 64) { - int swizzleBits = swizzleByteWidth == 128 ? 6 : 2; - Value o = lshr(and_(lane, i32_val(swizzleBits)), i32_val(1)); - Value kV = xor_(o, i32_val(k)); - kOffset = mul(kV, i32_val(m8n8x4Stride)); - } else { - kOffset = i32_val(k * m8n8x4Stride); - } - Value addr = add(relativeOffset, - i32_val(n * instrN + m * instrM * boxShape[1] + - box * (totalNumElements / numBoxes))); - addr = add(addr, kOffset); - - stMatrixm8n8x4(addr, inVals, indexOffset, smemBase, elemTy, loc, - rewriter); - indexOffset += 8; - } - } - } - } -} - -bool isStMatrixCompatible(RankedTensorType tensorTy, int swizzleByteWidth) { - auto mmaLayout = - mlir::dyn_cast(tensorTy.getEncoding()); - if (!mmaLayout || !mmaLayout.isHopper()) - return false; - if (tensorTy.getElementType().getIntOrFloatBitWidth() != 16) - return false; - if (swizzleByteWidth > 0 && mmaLayout.getInstrShape()[1] < 64) - return false; - if (swizzleByteWidth != 0 && swizzleByteWidth != 32 && - swizzleByteWidth != 64 && swizzleByteWidth != 128) - return false; - return true; -} - // declare vprintf(i8*, i8*) as external function LLVM::LLVMFuncOp getVprintfDeclaration(RewriterBase &rewriter) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); @@ -605,26 +480,22 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc, Value ptr, Value val) const { - auto vecTy = cast(val.getType()); - Type elemTy = vecTy.getElementType(); - stMatrixm8n8x4(i32_val(0), unpackLLVector(loc, val, rewriter), 0, ptr, elemTy, - loc, rewriter); -} - -bool TargetInfo::processReplicaUsingStMatrix( - RewriterBase &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzleByteWidth) const { - if (isStMatrixCompatible(srcTy, swizzleByteWidth) && - accumNumReplicates == 1 && outOrd[0] == 1 && paddedRepShape[1] % 8 == 0) { - storeDistributedToSharedWithStMatrix(srcTy, elemTy, vals, smemBase, - paddedRepShape, origRepShape, loc, - rewriter, swizzleByteWidth); - return true; + auto vals = unpackLLVector(loc, val, rewriter); + // Ensure input consists of 4 vectors, each holding 2 elements of 16 bits + assert(vals[0].getType().getIntOrFloatBitWidth() == 16 && + "stmatrix requires elements to be 16-bit integers or floats"); + assert(vals.size() == 8 && + "stmatrix requires exactly 8 elements in the input vector"); + Type packedTy = vec_ty(vals[0].getType(), 2); + SmallVector inputs; + for (int i = 0; i < 4; i++) { + Value input = undef(packedTy); + for (int j = 0; j < 2; j++) { + input = insert_element(packedTy, input, vals[i * 2 + j], i32_val(j)); + } + inputs.push_back(bitcast(input, i32_ty)); } - return false; + rewriter.create(loc, ptr, inputs); } std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 011cc37f4b..97df9d840f 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -41,15 +41,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase { triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const override; - bool processReplicaUsingStMatrix(RewriterBase &rewriter, Location loc, - Value smemBase, SmallVector &vals, - RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, - ArrayRef origRepShape, - ArrayRef outOrd, - unsigned accumNumReplicates, - int swizzleByteWidth) const override; - std::string getMulhiFuncName(Type resultElementTy) const override; void printf(RewriterBase &rewriter, Value formatStrStart, diff --git a/third_party/proton/README.md b/third_party/proton/README.md index 8b94f180c3..ccd79e7212 100644 --- a/third_party/proton/README.md +++ b/third_party/proton/README.md @@ -119,7 +119,7 @@ flops64: float # The number of 64-bit floating-point operations bytes: int # The number of bytes expected to be transferred ``` -### Command Line +### Command line Proton can be used as a command-line tool to profile Python scripts and Pytest tests. The following examples demonstrate how to use Proton command-line. @@ -149,6 +149,22 @@ More options can be found by running the following command. proton-viewer -h ``` +### Instruction sampling (experimental) + +Proton supports instruction sampling on NVIDIA GPUs. +Please note that this is an experimental feature and may not work on all GPUs. +You may experience ~20x end-to-end overhead when using instruction sampling, although the overhead for each individual GPU kernel is negligible. +The overhead is mostly caused by data transfer and processing on the CPU. +Additionally, the proton-viewer options `-i -d -t ` can be helpful for filtering out GPU kernels that are not of interest. +The following example demonstrates how to use instruction sampling: + +```python +import triton.profiler as proton + + +proton.start(name="profile_name", context="shadow", backend="cupti_pcsampling") +``` + ## Proton *vs* nsys - Runtime overhead (up to 1.5x) @@ -173,11 +189,24 @@ Proton is designed to be portable and can be used on AMD GPUs. nsys only support Proton can register hooks to analyze the metadata of triton kernels, while nsys cannot. **Note** that the hooks do add additional overhead to proton. -## Known Issues +## Proton *vs* ncu + +Similar to the comparison between Proton and Nsight Systems (Nsys), Proton has a lower profiling overhead than Nsight Compute (NCU). We also plan to support instruction sampling on AMD GPUs. +However, Nsight Compute supports the collection of more detailed metrics than Proton, such as memory access patterns, memory transactions, and other instruction-level metrics. +In contrast, Proton only supports instruction sampling and is designed to be lightweight and portable. + +## Known issues -- CUDA Graph +- CUDA graph `hooks` cannot be used to accurately accumulate the number of FLOPs in CUDA graph mode profiling because kernels are captured and launched separately; metrics are not accumulated when kernels are launched in graph mode. This issue can be circumvented by using `scope` to supply FLOPs. If profiling is initiated after CUDA graph capturing, there may be minor memory leak issues. This is because the number of kernels in a graph instance (i.e., `cuGraphExec`) is unknown, preventing the deletion of mappings between the kernel ID and the graph ID. + +- Instruction sampling + +If you encounter permission related problems when using instruction sampling, you can lookup this [page](https://developer.nvidia.com/nvidia-development-tools-solutions-err_nvgpuctrperm-permission-issue-performance-counters) for help. + +The overhead of instruction sampling on NVIDIA GPUs is about 20x using Proton because we haven't enabled continuous sampling yet. +Continuous sampling can allow for more runtime optimizations, but it makes it more challenging to attribute performance data back to the GPU kernels because: (1) it enables profiling of concurrent kernels, (2) it doesn't allow profiling of time and instruction samples simultaneously, and (3) it works best if we have a separate thread dedicated to attributing instruction samples to the GPU kernels diff --git a/third_party/proton/csrc/include/Data/Metric.h b/third_party/proton/csrc/include/Data/Metric.h index 0e22f7a050..a75692877c 100644 --- a/third_party/proton/csrc/include/Data/Metric.h +++ b/third_party/proton/csrc/include/Data/Metric.h @@ -7,7 +7,7 @@ namespace proton { -enum class MetricKind { Flexible, Kernel, Count }; +enum class MetricKind { Flexible, Kernel, PCSampling, Count }; using MetricValueType = std::variant; @@ -143,8 +143,78 @@ class KernelMetric : public Metric { const static inline bool AGGREGABLE[kernelMetricKind::Count] = { false, false, true, true, false, false}; const static inline std::string VALUE_NAMES[kernelMetricKind::Count] = { - "StartTime (ns)", "EndTime (ns)", "Count", - "Time (ns)", "DeviceId", "DeviceType", + "start_time (ns)", "end_time (ns)", "count", + "time (ns)", "device_id", "device_type", + }; +}; + +class PCSamplingMetric : public Metric { +public: + enum PCSamplingMetricKind : int { + NumSamples, + NumStalledSamples, + StalledBranchResolving, + StalledNoInstruction, + StalledShortScoreboard, + StalledWait, + StalledLongScoreboard, + StalledTexThrottle, + StalledBarrier, + StalledMembar, + StalledIMCMiss, + StalledMIOThrottle, + StalledMathPipeThrottle, + StalledDrain, + StalledLGThrottle, + StalledNotSelected, + StalledMisc, + StalledDispatchStall, + StalledSleeping, + StalledSelected, + Count, + }; + + PCSamplingMetric() + : Metric(MetricKind::PCSampling, PCSamplingMetricKind::Count) {} + + PCSamplingMetric(PCSamplingMetricKind kind, uint64_t samples, + uint64_t stalledSamples) + : PCSamplingMetric() { + this->values[kind] = stalledSamples; + this->values[PCSamplingMetricKind::NumSamples] = samples; + this->values[PCSamplingMetricKind::NumStalledSamples] = stalledSamples; + } + + virtual const std::string getName() const { return "PCSamplingMetric"; } + + virtual const std::string getValueName(int valueId) const { + return VALUE_NAMES[valueId]; + } + + virtual bool isAggregable(int valueId) const { return true; } + +private: + const static inline std::string VALUE_NAMES[PCSamplingMetricKind::Count] = { + "num_samples", + "num_stalled_samples", + "stalled_branch_resolving", + "stalled_no_instruction", + "stalled_short_scoreboard", + "stalled_wait", + "stalled_long_scoreboard", + "stalled_tex_throttle", + "stalled_barrier", + "stalled_membar", + "stalled_imc_miss", + "stalled_mio_throttle", + "stalled_math_pipe_throttle", + "stalled_drain", + "stalled_lg_throttle", + "stalled_not_Selected", + "stalled_misc", + "stalled_dispatch_stall", + "stalled_sleeping", + "stalled_selected", }; }; diff --git a/third_party/proton/csrc/include/Driver/Dispatch.h b/third_party/proton/csrc/include/Driver/Dispatch.h index 6fe2d75942..1d8ec017cd 100644 --- a/third_party/proton/csrc/include/Driver/Dispatch.h +++ b/third_party/proton/csrc/include/Driver/Dispatch.h @@ -63,17 +63,17 @@ template class Dispatch { *lib = dlopen(name, RTLD_NOLOAD); } if (*lib == nullptr) { - // If not found, try to load it from the default path + // If not found, try to load it from LD_LIBRARY_PATH + *lib = dlopen(name, RTLD_LOCAL | RTLD_LAZY); + } + if (*lib == nullptr) { + // If still not found, try to load it from the default path auto dir = std::string(ExternLib::defaultDir); if (dir.length() > 0) { auto fullPath = dir + "/" + name; *lib = dlopen(fullPath.c_str(), RTLD_LOCAL | RTLD_LAZY); } } - if (*lib == nullptr) { - // If still not found, try to load it from LD_LIBRARY_PATH - *lib = dlopen(name, RTLD_LOCAL | RTLD_LAZY); - } if (*lib == nullptr) { throw std::runtime_error("Could not find `" + std::string(name) + "`. Make sure it is in your " diff --git a/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h b/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h index 845b415bd5..495964923e 100644 --- a/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h +++ b/third_party/proton/csrc/include/Driver/GPU/CuptiApi.h @@ -2,11 +2,17 @@ #define PROTON_DRIVER_GPU_CUPTI_H_ #include "cupti.h" +#include "cupti_pcsampling.h" namespace proton { namespace cupti { +template CUptiResult getVersion(uint32_t *version); + +template +CUptiResult getContextId(CUcontext context, uint32_t *pCtxId); + template CUptiResult activityRegisterCallbacks( CUpti_BuffersCallbackRequestFunc funcBufferRequested, @@ -66,6 +72,40 @@ CUptiResult getGraphExecId(CUgraphExec graph, uint32_t *pId); template CUptiResult getGraphId(CUgraph graph, uint32_t *pId); +template +CUptiResult getCubinCrc(CUpti_GetCubinCrcParams *pParams); + +template +CUptiResult +getSassToSourceCorrelation(CUpti_GetSassToSourceCorrelationParams *pParams); + +template +CUptiResult +pcSamplingGetNumStallReasons(CUpti_PCSamplingGetNumStallReasonsParams *pParams); + +template +CUptiResult +pcSamplingGetStallReasons(CUpti_PCSamplingGetStallReasonsParams *pParams); + +template +CUptiResult pcSamplingSetConfigurationAttribute( + CUpti_PCSamplingConfigurationInfoParams *pParams); + +template +CUptiResult pcSamplingEnable(CUpti_PCSamplingEnableParams *pParams); + +template +CUptiResult pcSamplingDisable(CUpti_PCSamplingDisableParams *pParams); + +template +CUptiResult pcSamplingGetData(CUpti_PCSamplingGetDataParams *pParams); + +template +CUptiResult pcSamplingStart(CUpti_PCSamplingStartParams *pParams); + +template +CUptiResult pcSamplingStop(CUpti_PCSamplingStopParams *pParams); + } // namespace cupti } // namespace proton diff --git a/third_party/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h b/third_party/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h new file mode 100644 index 0000000000..58b6e2be81 --- /dev/null +++ b/third_party/proton/csrc/include/Profiler/Cupti/CuptiPCSampling.h @@ -0,0 +1,141 @@ +#ifndef PROTON_PROFILER_CUPTI_PC_SAMPLING_H_ +#define PROTON_PROFILER_CUPTI_PC_SAMPLING_H_ + +#include "CuptiProfiler.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/CuptiApi.h" +#include "Utility/Map.h" +#include "Utility/Singleton.h" +#include +#include + +namespace proton { + +struct CubinData { + size_t cubinCrc; + const char *cubin; + size_t cubinSize; + + struct LineInfoKey { + uint32_t functionIndex; + uint64_t pcOffset; + + bool operator<(const LineInfoKey &other) const { + return functionIndex < other.functionIndex || + (functionIndex == other.functionIndex && + pcOffset < other.pcOffset); + } + }; + + struct LineInfoValue { + uint32_t lineNumber{}; + const std::string functionName{}; + const std::string dirName{}; + const std::string fileName{}; + + LineInfoValue() = default; + + LineInfoValue(uint32_t lineNumber, const std::string &functionName, + const std::string &dirName, const std::string &fileName) + : lineNumber(lineNumber), functionName(functionName), dirName(dirName), + fileName(fileName) {} + }; + + std::map lineInfo; +}; + +struct ConfigureData { + ConfigureData() = default; + + ~ConfigureData() { + if (stallReasonNames) { + for (size_t i = 0; i < numStallReasons; i++) { + if (stallReasonNames[i]) + std::free(stallReasonNames[i]); + } + std::free(stallReasonNames); + } + if (stallReasonIndices) + std::free(stallReasonIndices); + if (pcSamplingData.pPcData) { + for (size_t i = 0; i < numValidStallReasons; ++i) { + std::free(pcSamplingData.pPcData[i].stallReason); + } + std::free(pcSamplingData.pPcData); + } + } + + void initialize(CUcontext context); + + CUpti_PCSamplingConfigurationInfo configureStallReasons(); + CUpti_PCSamplingConfigurationInfo configureSamplingPeriod(); + CUpti_PCSamplingConfigurationInfo configureSamplingBuffer(); + CUpti_PCSamplingConfigurationInfo configureScratchBuffer(); + CUpti_PCSamplingConfigurationInfo configureHardwareBufferSize(); + CUpti_PCSamplingConfigurationInfo configureStartStopControl(); + CUpti_PCSamplingConfigurationInfo configureCollectionMode(); + + // The amount of data reserved on the GPU + static constexpr size_t HardwareBufferSize = 128 * 1024 * 1024; + // The amount of data copied from the hardware buffer each time + static constexpr size_t ScratchBufferSize = 16 * 1024 * 1024; + // The number of PCs copied from the scratch buffer each time + static constexpr size_t DataBufferPCCount = 1024; + // The sampling period in cycles = 2^frequency + static constexpr uint32_t DefaultFrequency = 10; + + CUcontext context{}; + uint32_t contextId; + uint32_t numStallReasons{}; + uint32_t numValidStallReasons{}; + char **stallReasonNames{}; + uint32_t *stallReasonIndices{}; + std::map stallReasonIndexToMetricIndex{}; + std::set notIssuedStallReasonIndices{}; + CUpti_PCSamplingData pcSamplingData{}; + // The memory storing configuration information has to be kept alive during + // the profiling session + std::vector configurationInfos; +}; + +class CuptiPCSampling : public Singleton { + +public: + CuptiPCSampling() = default; + virtual ~CuptiPCSampling() = default; + + void initialize(CUcontext context); + + void start(CUcontext context); + + void stop(CUcontext context, uint64_t externId, bool isAPI); + + void finalize(CUcontext context); + + void loadModule(const char *cubin, size_t cubinSize); + + void unloadModule(const char *cubin, size_t cubinSize); + +private: + ConfigureData *getConfigureData(uint32_t contextId); + + CubinData *getCubinData(uint64_t cubinCrc); + + void processPCSamplingData(ConfigureData *configureData, uint64_t externId, + bool isAPI); + + ThreadSafeMap contextIdToConfigureData; + // In case the same cubin is loaded multiple times, we need to keep track of + // all of them + ThreadSafeMap> + cubinCrcToCubinData; + ThreadSafeSet contextInitialized; + + std::atomic pcSamplingStarted{false}; + std::mutex pcSamplingMutex{}; + std::mutex contextMutex{}; +}; + +} // namespace proton + +#endif // PROTON_PROFILER_CUPTI_PC_SAMPLING_H_ diff --git a/third_party/proton/csrc/include/Profiler/CuptiProfiler.h b/third_party/proton/csrc/include/Profiler/Cupti/CuptiProfiler.h similarity index 90% rename from third_party/proton/csrc/include/Profiler/CuptiProfiler.h rename to third_party/proton/csrc/include/Profiler/Cupti/CuptiProfiler.h index 344d0fd4b9..c443ec2e39 100644 --- a/third_party/proton/csrc/include/Profiler/CuptiProfiler.h +++ b/third_party/proton/csrc/include/Profiler/Cupti/CuptiProfiler.h @@ -1,7 +1,7 @@ #ifndef PROTON_PROFILER_CUPTI_PROFILER_H_ #define PROTON_PROFILER_CUPTI_PROFILER_H_ -#include "GPUProfiler.h" +#include "Profiler/GPUProfiler.h" namespace proton { diff --git a/third_party/proton/csrc/include/Profiler/GPUProfiler.h b/third_party/proton/csrc/include/Profiler/GPUProfiler.h index 26c6d10b5d..d5033b06aa 100644 --- a/third_party/proton/csrc/include/Profiler/GPUProfiler.h +++ b/third_party/proton/csrc/include/Profiler/GPUProfiler.h @@ -31,6 +31,16 @@ class GPUProfiler : public Profiler, std::unordered_map>>; using ApiExternIdSet = ThreadSafeSet>; + ConcreteProfilerT &enablePCSampling() { + pcSamplingEnabled = true; + return dynamic_cast(*this); + } + ConcreteProfilerT &disablePCSampling() { + pcSamplingEnabled = false; + return dynamic_cast(*this); + } + bool isPCSamplingEnabled() const { return pcSamplingEnabled; } + protected: // OpInterface void startOp(const Scope &scope) override { @@ -140,6 +150,8 @@ class GPUProfiler : public Profiler, ConcreteProfilerT &profiler; }; std::unique_ptr pImpl; + + bool pcSamplingEnabled{false}; }; } // namespace proton diff --git a/third_party/proton/csrc/include/Profiler/RoctracerProfiler.h b/third_party/proton/csrc/include/Profiler/Roctracer/RoctracerProfiler.h similarity index 91% rename from third_party/proton/csrc/include/Profiler/RoctracerProfiler.h rename to third_party/proton/csrc/include/Profiler/Roctracer/RoctracerProfiler.h index 2f1791dcb5..b9bc08de8e 100644 --- a/third_party/proton/csrc/include/Profiler/RoctracerProfiler.h +++ b/third_party/proton/csrc/include/Profiler/Roctracer/RoctracerProfiler.h @@ -1,7 +1,7 @@ #ifndef PROTON_PROFILER_ROCTRACER_PROFILER_H_ #define PROTON_PROFILER_ROCTRACER_PROFILER_H_ -#include "GPUProfiler.h" +#include "Profiler/GPUProfiler.h" namespace proton { diff --git a/third_party/proton/csrc/include/Utility/Atomic.h b/third_party/proton/csrc/include/Utility/Atomic.h index d7e40e73cd..0f759e0d61 100644 --- a/third_party/proton/csrc/include/Utility/Atomic.h +++ b/third_party/proton/csrc/include/Utility/Atomic.h @@ -1,4 +1,8 @@ +#ifndef PROTON_UTILITY_ATOMIC_H_ +#define PROTON_UTILITY_ATOMIC_H_ + #include +#include namespace proton { @@ -16,4 +20,20 @@ template T atomicMin(std::atomic &target, T value) { return current; } +template +void doubleCheckedLock(Condition enterCondition, std::mutex &lock, + Function function) { + if (!enterCondition()) + return; + + std::unique_lock guard(lock); + + if (!enterCondition()) + return; + + function(); +} + } // namespace proton + +#endif // PROTON_UTILITY_ATOMIC_H_ diff --git a/third_party/proton/csrc/include/Utility/Errors.h b/third_party/proton/csrc/include/Utility/Errors.h index 62d4f3f665..094723d6f7 100644 --- a/third_party/proton/csrc/include/Utility/Errors.h +++ b/third_party/proton/csrc/include/Utility/Errors.h @@ -1,3 +1,6 @@ +#ifndef PROTON_UTILITY_ERRORS_H_ +#define PROTON_UTILITY_ERRORS_H_ + #include namespace proton { @@ -8,3 +11,5 @@ class NotImplemented : public std::logic_error { }; } // namespace proton + +#endif // PROTON_UTILITY_ERRORS_H_ diff --git a/third_party/proton/csrc/include/Utility/String.h b/third_party/proton/csrc/include/Utility/String.h index b7d45ae1f7..b4a1d3ff91 100644 --- a/third_party/proton/csrc/include/Utility/String.h +++ b/third_party/proton/csrc/include/Utility/String.h @@ -13,6 +13,18 @@ inline std::string toLower(const std::string &str) { return lower; } +inline std::string replace(const std::string &str, const std::string &src, + const std::string &dst) { + std::string replaced = str; + size_t pos = replaced.find(src, pos); + while (pos != std::string::npos) { + replaced.replace(pos, src.length(), dst); + pos += dst.length(); + pos = replaced.find(src, pos); + } + return replaced; +} + } // namespace proton #endif // PROTON_UTILITY_STRING_H_ diff --git a/third_party/proton/csrc/lib/Data/TreeData.cpp b/third_party/proton/csrc/lib/Data/TreeData.cpp index b12427f777..ec6ea1c784 100644 --- a/third_party/proton/csrc/lib/Data/TreeData.cpp +++ b/third_party/proton/csrc/lib/Data/TreeData.cpp @@ -180,66 +180,76 @@ void TreeData::dumpHatchet(std::ostream &os) const { jsonNodes[Tree::TreeNode::RootId] = &(output.back()); std::set valueNames; std::map> deviceIds; - this->tree->template walk( - [&](Tree::TreeNode &treeNode) { - const auto contextName = treeNode.name; - auto contextId = treeNode.id; - json *jsonNode = jsonNodes[contextId]; - (*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}}; - (*jsonNode)["metrics"] = json::object(); - for (auto [metricKind, metric] : treeNode.metrics) { - if (metricKind == MetricKind::Kernel) { - auto kernelMetric = std::dynamic_pointer_cast(metric); - auto duration = std::get( - kernelMetric->getValue(KernelMetric::Duration)); - auto invocations = std::get( - kernelMetric->getValue(KernelMetric::Invocations)); - auto deviceId = std::get( - kernelMetric->getValue(KernelMetric::DeviceId)); - auto deviceType = std::get( - kernelMetric->getValue(KernelMetric::DeviceType)); - auto deviceTypeName = - getDeviceTypeString(static_cast(deviceType)); - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::Duration)] = - duration; - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::Invocations)] = - invocations; - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::DeviceId)] = - std::to_string(deviceId); - (*jsonNode)["metrics"] - [kernelMetric->getValueName(KernelMetric::DeviceType)] = - deviceTypeName; - valueNames.insert( - kernelMetric->getValueName(KernelMetric::Duration)); - valueNames.insert( - kernelMetric->getValueName(KernelMetric::Invocations)); - deviceIds.insert({deviceType, {deviceId}}); - } else { - throw std::runtime_error("MetricKind not supported"); - } - } - for (auto [_, flexibleMetric] : treeNode.flexibleMetrics) { - auto valueName = flexibleMetric.getValueName(0); + this->tree->template walk([&](Tree::TreeNode + &treeNode) { + const auto contextName = treeNode.name; + auto contextId = treeNode.id; + json *jsonNode = jsonNodes[contextId]; + (*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}}; + (*jsonNode)["metrics"] = json::object(); + for (auto [metricKind, metric] : treeNode.metrics) { + if (metricKind == MetricKind::Kernel) { + std::shared_ptr kernelMetric = + std::dynamic_pointer_cast(metric); + uint64_t duration = + std::get(kernelMetric->getValue(KernelMetric::Duration)); + uint64_t invocations = std::get( + kernelMetric->getValue(KernelMetric::Invocations)); + uint64_t deviceId = + std::get(kernelMetric->getValue(KernelMetric::DeviceId)); + uint64_t deviceType = std::get( + kernelMetric->getValue(KernelMetric::DeviceType)); + std::string deviceTypeName = + getDeviceTypeString(static_cast(deviceType)); + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::Duration)] = + duration; + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::Invocations)] = + invocations; + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::DeviceId)] = + std::to_string(deviceId); + (*jsonNode)["metrics"] + [kernelMetric->getValueName(KernelMetric::DeviceType)] = + deviceTypeName; + valueNames.insert(kernelMetric->getValueName(KernelMetric::Duration)); + valueNames.insert( + kernelMetric->getValueName(KernelMetric::Invocations)); + deviceIds.insert({deviceType, {deviceId}}); + } else if (metricKind == MetricKind::PCSampling) { + auto pcSamplingMetric = + std::dynamic_pointer_cast(metric); + for (size_t i = 0; i < PCSamplingMetric::Count; i++) { + auto valueName = pcSamplingMetric->getValueName(i); valueNames.insert(valueName); std::visit( [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, - flexibleMetric.getValues()[0]); - } - (*jsonNode)["children"] = json::array(); - auto children = treeNode.children; - for (auto _ : children) { - (*jsonNode)["children"].push_back(json::object()); + pcSamplingMetric->getValues()[i]); } - auto idx = 0; - for (auto child : children) { - auto [index, childId] = child; - jsonNodes[childId] = &(*jsonNode)["children"][idx]; - idx++; - } - }); + } else { + throw std::runtime_error("MetricKind not supported"); + } + } + for (auto [_, flexibleMetric] : treeNode.flexibleMetrics) { + auto valueName = flexibleMetric.getValueName(0); + valueNames.insert(valueName); + std::visit( + [&](auto &&value) { (*jsonNode)["metrics"][valueName] = value; }, + flexibleMetric.getValues()[0]); + } + (*jsonNode)["children"] = json::array(); + auto children = treeNode.children; + for (auto _ : children) { + (*jsonNode)["children"].push_back(json::object()); + } + auto idx = 0; + for (auto child : children) { + auto [index, childId] = child; + jsonNodes[childId] = &(*jsonNode)["children"][idx]; + idx++; + } + }); // Hints for all available metrics for (auto valueName : valueNames) { output[Tree::TreeNode::RootId]["metrics"][valueName] = 0; diff --git a/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp index 1d7e97314a..2c399d31c7 100644 --- a/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp +++ b/third_party/proton/csrc/lib/Driver/GPU/CuptiApi.cpp @@ -22,6 +22,11 @@ struct ExternLibCupti : public ExternLibBase { void *ExternLibCupti::lib = nullptr; +DEFINE_DISPATCH(ExternLibCupti, getVersion, cuptiGetVersion, uint32_t *); + +DEFINE_DISPATCH(ExternLibCupti, getContextId, cuptiGetContextId, CUcontext, + uint32_t *); + DEFINE_DISPATCH(ExternLibCupti, activityRegisterCallbacks, cuptiActivityRegisterCallbacks, CUpti_BuffersCallbackRequestFunc, @@ -77,6 +82,40 @@ DEFINE_DISPATCH(ExternLibCupti, getGraphExecId, cuptiGetGraphExecId, DEFINE_DISPATCH(ExternLibCupti, getGraphId, cuptiGetGraphId, CUgraph, uint32_t *); +DEFINE_DISPATCH(ExternLibCupti, getCubinCrc, cuptiGetCubinCrc, + CUpti_GetCubinCrcParams *); + +DEFINE_DISPATCH(ExternLibCupti, getSassToSourceCorrelation, + cuptiGetSassToSourceCorrelation, + CUpti_GetSassToSourceCorrelationParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingGetNumStallReasons, + cuptiPCSamplingGetNumStallReasons, + CUpti_PCSamplingGetNumStallReasonsParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingGetStallReasons, + cuptiPCSamplingGetStallReasons, + CUpti_PCSamplingGetStallReasonsParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingSetConfigurationAttribute, + cuptiPCSamplingSetConfigurationAttribute, + CUpti_PCSamplingConfigurationInfoParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingEnable, cuptiPCSamplingEnable, + CUpti_PCSamplingEnableParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingDisable, cuptiPCSamplingDisable, + CUpti_PCSamplingDisableParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingGetData, cuptiPCSamplingGetData, + CUpti_PCSamplingGetDataParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingStart, cuptiPCSamplingStart, + CUpti_PCSamplingStartParams *); + +DEFINE_DISPATCH(ExternLibCupti, pcSamplingStop, cuptiPCSamplingStop, + CUpti_PCSamplingStopParams *); + } // namespace cupti } // namespace proton diff --git a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp new file mode 100644 index 0000000000..f8fb2537a0 --- /dev/null +++ b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp @@ -0,0 +1,444 @@ +#include "Profiler/Cupti/CuptiPCSampling.h" +#include "Data/Metric.h" +#include "Driver/GPU/CudaApi.h" +#include "Driver/GPU/CuptiApi.h" +#include "Utility/Atomic.h" +#include "Utility/Map.h" +#include "Utility/String.h" +#include +#include + +namespace proton { + +namespace { + +uint64_t getCubinCrc(const char *cubin, size_t size) { + CUpti_GetCubinCrcParams cubinCrcParams = { + .size = CUpti_GetCubinCrcParamsSize, + .cubinSize = size, + .cubin = cubin, + .cubinCrc = 0, + }; + cupti::getCubinCrc(&cubinCrcParams); + return cubinCrcParams.cubinCrc; +} + +size_t getNumStallReasons(CUcontext context) { + size_t numStallReasons = 0; + CUpti_PCSamplingGetNumStallReasonsParams numStallReasonsParams = { + .size = CUpti_PCSamplingGetNumStallReasonsParamsSize, + .pPriv = NULL, + .ctx = context, + .numStallReasons = &numStallReasons}; + cupti::pcSamplingGetNumStallReasons(&numStallReasonsParams); + return numStallReasons; +} + +std::tuple +getSassToSourceCorrelation(const char *functionName, uint64_t pcOffset, + const char *cubin, size_t cubinSize) { + CUpti_GetSassToSourceCorrelationParams sassToSourceParams = { + .size = CUpti_GetSassToSourceCorrelationParamsSize, + .cubin = cubin, + .functionName = functionName, + .cubinSize = cubinSize, + .lineNumber = 0, + .pcOffset = pcOffset, + .fileName = NULL, + .dirName = NULL, + }; + // Get source can fail if the line mapping is not available in the cubin so we + // don't check the return value + cupti::getSassToSourceCorrelation(&sassToSourceParams); + auto fileNameStr = sassToSourceParams.fileName + ? std::string(sassToSourceParams.fileName) + : ""; + auto dirNameStr = + sassToSourceParams.dirName ? std::string(sassToSourceParams.dirName) : ""; + // It's user's responsibility to free the memory + if (sassToSourceParams.fileName) + std::free(sassToSourceParams.fileName); + if (sassToSourceParams.dirName) + std::free(sassToSourceParams.dirName); + return std::make_tuple(sassToSourceParams.lineNumber, fileNameStr, + dirNameStr); +} + +std::pair +getStallReasonNamesAndIndices(CUcontext context, size_t numStallReasons) { + char **stallReasonNames = + static_cast(std::calloc(numStallReasons, sizeof(char *))); + for (size_t i = 0; i < numStallReasons; i++) { + stallReasonNames[i] = static_cast( + std::calloc(CUPTI_STALL_REASON_STRING_SIZE, sizeof(char))); + } + uint32_t *stallReasonIndices = + static_cast(std::calloc(numStallReasons, sizeof(uint32_t))); + // Initialize the names with 128 characters to avoid buffer overflow + CUpti_PCSamplingGetStallReasonsParams stallReasonsParams = { + .size = CUpti_PCSamplingGetStallReasonsParamsSize, + .pPriv = NULL, + .ctx = context, + .numStallReasons = numStallReasons, + .stallReasonIndex = stallReasonIndices, + .stallReasons = stallReasonNames, + }; + cupti::pcSamplingGetStallReasons(&stallReasonsParams); + return std::make_pair(stallReasonNames, stallReasonIndices); +} + +size_t matchStallReasonsToIndices( + size_t numStallReasons, char **stallReasonNames, + uint32_t *stallReasonIndices, + std::map &stallReasonIndexToMetricIndex, + std::set ¬IssuedStallReasonIndices) { + // In case there's any invalid stall reasons, we only collect valid ones. + // Invalid ones are swapped to the end of the list + std::vector validIndex(numStallReasons, false); + size_t numValidStalls = 0; + for (size_t i = 0; i < numStallReasons; i++) { + bool notIssued = std::string(stallReasonNames[i]).find("not_issued") != + std::string::npos; + std::string cuptiStallName = std::string(stallReasonNames[i]); + for (size_t j = 0; j < PCSamplingMetric::PCSamplingMetricKind::Count; j++) { + auto metricName = PCSamplingMetric().getValueName(j); + if (cuptiStallName.find(metricName) != std::string::npos) { + if (notIssued) + notIssuedStallReasonIndices.insert(stallReasonIndices[i]); + stallReasonIndexToMetricIndex[stallReasonIndices[i]] = j; + validIndex[i] = true; + numValidStalls++; + break; + } + } + } + int invalidIndex = -1; + for (size_t i = 0; i < numStallReasons; i++) { + if (invalidIndex == -1 && !validIndex[i]) { + invalidIndex = i; + } else if (invalidIndex != -1 && validIndex[i]) { + std::swap(stallReasonIndices[invalidIndex], stallReasonIndices[i]); + std::swap(stallReasonNames[invalidIndex], stallReasonNames[i]); + validIndex[invalidIndex] = true; + invalidIndex++; + } + } + return numValidStalls; +} + +#define CUPTI_CUDA12_4_VERSION 22 +#define CUPTI_CUDA12_4_PC_DATA_PADDING_SIZE sizeof(uint32_t) + +CUpti_PCSamplingData allocPCSamplingData(size_t collectNumPCs, + size_t numValidStallReasons) { + uint32_t libVersion = 0; + cupti::getVersion(&libVersion); + size_t pcDataSize = sizeof(CUpti_PCSamplingPCData); + // Check cupti api version < 12.4 but cupti header version >= 12.4 + // If so, we subtract 4 bytes from the size of CUpti_PCSamplingPCData + // because it introduces a new field (i.e., correlationId) at the end of the + // struct, which is not compatible with the previous versions. + if (libVersion < CUPTI_CUDA12_4_VERSION && + CUPTI_API_VERSION >= CUPTI_CUDA12_4_VERSION) + pcDataSize -= CUPTI_CUDA12_4_PC_DATA_PADDING_SIZE; + CUpti_PCSamplingData pcSamplingData{ + .size = pcDataSize, + .collectNumPcs = collectNumPCs, + .pPcData = static_cast( + std::calloc(collectNumPCs, sizeof(CUpti_PCSamplingPCData)))}; + for (size_t i = 0; i < collectNumPCs; ++i) { + pcSamplingData.pPcData[i].stallReason = + static_cast(std::calloc( + numValidStallReasons, sizeof(CUpti_PCSamplingStallReason))); + } + return pcSamplingData; +} + +void enablePCSampling(CUcontext context) { + CUpti_PCSamplingEnableParams params = { + .size = CUpti_PCSamplingEnableParamsSize, + .pPriv = NULL, + .ctx = context, + }; + cupti::pcSamplingEnable(¶ms); +} + +void disablePCSampling(CUcontext context) { + CUpti_PCSamplingDisableParams params = { + .size = CUpti_PCSamplingDisableParamsSize, + .pPriv = NULL, + .ctx = context, + }; + cupti::pcSamplingDisable(¶ms); +} + +void startPCSampling(CUcontext context) { + CUpti_PCSamplingStartParams params = { + .size = CUpti_PCSamplingStartParamsSize, + .pPriv = NULL, + .ctx = context, + }; + cupti::pcSamplingStart(¶ms); +} + +void stopPCSampling(CUcontext context) { + CUpti_PCSamplingStopParams params = { + .size = CUpti_PCSamplingStopParamsSize, + .pPriv = NULL, + .ctx = context, + }; + cupti::pcSamplingStop(¶ms); +} + +void getPCSamplingData(CUcontext context, + CUpti_PCSamplingData *pcSamplingData) { + CUpti_PCSamplingGetDataParams params = { + .size = CUpti_PCSamplingGetDataParamsSize, + .pPriv = NULL, + .ctx = context, + .pcSamplingData = pcSamplingData, + }; + cupti::pcSamplingGetData(¶ms); +} + +void setConfigurationAttribute( + CUcontext context, + std::vector &configurationInfos) { + CUpti_PCSamplingConfigurationInfoParams infoParams = { + .size = CUpti_PCSamplingConfigurationInfoParamsSize, + .pPriv = NULL, + .ctx = context, + .numAttributes = configurationInfos.size(), + .pPCSamplingConfigurationInfo = configurationInfos.data(), + }; + cupti::pcSamplingSetConfigurationAttribute(&infoParams); +} + +} // namespace + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureStallReasons() { + numStallReasons = getNumStallReasons(context); + std::tie(this->stallReasonNames, this->stallReasonIndices) = + getStallReasonNamesAndIndices(context, numStallReasons); + numValidStallReasons = matchStallReasonsToIndices( + numStallReasons, stallReasonNames, stallReasonIndices, + stallReasonIndexToMetricIndex, notIssuedStallReasonIndices); + CUpti_PCSamplingConfigurationInfo stallReasonInfo{}; + stallReasonInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_STALL_REASON; + stallReasonInfo.attributeData.stallReasonData.stallReasonCount = + numValidStallReasons; + stallReasonInfo.attributeData.stallReasonData.pStallReasonIndex = + stallReasonIndices; + return stallReasonInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureSamplingPeriod() { + CUpti_PCSamplingConfigurationInfo samplingPeriodInfo{}; + samplingPeriodInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_PERIOD; + samplingPeriodInfo.attributeData.samplingPeriodData.samplingPeriod = + DefaultFrequency; + return samplingPeriodInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureSamplingBuffer() { + CUpti_PCSamplingConfigurationInfo samplingBufferInfo{}; + samplingBufferInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SAMPLING_DATA_BUFFER; + this->pcSamplingData = + allocPCSamplingData(DataBufferPCCount, numValidStallReasons); + samplingBufferInfo.attributeData.samplingDataBufferData.samplingDataBuffer = + &this->pcSamplingData; + return samplingBufferInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureScratchBuffer() { + CUpti_PCSamplingConfigurationInfo scratchBufferInfo{}; + scratchBufferInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_SCRATCH_BUFFER_SIZE; + scratchBufferInfo.attributeData.scratchBufferSizeData.scratchBufferSize = + ScratchBufferSize; + return scratchBufferInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureHardwareBufferSize() { + CUpti_PCSamplingConfigurationInfo hardwareBufferInfo{}; + hardwareBufferInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_HARDWARE_BUFFER_SIZE; + hardwareBufferInfo.attributeData.hardwareBufferSizeData.hardwareBufferSize = + HardwareBufferSize; + return hardwareBufferInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureStartStopControl() { + CUpti_PCSamplingConfigurationInfo startStopControlInfo{}; + startStopControlInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_ENABLE_START_STOP_CONTROL; + startStopControlInfo.attributeData.enableStartStopControlData + .enableStartStopControl = true; + return startStopControlInfo; +} + +CUpti_PCSamplingConfigurationInfo ConfigureData::configureCollectionMode() { + CUpti_PCSamplingConfigurationInfo collectionModeInfo{}; + collectionModeInfo.attributeType = + CUPTI_PC_SAMPLING_CONFIGURATION_ATTR_TYPE_COLLECTION_MODE; + collectionModeInfo.attributeData.collectionModeData.collectionMode = + CUPTI_PC_SAMPLING_COLLECTION_MODE_CONTINUOUS; + return collectionModeInfo; +} + +void ConfigureData::initialize(CUcontext context) { + this->context = context; + cupti::getContextId(context, &contextId); + configurationInfos.emplace_back(configureStallReasons()); + configurationInfos.emplace_back(configureSamplingPeriod()); + configurationInfos.emplace_back(configureHardwareBufferSize()); + configurationInfos.emplace_back(configureScratchBuffer()); + configurationInfos.emplace_back(configureSamplingBuffer()); + configurationInfos.emplace_back(configureStartStopControl()); + configurationInfos.emplace_back(configureCollectionMode()); + setConfigurationAttribute(context, configurationInfos); +} + +ConfigureData *CuptiPCSampling::getConfigureData(uint32_t contextId) { + return &contextIdToConfigureData[contextId]; +} + +CubinData *CuptiPCSampling::getCubinData(uint64_t cubinCrc) { + return &(cubinCrcToCubinData[cubinCrc].first); +} + +void CuptiPCSampling::initialize(CUcontext context) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + doubleCheckedLock([&]() { return !contextInitialized.contain(contextId); }, + contextMutex, + [&]() { + enablePCSampling(context); + getConfigureData(contextId)->initialize(context); + contextInitialized.insert(contextId); + }); +} + +void CuptiPCSampling::start(CUcontext context) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + doubleCheckedLock([&]() -> bool { return !pcSamplingStarted; }, + pcSamplingMutex, + [&]() { + initialize(context); + // Ensure all previous operations are completed + cuda::ctxSynchronize(); + startPCSampling(context); + pcSamplingStarted = true; + }); +} + +void CuptiPCSampling::processPCSamplingData(ConfigureData *configureData, + uint64_t externId, bool isAPI) { + auto *pcSamplingData = &configureData->pcSamplingData; + auto &profiler = CuptiProfiler::instance(); + auto dataSet = profiler.getDataSet(); + // In the first round, we need to call getPCSamplingData to get the unsynced + // data from the hardware buffer + bool firstRound = true; + while (pcSamplingData->totalNumPcs > 0 || + pcSamplingData->remainingNumPcs > 0 || firstRound) { + // Handle data + for (size_t i = 0; i < pcSamplingData->totalNumPcs; ++i) { + auto *pcData = pcSamplingData->pPcData + i; + auto *cubinData = getCubinData(pcData->cubinCrc); + auto key = + CubinData::LineInfoKey{pcData->functionIndex, pcData->pcOffset}; + if (cubinData->lineInfo.find(key) == cubinData->lineInfo.end()) { + auto [lineNumber, fileName, dirName] = + getSassToSourceCorrelation(pcData->functionName, pcData->pcOffset, + cubinData->cubin, cubinData->cubinSize); + cubinData->lineInfo.try_emplace(key, lineNumber, + std::string(pcData->functionName), + dirName, fileName); + } + auto &lineInfo = cubinData->lineInfo[key]; + for (size_t j = 0; j < pcData->stallReasonCount; ++j) { + auto *stallReason = &pcData->stallReason[j]; + if (!configureData->stallReasonIndexToMetricIndex.count( + stallReason->pcSamplingStallReasonIndex)) + throw std::runtime_error("Invalid stall reason index"); + for (auto *data : dataSet) { + auto scopeId = externId; + if (isAPI) + scopeId = data->addScope(externId, lineInfo.functionName); + if (lineInfo.fileName.size()) + scopeId = data->addScope( + scopeId, lineInfo.dirName + "/" + lineInfo.fileName + ":" + + lineInfo.functionName + "@" + + std::to_string(lineInfo.lineNumber)); + auto metricKind = static_cast( + configureData->stallReasonIndexToMetricIndex + [stallReason->pcSamplingStallReasonIndex]); + auto samples = stallReason->samples; + auto stalledSamples = + configureData->notIssuedStallReasonIndices.count( + stallReason->pcSamplingStallReasonIndex) + ? 0 + : samples; + auto metric = std::make_shared(metricKind, samples, + stalledSamples); + data->addMetric(scopeId, metric); + } + } + } + if (pcSamplingData->remainingNumPcs > 0 || firstRound) { + getPCSamplingData(configureData->context, pcSamplingData); + firstRound = false; + } else + break; + } +} + +void CuptiPCSampling::stop(CUcontext context, uint64_t externId, bool isAPI) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + doubleCheckedLock([&]() -> bool { return pcSamplingStarted; }, + pcSamplingMutex, + [&]() { + auto *configureData = getConfigureData(contextId); + stopPCSampling(context); + pcSamplingStarted = false; + processPCSamplingData(configureData, externId, isAPI); + }); +} + +void CuptiPCSampling::finalize(CUcontext context) { + uint32_t contextId = 0; + cupti::getContextId(context, &contextId); + if (!contextInitialized.contain(contextId)) + return; + auto *configureData = getConfigureData(contextId); + contextIdToConfigureData.erase(contextId); + contextInitialized.erase(contextId); + disablePCSampling(context); +} + +void CuptiPCSampling::loadModule(const char *cubin, size_t cubinSize) { + auto cubinCrc = getCubinCrc(cubin, cubinSize); + auto *cubinData = getCubinData(cubinCrc); + cubinData->cubinCrc = cubinCrc; + cubinData->cubinSize = cubinSize; + cubinData->cubin = cubin; +} + +void CuptiPCSampling::unloadModule(const char *cubin, size_t cubinSize) { + // XXX: Unload module is supposed to be called in a thread safe manner + // i.e., no two threads will be calling unload module the same time + auto cubinCrc = getCubinCrc(cubin, cubinSize); + auto count = cubinCrcToCubinData[cubinCrc].second; + if (count > 1) + cubinCrcToCubinData[cubinCrc].second = count - 1; + else + cubinCrcToCubinData.erase(cubinCrc); +} + +} // namespace proton diff --git a/third_party/proton/csrc/lib/Profiler/CuptiProfiler.cpp b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp similarity index 72% rename from third_party/proton/csrc/lib/Profiler/CuptiProfiler.cpp rename to third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp index 573840fc6c..9ddbd7a715 100644 --- a/third_party/proton/csrc/lib/Profiler/CuptiProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp @@ -1,9 +1,10 @@ -#include "Profiler/CuptiProfiler.h" +#include "Profiler/Cupti/CuptiProfiler.h" #include "Context/Context.h" #include "Data/Metric.h" #include "Driver/Device.h" #include "Driver/GPU/CudaApi.h" #include "Driver/GPU/CuptiApi.h" +#include "Profiler/Cupti/CuptiPCSampling.h" #include "Utility/Map.h" #include @@ -162,6 +163,33 @@ void setGraphCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { #undef CALLBACK_ENABLE } +void setResourceCallbacks(CUpti_SubscriberHandle subscriber, bool enable) { +#define CALLBACK_ENABLE(id) \ + cupti::enableCallback(static_cast(enable), subscriber, \ + CUPTI_CB_DOMAIN_RESOURCE, id) + + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_MODULE_LOADED); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_MODULE_UNLOAD_STARTING); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_CONTEXT_CREATED); + CALLBACK_ENABLE(CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING); +#undef CALLBACK_ENABLE +} + +bool isDriverAPILaunch(CUpti_CallbackId cbId) { + return cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunch || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchGrid || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchGridAsync || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel_ptsz || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchKernelEx_ptsz || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel_ptsz || + cbId == CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice || + cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch || + cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz; +} + } // namespace struct CuptiProfiler::CuptiProfilerPimpl @@ -186,6 +214,7 @@ struct CuptiProfiler::CuptiProfilerPimpl static constexpr size_t AttributeSize = sizeof(size_t); CUpti_SubscriberHandle subscriber{}; + CuptiPCSampling pcSampling; ThreadSafeMap> graphIdToNumInstances; @@ -241,33 +270,58 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData, if (domain == CUPTI_CB_DOMAIN_RESOURCE) { auto *resourceData = static_cast(const_cast(cbData)); - auto *graphData = - static_cast(resourceData->resourceDescriptor); auto *pImpl = dynamic_cast(profiler.pImpl.get()); - uint32_t graphId = 0; - uint32_t graphExecId = 0; - if (graphData->graph) - cupti::getGraphId(graphData->graph, &graphId); - if (graphData->graphExec) - cupti::getGraphExecId(graphData->graphExec, &graphExecId); - if (cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_CREATED || - cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_CLONED) { - if (!pImpl->graphIdToNumInstances.contain(graphId)) - pImpl->graphIdToNumInstances[graphId] = 1; - else - pImpl->graphIdToNumInstances[graphId]++; - } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_DESTROY_STARTING) { - pImpl->graphIdToNumInstances[graphId]--; - } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHEXEC_CREATED) { - pImpl->graphExecIdToGraphId[graphExecId] = graphId; - } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHEXEC_DESTROY_STARTING) { - pImpl->graphExecIdToGraphId.erase(graphExecId); - } else if (cbId == CUPTI_CBID_RESOURCE_GRAPH_DESTROY_STARTING) { - pImpl->graphIdToNumInstances.erase(graphId); + if (cbId == CUPTI_CBID_RESOURCE_MODULE_LOADED) { + auto *moduleResource = static_cast( + resourceData->resourceDescriptor); + if (profiler.isPCSamplingEnabled()) { + pImpl->pcSampling.loadModule(moduleResource->pCubin, + moduleResource->cubinSize); + } + } else if (cbId == CUPTI_CBID_RESOURCE_MODULE_UNLOAD_STARTING) { + auto *moduleResource = static_cast( + resourceData->resourceDescriptor); + if (profiler.isPCSamplingEnabled()) { + pImpl->pcSampling.unloadModule(moduleResource->pCubin, + moduleResource->cubinSize); + } + } else if (cbId == CUPTI_CBID_RESOURCE_CONTEXT_CREATED) { + if (profiler.isPCSamplingEnabled()) { + pImpl->pcSampling.initialize(resourceData->context); + } + } else if (cbId == CUPTI_CBID_RESOURCE_CONTEXT_DESTROY_STARTING) { + if (profiler.isPCSamplingEnabled()) { + pImpl->pcSampling.finalize(resourceData->context); + } + } else { + auto *graphData = + static_cast(resourceData->resourceDescriptor); + uint32_t graphId = 0; + uint32_t graphExecId = 0; + if (graphData->graph) + cupti::getGraphId(graphData->graph, &graphId); + if (graphData->graphExec) + cupti::getGraphExecId(graphData->graphExec, &graphExecId); + if (cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_CREATED || + cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_CLONED) { + if (!pImpl->graphIdToNumInstances.contain(graphId)) + pImpl->graphIdToNumInstances[graphId] = 1; + else + pImpl->graphIdToNumInstances[graphId]++; + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_DESTROY_STARTING) { + pImpl->graphIdToNumInstances[graphId]--; + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHEXEC_CREATED) { + pImpl->graphExecIdToGraphId[graphExecId] = graphId; + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHEXEC_DESTROY_STARTING) { + pImpl->graphExecIdToGraphId.erase(graphExecId); + } else if (cbId == CUPTI_CBID_RESOURCE_GRAPH_DESTROY_STARTING) { + pImpl->graphIdToNumInstances.erase(graphId); + } } } else { const CUpti_CallbackData *callbackData = static_cast(cbData); + auto *pImpl = dynamic_cast(profiler.pImpl.get()); if (callbackData->callbackSite == CUPTI_API_ENTER) { auto scopeId = Scope::getNewScopeId(); threadState.record(scopeId); @@ -275,7 +329,6 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData, size_t numInstances = 1; if (cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch || cbId == CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz) { - auto *pImpl = dynamic_cast(profiler.pImpl.get()); auto graphExec = static_cast( callbackData->functionParams) ->hGraph; @@ -298,7 +351,17 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData, << std::endl; } profiler.correlation.correlate(callbackData->correlationId, numInstances); + if (profiler.isPCSamplingEnabled() && isDriverAPILaunch(cbId)) { + pImpl->pcSampling.start(callbackData->context); + } } else if (callbackData->callbackSite == CUPTI_API_EXIT) { + if (profiler.isPCSamplingEnabled() && isDriverAPILaunch(cbId)) { + // XXX: Conservatively stop every GPU kernel for now + auto scopeId = profiler.correlation.externIdQueue.back(); + pImpl->pcSampling.stop( + callbackData->context, scopeId, + profiler.correlation.apiExternIds.contain(scopeId)); + } threadState.exitOp(); profiler.correlation.submit(callbackData->correlationId); } @@ -306,10 +369,15 @@ void CuptiProfiler::CuptiProfilerPimpl::callbackFn(void *userData, } void CuptiProfiler::CuptiProfilerPimpl::doStart() { - cupti::activityRegisterCallbacks(allocBuffer, completeBuffer); - cupti::activityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL); - // TODO: switch to directly subscribe the APIs and measure overhead cupti::subscribe(&subscriber, callbackFn, nullptr); + if (profiler.isPCSamplingEnabled()) { + setResourceCallbacks(subscriber, /*enable=*/true); + // Continuous PC sampling is not compatible with concurrent kernel profiling + cupti::activityEnable(CUPTI_ACTIVITY_KIND_KERNEL); + } else { + cupti::activityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL); + } + cupti::activityRegisterCallbacks(allocBuffer, completeBuffer); setGraphCallbacks(subscriber, /*enable=*/true); setRuntimeCallbacks(subscriber, /*enable=*/true); setDriverCallbacks(subscriber, /*enable=*/true); @@ -326,8 +394,12 @@ void CuptiProfiler::CuptiProfilerPimpl::doFlush() { // If the current context is not set, we don't do any synchronization. CUcontext cuContext = nullptr; cuda::ctxGetCurrent(&cuContext); - if (cuContext) + if (cuContext) { cuda::ctxSynchronize(); + } + if (profiler.isPCSamplingEnabled()) { + pcSampling.finalize(cuContext); + } profiler.correlation.flush( /*maxRetries=*/100, /*sleepMs=*/10, /*flush=*/[]() { @@ -341,7 +413,12 @@ void CuptiProfiler::CuptiProfilerPimpl::doFlush() { } void CuptiProfiler::CuptiProfilerPimpl::doStop() { - cupti::activityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL); + if (profiler.isPCSamplingEnabled()) { + setResourceCallbacks(subscriber, /*enable=*/false); + cupti::activityDisable(CUPTI_ACTIVITY_KIND_KERNEL); + } else { + cupti::activityDisable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL); + } setGraphCallbacks(subscriber, /*enable=*/false); setRuntimeCallbacks(subscriber, /*enable=*/false); setDriverCallbacks(subscriber, /*enable=*/false); diff --git a/third_party/proton/csrc/lib/Profiler/RoctracerProfiler.cpp b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp similarity index 99% rename from third_party/proton/csrc/lib/Profiler/RoctracerProfiler.cpp rename to third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp index 55af9eb714..68f3f0beac 100644 --- a/third_party/proton/csrc/lib/Profiler/RoctracerProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp @@ -1,4 +1,4 @@ -#include "Profiler/RoctracerProfiler.h" +#include "Profiler/Roctracer/RoctracerProfiler.h" #include "Context/Context.h" #include "Data/Metric.h" #include "Driver/GPU/HipApi.h" diff --git a/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/proton/csrc/lib/Session/Session.cpp index 1db512d075..9b0ef10d37 100644 --- a/third_party/proton/csrc/lib/Session/Session.cpp +++ b/third_party/proton/csrc/lib/Session/Session.cpp @@ -2,8 +2,8 @@ #include "Context/Python.h" #include "Context/Shadow.h" #include "Data/TreeData.h" -#include "Profiler/CuptiProfiler.h" -#include "Profiler/RoctracerProfiler.h" +#include "Profiler/Cupti/CuptiProfiler.h" +#include "Profiler/Roctracer/RoctracerProfiler.h" #include "Utility/String.h" namespace proton { @@ -13,6 +13,9 @@ Profiler *getProfiler(const std::string &profilerName) { if (proton::toLower(profilerName) == "cupti") { return &CuptiProfiler::instance(); } + if (proton::toLower(profilerName) == "cupti_pcsampling") { + return &CuptiProfiler::instance().enablePCSampling(); + } if (proton::toLower(profilerName) == "roctracer") { return &RoctracerProfiler::instance(); } diff --git a/third_party/proton/proton/profile.py b/third_party/proton/proton/profile.py index 01d5a1947e..2dd7a6f53e 100644 --- a/third_party/proton/proton/profile.py +++ b/third_party/proton/proton/profile.py @@ -42,7 +42,7 @@ def start( name (str, optional): The name (with path) of the profiling session. If not provided, the default name is "~/proton.hatchet". backend (str, optional): The backend to use for profiling. - Available options are [None, "cupti", "roctracer"]. + Available options are [None, "cupti", "cupti_pcsampling", "roctracer"]. Defaults to None, which automatically selects the backend matching the current active runtime. context (str, optional): The context to use for profiling. Available options are ["shadow", "python"]. diff --git a/third_party/proton/proton/proton.py b/third_party/proton/proton/proton.py index 7ea6413ac5..cbb7a0b6f9 100644 --- a/third_party/proton/proton/proton.py +++ b/third_party/proton/proton/proton.py @@ -13,7 +13,8 @@ def parse_arguments(): python -m triton.profiler.proton [options] script.py [script_args] [script_options] """, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument("-n", "--name", type=str, help="Name of the profiling session") - parser.add_argument("-b", "--backend", type=str, help="Profiling backend", default=None, choices=["cupti"]) + parser.add_argument("-b", "--backend", type=str, help="Profiling backend", default=None, + choices=["cupti", "cupti_pcsampling", "roctracer"]) parser.add_argument("-c", "--context", type=str, help="Profiling context", default="shadow", choices=["shadow", "python"]) parser.add_argument("-d", "--data", type=str, help="Profiling data", default="tree", choices=["tree"]) diff --git a/third_party/proton/proton/scope.py b/third_party/proton/proton/scope.py index 5695b88075..26d946a8c1 100644 --- a/third_party/proton/proton/scope.py +++ b/third_party/proton/proton/scope.py @@ -5,7 +5,7 @@ from .flags import get_profiling_on from triton._C.libproton import proton as libproton -_local = threading.local() +thread_local_scopes = threading.local() MetricValueType = Union[float, int] PropertyValueType = Union[float, int, str] @@ -22,7 +22,7 @@ class scope: foo[1,](x, y) ``` - decoarator: + decorator: ```python @proton.scope("test0", {metric_name: metric_value}) def foo(x, y): @@ -36,25 +36,25 @@ def foo(x, y): def __init__(self, name: str, metrics: Optional[dict[str, MetricValueType]] = None, properties: Optional[dict[str, PropertyValueType]] = None) -> None: - self._name = name - self._metrics = metrics - self._properties = properties + self.name = name + self.metrics = metrics + self.properties = properties def __enter__(self): if not get_profiling_on(): return self - self._id = libproton.record_scope() - libproton.enter_scope(self._id, self._name) - if self._metrics: - libproton.add_metrics(self._id, self._metrics) - if self._properties: - libproton.set_properties(self._id, self._properties) + self.id = libproton.record_scope() + libproton.enter_scope(self.id, self.name) + if self.metrics: + libproton.add_metrics(self.id, self.metrics) + if self.properties: + libproton.set_properties(self.id, self.properties) return self def __exit__(self, exc_type, exc_value, traceback) -> None: if not get_profiling_on(): return - libproton.exit_scope(self._id, self._name) + libproton.exit_scope(self.id, self.name) def __call__(self, func): @@ -62,14 +62,14 @@ def __call__(self, func): def wrapper(*args, **kwargs): if get_profiling_on(): id = libproton.record_scope() - libproton.enter_scope(id, self._name) - if self._metrics: - libproton.add_metrics(id, self._metrics) - if self._properties: - libproton.set_properties(id, self._properties) + libproton.enter_scope(id, self.name) + if self.metrics: + libproton.add_metrics(id, self.metrics) + if self.properties: + libproton.set_properties(id, self.properties) ret = func(*args, **kwargs) if get_profiling_on(): - libproton.exit_scope(id, self._name) + libproton.exit_scope(id, self.name) return ret return wrapper @@ -80,9 +80,9 @@ def enter_scope(name: str, *, triton_op: bool = False, metrics: Optional[dict[st if not get_profiling_on(): return -1 id = libproton.record_scope() - if not hasattr(_local, "scopes"): - _local.scopes = [] - _local.scopes.append((id, name)) + if not hasattr(thread_local_scopes, "scopes"): + thread_local_scopes.scopes = [] + thread_local_scopes.scopes.append((id, name)) if triton_op: libproton.enter_op(id, name) else: @@ -97,7 +97,7 @@ def enter_scope(name: str, *, triton_op: bool = False, metrics: Optional[dict[st def exit_scope(triton_op: bool = False) -> int: if not get_profiling_on(): return -1 - id, name = _local.scopes.pop() + id, name = thread_local_scopes.scopes.pop() if triton_op: libproton.exit_op(id, name) else: diff --git a/third_party/proton/proton/viewer.py b/third_party/proton/proton/viewer.py index 2067466c94..9fe0e7e67d 100644 --- a/third_party/proton/proton/viewer.py +++ b/third_party/proton/proton/viewer.py @@ -43,7 +43,7 @@ def get_min_time_flops(df, device_info): num_sms = device_info[device_type][device_index]["num_sms"] clock_rate = device_info[device_type][device_index]["clock_rate"] for width in TritonHook.flops_width: - idx = df["DeviceId"] == device_index + idx = df["device_id"] == device_index device_frames = df[idx] if f"flops{width}" not in device_frames.columns: continue @@ -72,7 +72,7 @@ def get_min_time_bytes(df, device_info): min_time_bytes = pd.DataFrame(0.0, index=df.index, columns=["min_time"]) for device_type in device_info: for device_index in device_info[device_type]: - idx = df["DeviceId"] == device_index + idx = df["device_id"] == device_index device_frames = df[idx] memory_clock_rate = device_info[device_type][device_index]["memory_clock_rate"] # in khz bus_width = device_info[device_type][device_index]["bus_width"] # in bits @@ -105,7 +105,7 @@ def get_min_time_bytes(df, device_info): def derive_metrics(gf, metrics, raw_metrics, device_info): derived_metrics = [] original_metrics = [] - internal_frame_indices = gf.dataframe["DeviceId"].isna() + internal_frame_indices = gf.dataframe["device_id"].isna() def get_time_seconds(df): time_metric_name = match_available_metrics([time_factor_dict.name], raw_metrics)[0] @@ -135,7 +135,7 @@ def get_time_seconds(df): derived_metrics.append(f"{metric} (inc)") elif metric in avg_time_factor_dict.factor: metric_time_unit = avg_time_factor_dict.name + "/" + metric.split("/")[1] - gf.dataframe[f"{metric} (inc)"] = (get_time_seconds(gf.dataframe) / gf.dataframe['Count'] / + gf.dataframe[f"{metric} (inc)"] = (get_time_seconds(gf.dataframe) / gf.dataframe['count'] / avg_time_factor_dict.factor[metric_time_unit]) gf.dataframe.loc[internal_frame_indices, f"{metric} (inc)"] = np.nan derived_metrics.append(f"{metric} (inc)") diff --git a/third_party/proton/test/example_cuda.json b/third_party/proton/test/example_cuda.json index 0db9ace447..445f0e224c 100644 --- a/third_party/proton/test/example_cuda.json +++ b/third_party/proton/test/example_cuda.json @@ -8,10 +8,10 @@ "type": "function" }, "metrics": { - "Count": 10, - "DeviceId": "1", - "DeviceType": "CUDA", - "Time (ns)": 204800, + "count": 10, + "device_id": "1", + "device_type": "CUDA", + "time (ns)": 204800, "flops8": 1e11, "bytes": 1e8 } @@ -23,10 +23,10 @@ "type": "function" }, "metrics": { - "Count": 1, - "DeviceId": "0", - "DeviceType": "CUDA", - "Time (ns)": 204800, + "count": 1, + "device_id": "0", + "device_type": "CUDA", + "time (ns)": 204800, "flops8": 1e10, "bytes": 1e7 } @@ -37,8 +37,8 @@ "type": "function" }, "metrics": { - "Count": 0, - "Time (ns)": 0, + "count": 0, + "time (ns)": 0, "flops8": 0, "bytes": 0 } diff --git a/third_party/proton/test/example_frame.json b/third_party/proton/test/example_frame.json index 64789a3b74..0069476fbc 100644 --- a/third_party/proton/test/example_frame.json +++ b/third_party/proton/test/example_frame.json @@ -10,10 +10,10 @@ "type": "function" }, "metrics": { - "Count": 1, - "DeviceId": "0", - "DeviceType": "HIP", - "Time (ns)": 204800 + "count": 1, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 204800 } } ], @@ -27,7 +27,12 @@ "frame": { "name": "test1" }, - "metrics": {} + "metrics": { + "count": 1, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 204800 + } } ], "frame": { @@ -35,8 +40,8 @@ "type": "function" }, "metrics": { - "Count": 0, - "Time (ns)": 0 + "count": 0, + "time (ns)": 0 } }, { diff --git a/third_party/proton/test/example_hip.json b/third_party/proton/test/example_hip.json index 2fcfad3c5d..68538706cf 100644 --- a/third_party/proton/test/example_hip.json +++ b/third_party/proton/test/example_hip.json @@ -8,10 +8,10 @@ "type": "function" }, "metrics": { - "Count": 1, - "DeviceId": "1", - "DeviceType": "HIP", - "Time (ns)": 204800, + "count": 1, + "device_id": "1", + "device_type": "HIP", + "time (ns)": 204800, "flops8": 1e11, "bytes": 1e8 } @@ -23,10 +23,10 @@ "type": "function" }, "metrics": { - "Count": 1, - "DeviceId": "0", - "DeviceType": "HIP", - "Time (ns)": 204800, + "count": 1, + "device_id": "0", + "device_type": "HIP", + "time (ns)": 204800, "flops8": 1e10, "bytes": 1e7 } @@ -37,8 +37,8 @@ "type": "function" }, "metrics": { - "Count": 0, - "Time (ns)": 0, + "count": 0, + "time (ns)": 0, "flops8": 0, "bytes": 0 } diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index 1a69608a26..13cb9bd99c 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -25,7 +25,7 @@ def test_torch(context): if context == "shadow": assert len(data[0]["children"]) == 1 assert data[0]["children"][0]["frame"]["name"] == "test" - assert data[0]["children"][0]["children"][0]["metrics"]["Time (ns)"] > 0 + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 elif context == "python": assert len(data[0]["children"]) == 1 # The last frame is the torch kernel @@ -111,7 +111,7 @@ def fn(): assert len(test_frame["children"]) >= 2 else: assert len(test_frame["children"]) >= 3 - assert test_frame["children"][0]["metrics"]["Time (ns)"] > 0 + assert test_frame["children"][0]["metrics"]["time (ns)"] > 0 def test_metrics(): @@ -197,7 +197,41 @@ def foo(x, size: tl.constexpr, y): assert data[0]["children"][0]["frame"]["name"] == "test0" assert data[0]["children"][0]["children"][0]["frame"]["name"] == "foo_test_1ctas_1elems" assert data[0]["children"][0]["children"][0]["metrics"]["flops32"] == 1.0 - assert data[0]["children"][0]["children"][0]["metrics"]["Time (ns)"] > 0 + assert data[0]["children"][0]["children"][0]["metrics"]["time (ns)"] > 0 + + +def test_pcsampling(): + if is_hip(): + pytest.skip("HIP backend does not support pc sampling") + + import os + if os.environ.get("PROTON_SKIP_PC_SAMPLING_TEST", "0") == "1": + pytest.skip("PC sampling test is disabled") + + @triton.jit + def foo(x, y, size: tl.constexpr): + offs = tl.arange(0, size) + for _ in range(1000): + tl.store(y + offs, tl.load(x + offs)) + + with tempfile.NamedTemporaryFile(delete=True, suffix=".hatchet") as f: + proton.start(f.name.split(".")[0], hook="triton", backend="cupti_pcsampling") + with proton.scope("init"): + x = torch.ones((1024, ), device="cuda", dtype=torch.float32) + y = torch.zeros_like(x) + with proton.scope("test"): + foo[(1, )](x, y, x.size()[0], num_warps=4) + proton.finalize() + data = json.load(f) + init_frame = data[0]["children"][0] + test_frame = data[0]["children"][1] + # With line mapping + assert "foo" in test_frame["children"][0]["frame"]["name"] + assert test_frame["children"][0]["children"][0]["metrics"]["num_samples"] > 0 + assert "@" in test_frame["children"][0]["children"][0]["frame"]["name"] + # Without line mapping + assert "elementwise" in init_frame["children"][0]["frame"]["name"] + assert init_frame["children"][0]["metrics"]["num_samples"] > 0 def test_deactivate(): @@ -211,6 +245,6 @@ def test_deactivate(): proton.finalize() data = json.load(f) # Root shouldn't have device id - assert "DeviceId" not in data[0]["metrics"] + assert "device_id" not in data[0]["metrics"] assert len(data[0]["children"]) == 1 - assert "DeviceId" in data[0]["children"][0]["metrics"] + assert "device_id" in data[0]["children"][0]["metrics"] diff --git a/third_party/proton/test/test_viewer.py b/third_party/proton/test/test_viewer.py index c8343e1267..998825bbc8 100644 --- a/third_party/proton/test/test_viewer.py +++ b/third_party/proton/test/test_viewer.py @@ -52,8 +52,8 @@ def test_min_time_flops(): with open(cuda_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) ret = get_min_time_flops(gf.dataframe, device_info) - device0_idx = gf.dataframe["DeviceId"] == "0" - device1_idx = gf.dataframe["DeviceId"] == "1" + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" # sm89 np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000025]], atol=1e-5) # sm90 @@ -61,8 +61,8 @@ def test_min_time_flops(): with open(hip_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) ret = get_min_time_flops(gf.dataframe, device_info) - device0_idx = gf.dataframe["DeviceId"] == "0" - device1_idx = gf.dataframe["DeviceId"] == "1" + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" # MI200 np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000026]], atol=1e-5) # MI300 @@ -73,8 +73,8 @@ def test_min_time_bytes(): with open(cuda_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) ret = get_min_time_bytes(gf.dataframe, device_info) - device0_idx = gf.dataframe["DeviceId"] == "0" - device1_idx = gf.dataframe["DeviceId"] == "1" + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" # sm89 np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[9.91969e-06]], atol=1e-6) # sm90 @@ -82,8 +82,8 @@ def test_min_time_bytes(): with open(hip_example_file, "r") as f: gf, _, device_info = get_raw_metrics(f) ret = get_min_time_bytes(gf.dataframe, device_info) - device0_idx = gf.dataframe["DeviceId"] == "0" - device1_idx = gf.dataframe["DeviceId"] == "1" + device0_idx = gf.dataframe["device_id"] == "0" + device1_idx = gf.dataframe["device_id"] == "1" # MI200 np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[6.10351e-06]], atol=1e-6) # MI300 diff --git a/third_party/proton/tutorials/dynamic_net.py b/third_party/proton/tutorials/dynamic_net.py index a1a82b53e2..5793bebd09 100644 --- a/third_party/proton/tutorials/dynamic_net.py +++ b/third_party/proton/tutorials/dynamic_net.py @@ -85,13 +85,14 @@ def run(): argparser.add_argument("--profile", action="store_true") argparser.add_argument("--mode", default="torch", choices=["torch", "torchinductor"]) argparser.add_argument("--context", default="shadow", choices=["shadow", "python"]) +argparser.add_argument("--backend", default=None, choices=["cupti", "roctracer", "cupti_pcsampling"]) args = argparser.parse_args() mode = args.mode if args.profile: - func = proton.profile(run, name="dynamic_net", context=args.context) + func = proton.profile(run, name="dynamic_net", context=args.context, backend=args.backend) else: func = run diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index 67b1d9e9bc..e3f521f1b3 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -559,7 +559,7 @@ TEST_F(AMDMfmaLayoutTest, mfma32) { auto tmfma2d = createTransposedMFMA(32, 32, {2, 4}); ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u)); - ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto mfma3d = createMFMA(32, 32, {2, 4, 1}); ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); @@ -567,7 +567,7 @@ TEST_F(AMDMfmaLayoutTest, mfma32) { auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1}); ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u)); - ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); } TEST_F(AMDMfmaLayoutTest, mfma16) { @@ -577,7 +577,7 @@ TEST_F(AMDMfmaLayoutTest, mfma16) { auto tmfma2d = createTransposedMFMA(16, 16, {2, 4}); ASSERT_THAT(tmfma2d.getThreadOrder(), testing::ElementsAre(0u, 1u)); - ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(0u, 1u)); + ASSERT_THAT(tmfma2d.getWarpOrder(), testing::ElementsAre(1u, 0u)); auto mfma3d = createMFMA(16, 16, {2, 4, 1}); ASSERT_THAT(mfma3d.getThreadOrder(), testing::ElementsAre(2u, 1u, 0u)); @@ -585,7 +585,7 @@ TEST_F(AMDMfmaLayoutTest, mfma16) { auto tmfma3d = createTransposedMFMA(16, 16, {2, 4, 1}); ASSERT_THAT(tmfma3d.getThreadOrder(), testing::ElementsAre(1u, 2u, 0u)); - ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(1u, 2u, 0u)); + ASSERT_THAT(tmfma3d.getWarpOrder(), testing::ElementsAre(2u, 1u, 0u)); } } // anonymous namespace diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 0b7a0f7821..7d918602a7 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -529,14 +529,14 @@ TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}}}, {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, - {S("warp"), {{32, 0}, {0, 0}, {0, 0}}}, + {S("warp"), {{0, 0}, {0, 0}, {32, 0}}}, {S("block"), {}}}, {S("dim0"), S("dim1")})); EXPECT_EQ(toLinearLayout({128, 128}, mfmaT), LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {0, 8}, {0, 16}, {64, 0}}}, {S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 4}}}, - {S("warp"), {{32, 0}, {0, 32}, {0, 64}}}, + {S("warp"), {{0, 32}, {0, 64}, {32, 0}}}, {S("block"), {}}}, {S("dim0"), S("dim1")})); }