From 5bbce9e6b9b544384287a23ab6c3c420de3c533b Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Fri, 22 Nov 2024 15:04:25 +0000 Subject: [PATCH] Revert "Revert "[LAYOUTS] Implement IR support for LinearLayouts (#5170)"" This reverts commit 7b5daa47fd38226587c1e6be35d6e6e371074735. --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 65 +++ .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 30 +- include/triton/Tools/LinearLayout.h | 3 + .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 3 + .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 5 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 456 ++++++++++++++++++ .../TritonGPU/IR/LinearLayoutConversions.cpp | 143 +----- lib/Dialect/TritonGPU/IR/Ops.cpp | 10 +- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 48 +- .../Transforms/RemoveLayoutConversions.cpp | 9 +- lib/Tools/LinearLayout.cpp | 60 +-- test/TritonGPU/accelerate-matmul.mlir | 6 +- test/TritonGPU/ops.mlir | 14 + unittest/Dialect/TritonGPU/DialectTest.cpp | 130 ++++- 14 files changed, 785 insertions(+), 197 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index f2715043d7..e0865e12af 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -149,6 +149,71 @@ triton::gpu::BlockedEncodingAttr getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, int numWarps, int threadsPerWarp, int numCTAs); +// For each output dimension d, ensure that the layout's output size (i.e., its +// codomain) does not exceed shape[d]. Do this without changing the size of the +// layout's inputs (i.e., leave its domain unchanged). +// +// This function is invariant to the order of the layout's input and output +// dimensions. +// +// We achieve this by setting the largest value in each output dimension d to 0 +// because bases that map to a location larger than shape[d] +// effectively duplicate along that dimension. For example, consider a layout +// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to +// shrink the output dimension size to 8: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 16 +// +// In the first step, we shrink the output dimension size to 16 by setting +// L(lane=2) to 0: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +// +// This means that lane=2 has the same data as lane=0. +// +// Now the output dimension of this layout has a size of 16, which is still +// larger than 8. We find the current largest value in the output dimension, +// which is L(register=1) = 8, and we set L(register=1) to 0: +// +// L(register=1) = 0 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +// +// Now the output dimension of this layout has a size of 8, which is the desired +// size. Note that this method works only because the bases are powers of two, +// which is the case for DistributedLayouts If broadcastRegisters is false, we +// remove any register that's larger than the desired shape. In the example +// above we would have +// L(register=1) = 4 +// L(register=2) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +LinearLayout +ensureLayoutNotLargerThan(const LinearLayout &layout, + const llvm::SmallDenseMap &shape, + bool broadcastRegisters = true); + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along its most-minor dimension ("register" for register layouts, "offset" for +// shared layouts). +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape); + // Dump information about which threads/registers contain each of the tensor // elements. void dumpLayout(RankedTensorType tensorType); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index e02f0f7b4f..0dd89b8147 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -56,7 +56,6 @@ Right now, Triton implements two main classes of layouts: shared, and distribute code extraBaseClassDeclaration = [{ unsigned getTotalElemsPerThread(ArrayRef shape, Type eltTy) const; SmallVector getElemsPerThread(ArrayRef shape, Type eltTy) const; - ::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const; }]; } @@ -147,7 +146,6 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to let genVerifyDecl = 1; let skipDefaultBuilders = 1; } - //===----------------------------------------------------------------------===// // Shared Layout Encoding //===----------------------------------------------------------------------===// @@ -565,6 +563,34 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, }]; } +//===----------------------------------------------------------------------===// +// Linear Layout Encoding +//===----------------------------------------------------------------------===// + +def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> { + let mnemonic = "linear"; + + let description = [{ + See the docs in LinearLayout.h for the definition of linear layouts. + }]; + + let parameters = (ins "LinearLayout":$linearLayout); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getContigPerThread() const; + SmallVector getOrder() const; + }]; + + let genVerifyDecl = 1; + // Example of assembly format: + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + let hasCustomAssemblyFormat = 1; +} + + //===----------------------------------------------------------------------===// // Blocked Layout Encoding //===----------------------------------------------------------------------===// diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 47e3fca79b..cfc4c0d13b 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -9,6 +9,7 @@ #include #include "mlir/IR/BuiltinAttributes.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -432,6 +433,7 @@ class LinearLayout { // (e.g. by reshaping) then the order doesn't really affect anything. auto getInDimNames() const { return llvm::make_first_range(bases); } auto getOutDimNames() const { return llvm::make_first_range(outDims); } + auto getOutDimSizes() const { return llvm::make_second_range(outDims); } // Gets the position that this outDim occupies in getOutDimNames(). Asserts // if the dim is not present. @@ -693,6 +695,7 @@ class LinearLayout { return !(lhs == rhs); } bool equalIgnoringOutDimSizes(const LinearLayout &other) const; + friend size_t hash_value(const LinearLayout &layout); private: // Factory function that gracefully fails rather than asserts if the layout is diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 0a1b0e30fc..dd60cfb741 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -397,6 +397,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (isa(layout)) { return true; } + if (isa(layout)) { + return true; + } if (auto slice = dyn_cast(layout)) { return layoutIsOK(slice.getParent()); } diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 4cea14f095..b090670d95 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -165,8 +165,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); if (isa(srcLayout) && - (isa( - dstLayout) || + (isa(dstLayout) || isSupportedDotOpLayout(dstTy))) { return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); @@ -206,7 +206,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto dstTy = op.getResult().getType(); auto dstShape = dstTy.getShape(); auto srcSharedLayout = cast(srcTy.getEncoding()); - auto dstLayout = dstTy.getEncoding(); assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && "Unexpected rank of ConvertLayout(shared->distributed)"); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 483e4e3274..84c0050443 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -334,6 +334,9 @@ SmallVector getOrder(Attribute layout) { if (auto sharedLayout = mlir::dyn_cast(layout)) { return llvm::to_vector(sharedLayout.getOrder()); } + if (auto linearLayout = mlir::dyn_cast(layout)) { + return linearLayout.getOrder(); + } llvm::report_fatal_error("Unimplemented usage of getOrder"); return {}; @@ -557,6 +560,102 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, return encoding; } +LinearLayout +ensureLayoutNotLargerThan(const LinearLayout &layout, + const llvm::SmallDenseMap &shape, + bool broadcastRegisters) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + MLIRContext *ctx = shape.begin()->first.getContext(); + + auto bases = layout.getBases(); + + auto kRegister = StringAttr::get(ctx, "register"); + std::set broadcastedDims; + + for (auto outDim : llvm::enumerate(layout.getOutDimNames())) { + auto outDimName = outDim.value(); + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + if (actualSize <= desiredSize) { + continue; + } + assert(actualSize % desiredSize == 0); + // + std::vector> sortedBases; + for (auto [inDimName, basis] : bases) { + for (size_t basisIdx = 0; basisIdx < basis.size(); basisIdx++) { + auto outValue = basis[basisIdx][outDim.index()]; + if (outValue == 0) { + continue; + } + assert(llvm::isPowerOf2_32(outValue)); + sortedBases.emplace_back(inDimName, basisIdx, outValue); + } + } + // From the largest basis to the smallest. + llvm::sort(sortedBases, + [](auto a, auto b) { return std::get<2>(a) > std::get<2>(b); }); + for (auto [inDimName, basisIdx, outValue] : sortedBases) { + if (actualSize <= desiredSize) { + break; + } + if (!broadcastRegisters && inDimName == kRegister) { + broadcastedDims.insert(basisIdx); + } else { + bases[inDimName][basisIdx][outDim.index()] = 0; + } + actualSize >>= 1; + } + } + if (!broadcastRegisters) { + // Remove broadcasted registers + std::vector> newBasesRegister; + for (auto [idx, basis] : llvm::enumerate(bases[kRegister])) { + // Remove if it's broadcasted + if (broadcastedDims.find(idx) == broadcastedDims.end()) { + newBasesRegister.push_back(std::move(basis)); + } + } + bases[kRegister] = std::move(newBasesRegister); + } + + return LinearLayout(std::move(bases), + llvm::to_vector(layout.getOutDimNames())); +} + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along its most-minor dimension ("register" for register layouts, "offset" for +// shared layouts). +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + + MLIRContext *ctx = shape.begin()->first.getContext(); + StringAttr kDim = *layout.getInDimNames().begin(); + assert(kDim == "register" || kDim == "offset"); + + LinearLayout ret = layout; + for (StringAttr outDimName : layout.getOutDimNames()) { + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + assert(actualSize > desiredSize || desiredSize % actualSize == 0); + ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName); + assert(ret.getOutDimSize(outDimName) >= desiredSize); + } + return ret; +} + } // namespace gpu } // namespace triton } // namespace mlir @@ -1244,6 +1343,360 @@ void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { printer << "}>"; } +// FIXME Can we take the LinearLayout by const&? +LogicalResult +LinearEncodingAttr::verify(function_ref emitError, + LinearLayout linearLayout) { + // Example of LinearEncodingAttr + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + // The input dims must be {register, lane, warp, block} + // The output dims of the linear layout should be dim0..dim[rank-1] + + static const auto expectedInDims = + SmallVector({"register", "lane", "warp", "block"}); + for (const auto &[i, dims] : llvm::enumerate( + llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { + const auto &[dim, expectedDimStr] = dims; + if (dim.str() != expectedDimStr) { + return emitError() << "Expected input dimension " << i << " to be '" + << expectedDimStr << "'. Got " << dim; + } + } + + // outDims are ['dim0', 'dim1', ...] + for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) { + if (dim.str() != ("dim" + llvm::Twine(i)).str()) { + return emitError() + << "Expected output dimensions to be ['dim0', 'dim1', ...]. Got " + << dim << " at position " << i; + } + } + + const auto &bases = linearLayout.getBases(); + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &dimBases : llvm::make_second_range(bases)) { + if (!llvm::all_of(dimBases, [&](const auto &basis) { + return std::count_if(basis.begin(), basis.end(), nonZero) <= 1; + })) { + return emitError() + << "In a distributed layout, each base must move in at most one " + "dimension."; + } + } + + return success(); +} + +void LinearEncodingAttr::print(mlir::AsmPrinter &printer) const { + // We don't use the default implementation as it's a bit too verbose + // This prints in the following format that is shape agnostic, in the sense + // that we don't print explicitly the outShape of the LL + // We always assume LLs to be surjective + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + auto ll = getLinearLayout(); + printer << "<{" << join(ll.getBases(), ", ", [](const auto &base) { + return base.first.str() + " = " + "[" + + join(base.second, ", ", + [](const std::vector &vec) { + return "[" + join(vec, ", ") + "]"; + }) + + "]"; + }) << "}>"; +} + +Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + + if (parser.parseGreater().failed()) + return {}; + + LinearLayout::BasesT bases; + + // Parse the basis names in order (the order is relevant) + std::vector inDimNames = {"register", "lane", "warp", "block"}; + + for (const auto &inDimNameStr : inDimNames) { + auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); + Attribute value = dict.get(inDimName); + + // Expecting an array of arrays + auto arrayOfArraysAttr = mlir::dyn_cast(value); + if (!arrayOfArraysAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of arrays for basis of '") + << inDimName.getValue() << "'"; + return {}; + } + + std::vector> inDimBases; + for (Attribute arrayAttr : arrayOfArraysAttr) { + auto intArrayAttr = mlir::dyn_cast(arrayAttr); + if (!intArrayAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of integers in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + std::vector basis; + for (Attribute intAttr : intArrayAttr) { + auto intValueAttr = mlir::dyn_cast(intAttr); + if (!intValueAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected integer in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + basis.push_back(intValueAttr.getInt()); + } + inDimBases.push_back(std::move(basis)); + } + bases[inDimName] = std::move(inDimBases); + } + size_t rank = 0; + for (const auto &basesDim : llvm::make_second_range(bases)) { + if (!basesDim.empty()) { + rank = basesDim[0].size(); + break; + } + } + + // To implement this we'd need to serialise the rank as well. + // We can do this if we ever need it + if (rank == 0) { + parser.emitError(parser.getCurrentLocation(), "Empty Layout not supported"); + return {}; + } + + // Generate standared outDimNames (dim0, dim1, ...) + SmallVector outDimNames; + for (int i = 0; i < rank; ++i) { + outDimNames.push_back( + StringAttr::get(parser.getContext(), "dim" + llvm::Twine(i))); + } + + // Create LinearLayout + LinearLayout linearLayout(std::move(bases), std::move(outDimNames)); + + // Create and return the LinearEncodingAttr + return parser.getChecked(parser.getContext(), + std::move(linearLayout)); +} + +SmallVector basesPerDim(const LinearLayout::BasesT &namedBases, + StringAttr dimName, size_t rank, + bool skipBroadcast = true) { + const auto &bases = namedBases.find(dimName)->second; + + if (bases.empty()) { + return SmallVector(rank, 1); + } + + SmallVector ret(rank, 1); + auto nonZero = [](auto val) { return val != 0; }; + int nonZeroIdx = -1; + for (const auto &basis : bases) { + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + if (it != basis.end()) { + nonZeroIdx = it - basis.begin(); + ret[nonZeroIdx] *= 2; + } else if (!skipBroadcast) { + // If we've seen a non-zero basis, we double the size of the previous dim + // This is just needed to count the CTAsPerCGA + assert(nonZeroIdx != -1); + ret[nonZeroIdx] *= 2; + } + } + return ret; +} + +SmallVector basesPerDim(const LinearLayout &ll, StringAttr dimName, + bool skipBroadcast = true) { + auto shapeIter = ll.getOutDimSizes(); + auto rank = std::distance(shapeIter.begin(), shapeIter.end()); + return basesPerDim(ll.getBases(), dimName, rank, skipBroadcast); +} + +SmallVector orderPerDim(const LinearLayout &ll, StringAttr dimName, + ArrayRef defaultOrder) { + const auto &bases = ll.getBases().find(dimName)->second; + llvm::SetVector order; + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &basis : bases) { + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + if (it != basis.end()) { + auto i = it - basis.begin(); + order.insert(i); + } + } + // If any dim is missing, we add them in the defaultOrder + for (auto i : defaultOrder) { + order.insert(i); + } + return SmallVector(order.begin(), order.end()); +} + +// [Note. Divergence of methods wrt. legacy layouts] +// For smaller shapes where the CTATile is larger than the output +// tensor, some methods return different values than the legacy layouts. I think +// this is benign tho. An example: what is the the vector of `warpsPerCTA` if +// all the warps hold the same data? I think it should be [1, 1], even if we +// have 4 warps. But perhaps for this we have to add some masking in some +// places... We'll see +SmallVector LinearEncodingAttr::getRepOrder() const { + // This is not correct, but: + // - It happens to agree in most places with the legacy layout + // - getRepOrder does not make sense for LinearEncodingAttr as it already has + // the same shape as the tensor that uses it + return getOrder(); +} +SmallVector LinearEncodingAttr::getCTAsPerCGA() const { + // CTAs are split into an identity part (SplitNum) and a broadcast part + return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "block"), + /*skipBroadcast=*/false); +} +SmallVector LinearEncodingAttr::getCTAOrder() const { + return orderPerDim(getLinearLayout(), StringAttr::get(getContext(), "block"), + getOrder()); +} +SmallVector LinearEncodingAttr::getCTASplitNum() const { + return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "block")); +} +SmallVector LinearEncodingAttr::getWarpsPerCTA() const { + return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "warp")); +} +SmallVector LinearEncodingAttr::getWarpOrder() const { + return orderPerDim(getLinearLayout(), StringAttr::get(getContext(), "warp"), + getOrder()); +} +SmallVector LinearEncodingAttr::getThreadsPerWarp() const { + return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "lane")); +} +SmallVector LinearEncodingAttr::getThreadOrder() const { + return orderPerDim(getLinearLayout(), StringAttr::get(getContext(), "lane"), + getOrder()); +} +SmallVector LinearEncodingAttr::getSizePerThread() const { + auto rank = getRepOrder().size(); + auto ll = getLinearLayout(); + auto ctx = getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + + // We canonicalize on the spot, as if we use CGAs the regs are not in + // canonical form The order is [reg, lane, warp, rep, block], so we first + // remove the blocks + llvm::SmallVector ctaShape; + for (auto [shape, cgaNum] : + llvm::zip(ll.getOutDimSizes(), getCTASplitNum())) { + ctaShape.push_back(shape / cgaNum); + } + LinearLayout::BasesT bases = ll.getBases(); + + llvm::SetVector reverseRepOrder; + auto nonZero = [](auto val) { return val != 0; }; + auto ®isters = bases[StringAttr::get(ctx, "register")]; + while (!registers.empty()) { + auto &basis = registers.back(); + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + // If there's broadcasting (base == zeros) there are no more reps + if (it == basis.end()) { + break; + } + auto dim = it - basis.begin(); + reverseRepOrder.insert(dim); + // As soon as we stop finding reps, we stop + if (dim != reverseRepOrder.back() || 2 * basis[dim] != ctaShape[dim]) { + break; + } + ctaShape[dim] /= 2; + registers.pop_back(); + } + return basesPerDim(bases, kRegister, rank); +} + +SmallVector LinearEncodingAttr::getOrder() const { + auto rank = getLinearLayout().getNumOutDims(); + SmallVector order(rank); + // Choose [rank-1, rank-2, ... 0] as the default order in case + // there are dims that do not move in the register + // This order is as good as any really + std::iota(order.rbegin(), order.rend(), 0); + + return orderPerDim(getLinearLayout(), + StringAttr::get(getContext(), "register"), order); +} + +std::optional +LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto ll = getLinearLayout(); + auto canonicalDims = llvm::to_vector(ll.getOutDimNames()); + llvm::SmallDenseMap namedShape; + llvm::SmallVector permutedDims; + for (auto dim : getRepOrder()) { + permutedDims.push_back(canonicalDims[dim]); + namedShape[canonicalDims[dim]] = shape[dim]; + } + ll = ll.transposeOuts(permutedDims); + ll = ensureLayoutNotSmallerThan(ll, namedShape); + ll = ensureLayoutNotLargerThan(ll, namedShape, /*broadcastRegisters=*/false); + ll = ll.transposeOuts(canonicalDims); + return ll; +} + +SmallVector +LinearEncodingAttr::getElemsPerThread(ArrayRef shape, Type) const { + // We can relax this assert by calling toLinearLayout rather than + // getLinearLayout + SmallVector shapeVec(shape.begin(), shape.end()); + assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes())); + auto ll = getLinearLayout(); + return basesPerDim(ll, StringAttr::get(getContext(), "register")); +} + +// Start of Selection +SmallVector LinearEncodingAttr::getContigPerThread() const { + auto ll = getLinearLayout(); + const auto ®s = + ll.getBases().find(StringAttr::get(getContext(), "register"))->second; + auto order = getOrder(); + auto rank = order.size(); + + SmallVector contig(rank, 1); + auto regIt = regs.begin(); + for (unsigned dim : order) { + std::vector basis(rank, 0); + basis[dim] = 1; + + while (regIt != regs.end() && *regIt == basis) { + contig[dim] *= 2; + basis[dim] *= 2; + ++regIt; + } + } + return contig; +} + +unsigned LinearEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + //===----------------------------------------------------------------------===// // MMA encoding //===----------------------------------------------------------------------===// @@ -2034,6 +2487,9 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { } else if (auto blockedAttr = mlir::dyn_cast(attr)) { os << "blocked"; return AliasResult::FinalAlias; + } else if (auto linearAttr = mlir::dyn_cast(attr)) { + os << "linear"; + return AliasResult::FinalAlias; } else if (auto warpAttr = mlir::dyn_cast(attr)) { os << "warp"; return AliasResult::FinalAlias; diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 5c840d1dcb..20280a6ddb 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -53,25 +53,6 @@ SmallVector permuteDimNames(const SmallVector &names, return ret; } -void assertIsRegisterLayout(const LinearLayout &layout) { - assert(layout.getNumInDims() > 0); - MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); - StringAttr kRegister = S("register"); - StringAttr kLane = S("lane"); - StringAttr kWarp = S("warp"); - StringAttr kBlock = S("block"); - - const auto &ins = layout.getInDimNames(); - assert(llvm::SmallVector(ins.begin(), ins.end()) == - llvm::SmallVector({kRegister, kLane, kWarp, kBlock})); - - const auto &outs = layout.getOutDimNames(); - const auto &expectedOuts = standardOutDimNames(ctx, layout.getNumOutDims()); - assert(llvm::SmallDenseSet(outs.begin(), outs.end()) == - llvm::SmallDenseSet(expectedOuts.begin(), - expectedOuts.end())); -} - // Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to // creating a 1D -> 1D mapping of size product(shape) and then reshaping to // permute(shape, order). @@ -122,124 +103,6 @@ LinearLayout makeCgaLayout(CTALayoutAttr layout) { return ret.transposeOuts(outDimNames); } -// For each output dimension d, ensure that the layout's output size (i.e., its -// codomain) does not exceed shape[d]. Do this without changing the size of the -// layout's inputs (i.e., leave its domain unchanged). -// -// This function is invariant to the order of the layout's input and output -// dimensions. -// -// We achieve this by setting the largest value in each output dimension d to 0 -// because bases that map to a location larger than shape[d] -// effectively duplicate along that dimension. For example, consider a layout -// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to -// shrink the output dimension size to 8: -// -// L(register=1) = 8 -// L(register=2) = 4 -// L(register=4) = 1 -// L(lane=1) = 2 -// L(lane=2) = 16 -// -// In the first step, we shrink the output dimension size to 16 by setting -// L(lane=2) to 0: -// -// L(register=1) = 8 -// L(register=2) = 4 -// L(register=4) = 1 -// L(lane=1) = 2 -// L(lane=2) = 0 -// -// This means that lane=2 has the same data as lane=0. -// -// Now the output dimension of this layout has a size of 16, which is still -// larger than 8. We find the current largest value in the output dimension, -// which is L(register=1) = 8, and we set L(register=1) to 0: -// -// L(register=1) = 0 -// L(register=2) = 4 -// L(register=4) = 1 -// L(lane=1) = 2 -// L(lane=2) = 0 -// -// Now the output dimension of this layout has a size of 8, which is the desired -// size. Note that this method works only because the bases are powers of two. -// It is unclear what to do when they are not. -LinearLayout ensureLayoutNotLargerThan( - const LinearLayout &layout, - const llvm::SmallDenseMap &shape) { - assert(shape.size() == layout.getNumOutDims()); - if (shape.empty()) { - return layout; - } - MLIRContext *ctx = shape.begin()->first.getContext(); - - auto bases = layout.getBases(); - for (auto outDim : llvm::enumerate(layout.getOutDimNames())) { - auto outDimName = outDim.value(); - int32_t actualSize = layout.getOutDimSize(outDimName); - int32_t desiredSize = shape.lookup(outDimName); - if (actualSize <= desiredSize) { - continue; - } - assert(actualSize % desiredSize == 0); - // - std::vector> sortedBases; - for (auto [inDimName, basis] : bases) { - for (size_t basisIdx = 0; basisIdx < basis.size(); basisIdx++) { - auto outValue = basis[basisIdx][outDim.index()]; - if (outValue == 0) { - continue; - } - assert(llvm::isPowerOf2_32(outValue)); - sortedBases.emplace_back(inDimName, basisIdx, outValue); - } - } - // From the largest basis to the smallest. - llvm::sort(sortedBases, - [](auto a, auto b) { return std::get<2>(a) > std::get<2>(b); }); - for (auto [inDimName, basisIdx, outValue] : sortedBases) { - if (actualSize <= desiredSize) { - break; - } - bases[inDimName][basisIdx][outDim.index()] = 0; - actualSize >>= 1; - } - } - return LinearLayout(std::move(bases), - llvm::to_vector(layout.getOutDimNames())); -} - -// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no -// smaller than shape[d]. Do this by increasing the size of the layout's inputs -// along its most-minor dimension ("register" for register layouts, "offset" for -// shared layouts). -// -// This function is invariant to the order of the layout's input dimensions, but -// it cares about the order of the output dims, which should be minor-to-major. -LinearLayout ensureLayoutNotSmallerThan( - const LinearLayout &layout, - const llvm::SmallDenseMap &shape) { - assert(shape.size() == layout.getNumOutDims()); - if (shape.empty()) { - return layout; - } - - MLIRContext *ctx = shape.begin()->first.getContext(); - StringAttr kDim = *layout.getInDimNames().begin(); - assert(kDim == "register" || kDim == "offset"); - - LinearLayout ret = layout; - for (StringAttr outDimName : layout.getOutDimNames()) { - int32_t actualSize = layout.getOutDimSize(outDimName); - int32_t desiredSize = shape.lookup(outDimName); - assert(actualSize > desiredSize || desiredSize % actualSize == 0); - ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName); - assert(ret.getOutDimSize(outDimName) >= desiredSize); - } - return ret; -} - // Combines the layout of a CTA (input dims [register, lane, warp]) with the // layout of a CGA (i.e. a block), and ensures that the resulting layout has the // given shape. @@ -932,10 +795,10 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { std::optional toLinearLayout(ArrayRef shape, Attribute layout, std::optional elemBitWidth /*= std::nullopt*/) { + // Layouts are distributed or shared if (auto distributed = dyn_cast(layout)) { return distributed.toLinearLayout(shape); - } - if (auto shared = dyn_cast(layout)) { + } else if (auto shared = dyn_cast(layout)) { if (shared.getHasLeadingOffset()) { assert(elemBitWidth.has_value()); return sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); @@ -944,7 +807,7 @@ toLinearLayout(ArrayRef shape, Attribute layout, } } - // TODO(jlebar): Other layouts + // Third party layouts return std::nullopt; } diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 233883964f..068965468e 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -53,15 +53,17 @@ LogicalResult UpcastMXFPOp::verify() { if (!dotEncoding) { return emitOpError("Expected a DotOperandEncodingAttr for values"); } - auto blockedScale = dyn_cast(layoutScale); - if (!blockedScale) { - return emitOpError("Expected a BlockOperandEncoding for scales"); + if (!isa(layoutScale)) { + return emitOpError( + "Expected a BlockOperandEncoding or LinearOperandEncoding " + "for scales"); } if (isa(dotEncoding.getParent())) { // Necessary to keep all of the scales of a given block of values in the // same warp - auto threadsPerWarp = blockedScale.getThreadsPerWarp(); + auto threadsPerWarp = + cast(layoutScale).getThreadsPerWarp(); if (threadsPerWarp != ArrayRef({16, 2})) { return emitOpError("Expected threads per warp to be {16, 2}"); } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 4dccc85da3..c07f314087 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -400,6 +400,7 @@ class DecomposeScaledBlocked "NYI: lhs supports fp4 or fp8"); assert(bType == ScaleDotElemType::E4M3 || bType == ScaleDotElemType::E5M2 || bType == ScaleDotElemType::BF16 && "NYI: rhs supports fp8 and bf16"); + bool isFp4 = aType == ScaleDotElemType::E2M1; auto mmaEnc = getMMAEncoding(rewriter, scaledDotOp); auto versionMajor = mmaEnc.getVersionMajor(); @@ -418,7 +419,7 @@ class DecomposeScaledBlocked // types auto aKWidth = mmaEnc.isHopper() ? 2 : 8; auto bKWidth = mmaEnc.isHopper() ? 2 : 8; - if (aType == ScaleDotElemType::E2M1) { + if (isFp4) { // Load 2x4-bit elements per thread aKWidth /= 2; } @@ -438,9 +439,43 @@ class DecomposeScaledBlocked // Necessary choice to leave all the scales of the tile in that given warp auto threadsPerWarp = SmallVector{instrShapeM, warpSize / instrShapeM}; - auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + + assert(versionMajor == 2 && + "NYI: MMAv3. Need to rethink the scale layout otherwise"); + + // Copy the bases + + Attribute newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), newAEncoding.getCTAOrder(), mmaEnc.getCTALayout()); + + auto dotBroadcastsWarpLevel = mmaEnc.getWarpsPerCTA()[1] != 1; + if (dotBroadcastsWarpLevel) { + // If mma has warpsPerCTA == {2, 2}, then newAEncoding has + // warpsPerCTA == {2, 1}. In this case, we need to broadcast the warps + // on the second dimension as per + // A: 0 1 | 0 1 + // - - | - - + // 2 3 | 2 3 + // This broadcasting is not representable by standard blocked encodings, + // so we need to use linear layouts. + // This broadcasting is implemented in ampereDotToLinearLayout + auto blocked = cast(newScaleEncoding); + auto blockedLL = *blocked.toLinearLayout(a.getType().getShape()); + LinearLayout::BasesT scaleBases = blockedLL.getBases(); + auto nBases = llvm::Log2_32(mmaEnc.getWarpsPerCTA()[1]); + auto &warps = scaleBases[StringAttr::get(ctx, "warp")]; + // Prepend the vector of zeros to the warpBases + warps.insert(warps.begin(), nBases, std::vector(rank, 0)); + auto outDims = llvm::to_vector(blockedLL.getOutDimNames()); + auto newLL = LinearLayout(scaleBases, outDims); + auto llEncoding = LinearEncodingAttr::get(ctx, std::move(newLL)); + // Adjust the shape of the layout to match the scale operand + auto scaleShape = scale.getType().getShape(); + newScaleEncoding = + LinearEncodingAttr::get(ctx, *llEncoding.toLinearLayout(scaleShape)); + } + a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding); // Upcast B operand @@ -543,7 +578,8 @@ class DecomposeScaledBlocked auto dotOp = rewriter.create( scaledDotOp.getLoc(), scaledDotOp.getType(), a, b, scaledDotOp.getC()); - // FIXME Waiting on the following comment to be fixed: + // Waiting for https://github.com/triton-lang/triton/pull/5003 to land + // cf. // https://github.com/triton-lang/triton/pull/5003#issuecomment-2445091746 // int versionMajor = getMMAVersionSafe(computeCapability, dotOp); int versionMajor = 2; @@ -559,10 +595,8 @@ class DecomposeScaledBlocked versionMajor, retShapePerCTA, dotOp.getA().getType().getElementType(), numWarps); - // FIXME Waiting on supporting LLs on convert_layout - // auto warpsPerCTA = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, - // numWarps, instrShape); - SmallVector warpsPerCTA = {(unsigned)numWarps, 1}; + auto warpsPerCTA = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, + numWarps, instrShape); return NvidiaMmaEncodingAttr::get(ctx, versionMajor, versionMinor, warpsPerCTA, CTALayout, instrShape); } diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index bc4049dc30..70f5219111 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -970,7 +970,9 @@ void LayoutRematerialization::backwardRematerialization( // we don't handle conversions to DotOperandEncodingAttr // this is a heuristic to accommodate fused attention RankedTensorType targetType = convertOp.getType(); - if (isa(targetType.getEncoding())) + // We stop the rematerialization of linear layouts as we have to be a bit more + // careful with the heuristics for both correctness and perf + if (isa(targetType.getEncoding())) return; Value oldV = convertOp->getOperand(0); LDBG("check backward remat with source " << oldV << " encoding " @@ -1012,8 +1014,11 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention + // We stop the rematerialization of linear layouts as we have to be a bit more + // careful with the heuristics for both correctness and perf RankedTensorType targetType = convertOp.getType(); - if (mlir::isa(targetType.getEncoding())) + if (mlir::isa( + targetType.getEncoding())) return; auto isExtOrBroadcastOp = [](Operation *op) { diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 4319d1f086..3a81231ac8 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -212,42 +212,6 @@ void assertCommonDimsSameOrder(T &&aDims, U &&bDims) { "\nb: " + triton::join(bDims, ", ")); } } - -void eraseEmptyInOutDims(BasesT &bases, - llvm::MapVector &outDims) { - // Erase empty out-dims. - SmallVector emptyOutDims; - for (auto [i, outDim] : llvm::enumerate( - llvm::to_vector_of(llvm::make_first_range(outDims)))) { - if (outDims[outDim] == 1) { - emptyOutDims.push_back(i); - outDims.erase(outDim); - } - } - if (outDims.empty()) { - bases.clear(); - return; - } - - for (auto &[inDim, inDimBases] : bases) { - for (auto &basis : inDimBases) { - // Erase the basis elements corresponding to the empty out-dims. - for (int i : llvm::reverse(emptyOutDims)) { - basis.erase(basis.begin() + i); - } - } - } - - // Erase empty in-dims. - // TODO: This needs a test-case. - for (StringAttr inDim : - llvm::to_vector_of(llvm::make_first_range(bases))) { - if (bases[inDim].empty()) { - bases.erase(inDim); - } - } -} - } // anonymous namespace /*static*/ std::optional @@ -989,6 +953,30 @@ LinearLayout::getFreeVariableMasks() const { return ret; } +size_t hash_value(const LinearLayout &layout) { + size_t seed = 0; + + // Hash the bases + for (const auto &base : layout.getBases()) { + // Hash the input dimension name + seed = llvm::hash_combine(seed, base.first); + + // Hash the vectors in bases + for (const auto &vec : base.second) { + for (int32_t val : vec) { + seed = llvm::hash_combine(seed, val); + } + } + } + + // Hash the output dimensions and their sizes + for (const auto &outDim : layout.getOutDimNames()) { + seed = llvm::hash_combine(seed, outDim, layout.getOutDimSize(outDim)); + } + // Don't hash the surjective flag as it's a cached property + return seed; +} + bool operator==(LinearLayout lhs, LinearLayout rhs) { if (!lhs.equalIgnoringOutDimSizes(rhs)) return false; diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 703d379bf5..648a29c34f 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -168,14 +168,16 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: #[[LINEAR:.+]] = #triton_gpu.linear<{{.*}}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: dot_scaled + // CHECK: dot_scaled tt.func @dot_scaled( %a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> { - // CHECK: triton_gpu.upcast_mxfp + // CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, #[[LINEAR]]> + // CHECK: triton_gpu.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{{.*}}>>, tensor<128x2xi8, #[[LINEAR]]> -> tensor<128x64xbf16, #triton_gpu.dot_op<{{.*}}>> // CHECK: tt.dot %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> diff --git a/test/TritonGPU/ops.mlir b/test/TritonGPU/ops.mlir index 9184a53120..70c1a315e7 100644 --- a/test/TritonGPU/ops.mlir +++ b/test/TritonGPU/ops.mlir @@ -33,3 +33,17 @@ module attributes {"triton_gpu.target" = "cuda:0", "triton_gpu.num-ctas" = 1 : i tt.return } } +// ----- + +#blocked= #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[$LINEAR:.*]] = #triton_gpu.linear<{{.*}}> + +module attributes {"triton_gpu.target" = "cuda:0", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @blocked_to_linear + tt.func @blocked_to_linear(%input: tensor<32x4xi8, #blocked>) { + // The layout is the basic layout generated by DecomposeScaledBlocked + %output = triton_gpu.convert_layout %input {allocation.offset = 0 : i32} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #triton_gpu.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}>> + // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #[[$LINEAR]]> + tt.return + } +} diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index c27c63335e..779bc1b788 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -620,7 +620,135 @@ TEST_F(AMDMfmaLayoutTest, mfma_dot_op) { ASSERT_THAT(tdot3dOp1.getWarpOrder(), tmfma3d.getWarpOrder()); } -} // anonymous namespace +class LinearEncodingTest : public ::testing::Test { +public: + LinearEncodingTest() { ctx.getOrLoadDialect(); } + +protected: + MLIRContext ctx; +}; + +TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { + // Define a tensor shape + auto rank = 2; + SmallVector> shapes = {{64, 128}, {256, 1024}}; + SmallVector> orders = {{0, 1}, {1, 0}}; + SmallVector ctaLayouts = { + triton::gpu::CTALayoutAttr::getDefault(&ctx, rank), + triton::gpu::CTALayoutAttr::get(&ctx, {4, 2}, {2, 2}, {1, 0}), + }; + SmallVector distributedEncodings; + + // Create BlockedEncodingAttr and SliceEncodingAttr + { + SmallVector sizePerThread = {4, 4}; + SmallVector threadsPerWarp = {4, 8}; + SmallVector warpsPerCTA = {2, 2}; + + for (auto ctaLayout : ctaLayouts) { + for (const auto &order : orders) { + auto blockedEncoding = triton::gpu::BlockedEncodingAttr::get( + &ctx, sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + distributedEncodings.push_back(blockedEncoding); + distributedEncodings.push_back( + triton::gpu::SliceEncodingAttr::get(&ctx, 0, blockedEncoding)); + } + } + } + + // Create an MMAv2 and DotOperandEncodingAttr (MMAv3 doesn't support linear + // layouts yet) + { + unsigned versionMajor = 2; + unsigned versionMinor = 0; + SmallVector warpsPerCTA{4, 2}; + SmallVector instrShape{16, 8}; // Instruction shape (M, N) + auto mma = triton::gpu::NvidiaMmaEncodingAttr::get( + &ctx, versionMajor, versionMinor, warpsPerCTA, ctaLayouts[0], + instrShape); + distributedEncodings.push_back(mma); + // Create an opIdx=0 and opIdx=1 encoding + for (unsigned opIdx = 0; opIdx < 2; ++opIdx) { + distributedEncodings.push_back( + triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx, mma, 2)); + } + } + + for (const auto &distributedEncoding : distributedEncodings) { + for (auto shape : shapes) { + if (auto sliceEncoding = + dyn_cast(distributedEncoding)) { + shape.erase(shape.begin() + sliceEncoding.getDim()); + } + + // Create LinearEncodingAttr from the LinearLayout + auto linearLayout = *distributedEncoding.toLinearLayout(shape); + auto linearEncoding = + triton::gpu::LinearEncodingAttr::get(&ctx, linearLayout); + + // Test that the canonical form of the LinearLayout is indeed canonical + // by expanding it to the original shape + auto expandedLL = linearEncoding.toLinearLayout(shape); + ASSERT_EQ(linearLayout, expandedLL); + + // Test that methods of DistributedEncoding return the same values + Type eltTy = FloatType::getF32(&ctx); + + ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder()); + ASSERT_EQ(cast(distributedEncoding) + .getTotalElemsPerThread(shape, eltTy), + linearEncoding.getTotalElemsPerThread(shape, eltTy)); + ASSERT_EQ(cast(distributedEncoding) + .getElemsPerThread(shape, eltTy), + linearEncoding.getElemsPerThread(shape, eltTy)); + ASSERT_EQ(distributedEncoding.getRepOrder(), + linearEncoding.getRepOrder()); + ASSERT_EQ(distributedEncoding.getContigPerThread(), + linearEncoding.getContigPerThread()); + // DotOperandEncodingAttr::getWarpOrder() is not defined + if (!isa(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getWarpOrder(), + linearEncoding.getWarpOrder()); + } + ASSERT_EQ(distributedEncoding.getThreadOrder(), + linearEncoding.getThreadOrder()); + // For slice these do not equal the total number of lines / warps + // See [Note. Divergence of methods wrt. legacy layouts] + if (!isa(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getWarpsPerCTA(), + linearEncoding.getWarpsPerCTA()); + ASSERT_EQ(distributedEncoding.getThreadsPerWarp(), + linearEncoding.getThreadsPerWarp()); + } + // Canonicalisation for opIdx=0 takes just a [2 x 2] subtile as it takes + // the second repetition along K as the second tile. + if (!isa(distributedEncoding)) { + // FIXME: This happens to be correct for SliceLayout because of the hack + // in SliceEncodingAttr::toLinearLayout(). We should remove the hack + // and the skips in the getWarpsPerCTA() and getThreadsPerWarp() + ASSERT_EQ(distributedEncoding.getSizePerThread(), + linearEncoding.getSizePerThread()); + } + + // block level + // SliceEncoding is not well-defined for CGAs + if (!isa(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getCTASplitNum(), + linearEncoding.getCTASplitNum()); + ASSERT_EQ(distributedEncoding.getCTAsPerCGA(), + linearEncoding.getCTAsPerCGA()); + // If we are not using CGAs, the order is meaningless + auto useCGA = distributedEncoding.getCTAsPerCGA() != + SmallVector(rank, 1); + if (useCGA) { + ASSERT_EQ(distributedEncoding.getCTAOrder(), + linearEncoding.getCTAOrder()); + } + } + } + } +} +} // namespace } // namespace mlir::triton::gpu int main(int argc, char *argv[]) {