From c373fc7a96316cb00974c4a1006ebd2cb0e85ed4 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Thu, 24 Oct 2024 13:57:42 +0100 Subject: [PATCH 1/2] Add LL::quotient and remove uses of divideRight and sublayoutIsIdentity (#4968) We add a new abstraction `LL::quotient` that abstracts the idea of "a linear layout does not permute certain dimensions". Doing so, allows us to remove `divideRight` and subsume them into this higher-level abstraction. We also fix a bug in `isCrossCTAConversion`. We also remove some code duplication from `transferWithinThreads` and `cvtReorderRegisters` in favour of a more generic approach. We fix a bug in `sublayout` that meant that `sublayout` would reorder `outDims` at will by using a set instead of a vector. I am missing adding tests for LL::quotient, will do in a minute. --- include/triton/Analysis/Utility.h | 9 + .../TritonGPU/IR/LinearLayoutConversions.h | 4 - include/triton/Tools/LinearLayout.h | 45 ++-- lib/Analysis/Utility.cpp | 77 +++--- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 149 +++++------ .../TritonGPU/IR/LinearLayoutConversions.cpp | 19 -- lib/Tools/LinearLayout.cpp | 204 +++++---------- .../AccelerateAMDMatmul.cpp | 2 +- unittest/Tools/LinearLayoutTest.cpp | 241 ++++++++---------- 9 files changed, 287 insertions(+), 463 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index ae05e20498..cb3e3d292e 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -6,6 +6,7 @@ #include "mlir/Support/LLVM.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Tools/LinearLayout.h" namespace mlir { @@ -189,6 +190,14 @@ bool supportMMA(triton::DotOp op, int version); bool supportMMA(Value value, int version); +// Conversion from `srcTy` to `dstTy` involving the minimum amount of data +// transfer provided that both types can be converted to LL (if it can't it'll +// return nullopt). The output will be such that layout.getInDimNames() == +// layout.getOutDimNames() and the conversion will not include kBlock (resp. +// kWarp or kLane) if it can be avoided +std::optional +minimalCvtLayout(RankedTensorType srcTy, RankedTensorType dstTy); + // Conversion from `srcTy` to `dstTy` only involves reordering of registers. // There is no need for data exchange across threads, warps, or blocks. bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy); diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index 1124daec6d..71bb5e50aa 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -44,10 +44,6 @@ std::optional toLinearLayout(ArrayRef shape, Attribute layout, std::optional elemBitWidth = std::nullopt); -// Given a linear layout with input dims and output dims containing a "block" -// dimension, determines if the layout moves data across block boundaries. -bool isCrossCTAConversion(const LinearLayout &layout); - // Given a linear layout where the input dimensions contain a "block" dimension, // this method sets the "block" dimension to 0 and removes the corresponding // output dimensions. diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 41a3621a97..c728cfbb32 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -575,29 +575,20 @@ class LinearLayout { return *this; } - // divideLeft and divideRight are the inverses of operator*. - // - // Consider `a = c.divideRight(b)`, where `a` is a linear layout with - // `in-dims(a) == in-dims(b)` and `out-dims(a) == out-dims(c)`. We may remove - // some empty dimensions from `a` to form `a'` and still have `a' * b == c`. - // Therefore, there are multiple possible values that we could return for - // `(a * b).divideRight(b)` which would satisfy - // `((a * b).divideRight(b)) * b == a * b`. - // - // In the following example, we have `a * b == a' * b` when "in1" is an empty - // dimension that maps everything to 0: - // - // a = L("in1", "in2") -> ("out1", "out2") - // a' = L("in1") -> ("out1") - // b = L("in2") -> ("out2") - // - // divideLeft and divideRight resolve this ambiguity by always returning the - // "canonical" quotient, namely the one with the fewest possible size-zero - // input and output dimensions. - // - // TODO(jlebar): Implement divideLeft. - // std::optional divideLeft(const LinearLayout &divisor); - std::optional divideRight(const LinearLayout &divisor) const; + // Returns true if this layout acts trivially (as the identity) on the given + // dimensions. This means that it's the identity on those dimensions, and it + // does not map other dimensions onto those or these onto other dimensions. + bool isTrivialOver(ArrayRef dimNames) const; + + // For an endomorphism on dimNames (linear map that maps dimNames to dimNames) + // checks whether it is the identity map on these dimensions (i.e + // LinearLayouts::isTrivialOver) and if so, returns the sublayout of the + // remaining dimensions. + // nb. The isTrivialOver condition is more restrictive than the usual + // "leaves the subspace invariant" condition in maths. + // We can always relax it if we know how to take advantage of a conversion + // layout being block-diagonal in the future. + std::optional quotient(ArrayRef dimNames) const; // Gets a layout with only these in/out dimensions. // @@ -614,10 +605,10 @@ class LinearLayout { bool sublayoutIsZero(ArrayRef inDimNames, ArrayRef outDimNames) const; - // Is the sublayout restricted to inDimNames + outDimNames and then flattened - // to 1D the identity layout (ignoring out-dim sizes)? - bool sublayoutIsIdentity(ArrayRef inDimNames, - ArrayRef outDimNames) const; + // Is the sublayout defined from dimNames to dimNames the identity? + // In particular, is the input and output size in these dimensions + // the same, and are the bases the identity? + bool squareSublayoutIsIdentity(ArrayRef dimNames) const; // Computes and returns L(x, y, z). // diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 8ac303c977..f51b3c1657 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -647,57 +647,56 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, return ans; } -bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) { +// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity +// under kBlock, kWarp or kLane (in that order). The idea here is that if we +// have a transformation that's the identity on kBlock, we don't need to use +// distributed shared memory. If it's also the identity on kWarp, we can +// transfer via warp-shuffles, and if it's the identity on kLane just have to +// reorder the registers +std::optional minimalCvtLayout(RankedTensorType srcTy, + RankedTensorType dstTy) { MLIRContext *ctx = srcTy.getContext(); std::optional srcLayout = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); std::optional dstLayout = toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); - if (srcLayout.has_value() && dstLayout.has_value()) { - // comp describes the layout function for converting from src to dst. - LinearLayout comp = srcLayout->invertAndCompose(*dstLayout); - StringAttr kLane = StringAttr::get(ctx, "lane"); - StringAttr kWarp = StringAttr::get(ctx, "warp"); - StringAttr kBlock = StringAttr::get(ctx, "block"); - // TODO(jlebar): These checks are overly-restrictive. For example, we can - // transfer by shuffling registers (case 1) if and only if all of the bases - // for `register` have 0s for lane, warp, and block. But the check below is - // stronger than this, checking also that the choice of lane/warp/block does - // not affect the permutation of registers. If we allow different - // lane/warp/blocks to have different permutations, we can generalize this. - if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kLane), - kLane, kLane) * - LinearLayout::identity1D(comp.getInDimSize(kWarp), - kWarp, kWarp) * - LinearLayout::identity1D(comp.getInDimSize(kBlock), - kBlock, kBlock)) - .has_value()) { - return true; + if (!(srcLayout.has_value() && dstLayout.has_value())) + return std::nullopt; + // comp describes the layout function to create dst from src. + LinearLayout comp = dstLayout->invertAndCompose(*srcLayout); + // We try to quotient by the largest subspace first + auto dims = SmallVector{"block", "warp", "lane", "register"}; + for (auto dim : dims) { + auto quotient = comp.quotient(StringAttr::get(ctx, dim)); + if (!quotient.has_value()) { + break; } + comp = *quotient; } - return false; + return comp; +} + +bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) { + auto layout = minimalCvtLayout(srcTy, dstTy); + MLIRContext *ctx = srcTy.getContext(); + if (!layout.has_value()) { + return false; + } + auto kRegister = StringAttr::get(ctx, "register"); + auto outDims = llvm::to_vector(layout->getOutDimNames()); + return outDims.empty() || ArrayRef(outDims) == ArrayRef({kRegister}); } bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) { + auto layout = minimalCvtLayout(srcTy, dstTy); MLIRContext *ctx = srcTy.getContext(); - std::optional srcLayout = - toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - std::optional dstLayout = - toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); - if (srcLayout.has_value() && dstLayout.has_value()) { - // comp describes the layout function for converting from src to dst. - LinearLayout comp = srcLayout->invertAndCompose(*dstLayout); - StringAttr kWarp = StringAttr::get(ctx, "warp"); - StringAttr kBlock = StringAttr::get(ctx, "block"); - if (comp.divideRight(LinearLayout::identity1D(comp.getInDimSize(kWarp), - kWarp, kWarp) * - LinearLayout::identity1D(comp.getInDimSize(kBlock), - kBlock, kBlock)) - .has_value()) { - return true; - } + if (!layout.has_value()) { + return false; } - return false; + auto kRegister = StringAttr::get(ctx, "register"); + auto kLane = StringAttr::get(ctx, "lane"); + return llvm::to_vector(layout->getOutDimNames()) == + llvm::SmallVector{kRegister, kLane}; } bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) { diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 893afc6590..43c7095b21 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -282,83 +282,67 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion const auto &shape = op.getType().getShape(); auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); - std::optional srcLayout = - toLinearLayout(shape, srcTy.getEncoding()); - std::optional dstLayout = - toLinearLayout(shape, dstTy.getEncoding()); - if (!srcLayout.has_value() || !dstLayout.has_value()) { - return failure(); - } - // There are four cases to handle. - // - // 1. Transfer between values in the same thread, in which case we simply - // reorder the elements of adaptor.getSrc(). - // 2. Transfer between values in the same warp, in which case we try to - // move values using warp shuffles, though if the pattern is complicated - // enough we may fall back to using shared memory (case 3). - // 3. Transfer between values in the same CTA, in which case we move values - // through shared memory. - // 4. Transfer between values in different CTAs, in which case we move - // values through distributed shared memory. - // - // We can tell which case we're in by examining `conversion`. - // For example, if the block -> block mapping is an identity layout: {1, 2, - // 4, ...}, then there's no movement between data in different CTAs, and we - // know we're not in case 4. - if (cvtReordersRegisters(srcTy, dstTy)) { // Case 1. - return transferWithinThread(op, *srcLayout, *dstLayout, adaptor, - rewriter); + auto conversion = minimalCvtLayout(srcTy, dstTy); + if (!conversion.has_value()) { + return rewriter.notifyMatchFailure( + op, "NYI. srcTy and/or dstTy don't implement LLs yet"); } - if (cvtNeedsWarpShuffle(srcTy, dstTy)) { // Case 2. - return transferWithinLane(op, *srcLayout, *dstLayout, adaptor, rewriter); + assert(to_vector(conversion->getInDimNames()) == + to_vector(conversion->getOutDimNames())); + auto dims = conversion->getInDimNames(); + if (llvm::is_contained(dims, str_attr("block"))) { + // Case 1: Transfer between values in different CTAs. + // This requires moving values through distributed shared memory. + return rewriter.notifyMatchFailure( + op, "NYI: Transfer between different CTAs"); + } else if (llvm::is_contained(dims, str_attr("warp"))) { + // Case 2: Transfer between values in the same CTA, in which case we move + // values through shared memory. + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); + } else if (llvm::is_contained(dims, str_attr("lane"))) { + // Case 3. Transfer between values in the same warp, in which case we try + // to move values using warp shuffles, though if the pattern is + // complicated enough we may fall back to using shared memory + // TODO(Keren): implement warp shuffle instead of using the general + // approach that uses shared memory + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); + return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); + } else if (llvm::is_contained(dims, str_attr("register"))) { + // Case 4. Transfer between values in the same thread, in which case we + // simply reorder the elements of adaptor.getSrc(). + return transferWithinThread(op, *conversion, adaptor, rewriter); + } else { + // The two layouts are equivalent. We should probably remove these in + // RemoveLayoutConversion. + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); } - - return transferWithinBlockOrGroup(op, *srcLayout, *dstLayout, adaptor, - rewriter); // Case 3 and 4 } LogicalResult - transferWithinThread(ConvertLayoutOp op, const LinearLayout &srcLayout, - const LinearLayout &dstLayout, OpAdaptor adaptor, + transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); StringAttr kRegister = str_attr("register"); - StringAttr kLane = str_attr("lane"); - StringAttr kWarp = str_attr("warp"); - StringAttr kBlock = str_attr("block"); - - // There are three possible cases: - // - // 1. `srcLayout` has the same number of registers as `dstLayout`. - // 2. `srcLayout` has fewer registers than `dstLayout`. - // 3. `srcLayout` has more registers than `dstLayout`. - // - // In the second case `srcLayout . dstLayout^-1` is not surjective - // because not all destination registers are covered. - // Since the goal is to cover all of the destination - // registers, we can instead use `dstLayout . srcLayout^-1`. - LinearLayout conversion = dstLayout.invertAndCompose(srcLayout); - auto dstToSrc = conversion.divideRight( - LinearLayout::identity1D(conversion.getInDimSize(kLane), kLane, kLane) * - LinearLayout::identity1D(conversion.getInDimSize(kWarp), kWarp, kWarp) * - LinearLayout::identity1D(conversion.getInDimSize(kBlock), kBlock, - kBlock)); - assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); - assert(ArrayRef(to_vector(dstToSrc->getInDimNames())) == - ArrayRef{kRegister}); - assert(ArrayRef(to_vector(dstToSrc->getOutDimNames())) == - ArrayRef{kRegister}); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); SmallVector outVals; - outVals.resize(dstToSrc->getInDimSize(kRegister)); - for (int i = 0; i < dstToSrc->getInDimSize(kRegister); i++) { - auto srcIdx = dstToSrc->apply({{kRegister, i}}); - outVals[i] = inVals[srcIdx.begin()->second]; + outVals.resize(conversion.getInDimSize(kRegister)); + for (int i = 0; i < conversion.getInDimSize(kRegister); i++) { + auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second; + outVals[i] = inVals[srcIdx]; } Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); @@ -366,27 +350,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return success(); } - LogicalResult transferWithinLane(ConvertLayoutOp op, - const LinearLayout &srcLayout, - const LinearLayout &dstLayout, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // TODO(Keren): implement warp shuffle instead of using the general approach - // that uses shared memory - return transferWithinBlockOrGroup(op, srcLayout, dstLayout, adaptor, - rewriter); - } - - LogicalResult - transferWithinBlockOrGroup(ConvertLayoutOp op, const LinearLayout &srcLayout, - const LinearLayout &dstLayout, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - LinearLayout conversion = srcLayout.invertAndCompose(dstLayout); - - // TODO(Keren): LLs support cross-CTA conversions, this function does not - if (isCrossCTAConversion(conversion)) - return failure(); - + LogicalResult transferWithinBlock(ConvertLayoutOp op, + const LinearLayout &srcLayout, + const LinearLayout &dstLayout, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); auto srcTy = op.getSrc().getType(); @@ -445,11 +413,14 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } } + // Pretty sure this is the identity function ATM + // It'd be better to simply call `quotient({kBlock})` and + // remove kBlock from transferWithinBlockImpl auto srcLayoutWithinBlock = getLayoutWithinBlock(srcLayout); auto dstLayoutWithinBlock = getLayoutWithinBlock(dstLayout); SmallVector outVals = - transferWithinBlock(inVals, op, srcLayoutWithinBlock, - dstLayoutWithinBlock, adaptor, rewriter); + transferWithinBlockImpl(inVals, op, srcLayoutWithinBlock, + dstLayoutWithinBlock, adaptor, rewriter); // Unmunge output values for (const auto &it : llvm::enumerate(outVals)) { @@ -467,10 +438,10 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } SmallVector - transferWithinBlock(ArrayRef inVals, ConvertLayoutOp op, - const LinearLayout &srcLayout, - const LinearLayout &dstLayout, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + transferWithinBlockImpl(ArrayRef inVals, ConvertLayoutOp op, + const LinearLayout &srcLayout, + const LinearLayout &dstLayout, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 2839b36680..77b26d90a9 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -939,25 +939,6 @@ toLinearLayout(ArrayRef shape, Attribute layout, return std::nullopt; } -bool isCrossCTAConversion(const LinearLayout &layout) { - assert(!layout.getInDimNames().empty()); - MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); - - StringAttr kBlock = S("block"); - assert(layout.hasInDim(kBlock)); - assert(layout.hasOutDim(kBlock)); - - SetVector nonBlockInDims(layout.getInDimNames().begin(), - layout.getInDimNames().end()); - nonBlockInDims.remove(kBlock); - - // This layout moves data between CTAs if - // - the value for any input dim other than block affects the output block, or - // - input (0, ..., block=i) does not map to output (0, ..., block=i). - return !layout.sublayoutIsZero(nonBlockInDims.getArrayRef(), {kBlock}) || - !layout.sublayoutIsIdentity({kBlock}, {kBlock}); -} - LinearLayout getLayoutWithinBlock(const LinearLayout &layout) { assert(!layout.getInDimNames().empty()); MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 460792439d..bf017f8c64 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -657,153 +657,62 @@ LinearLayout operator*(LinearLayout inner, LinearLayout outer) { inner.isSurjective() && outer.isSurjective()); } -std::optional -LinearLayout::divideRight(const LinearLayout &divisor) const { - assertCommonDimsSameOrder(getOutDimNames(), divisor.getOutDimNames()); - assertCommonDimsSameOrder(getInDimNames(), divisor.getInDimNames()); - - // Strip off the top N bases for each input dimension of divisor. This - // gives a candidate quotient. Then check if quotient * divisor equals - // `this`. - BasesT newBases = bases; - for (StringAttr inDim : divisor.getInDimNames()) { - if (getInDimSizeLog2(inDim) < divisor.getInDimSizeLog2(inDim)) { - return std::nullopt; - } - auto &newInDimBases = newBases[inDim]; - newInDimBases.resize(newInDimBases.size() - - divisor.getInDimSizeLog2(inDim)); - } - - // Check if the size of the new out-dims are large enough. - // If yes, we can divide the out-dims. - // If no, we return nullopt to indicate that the division is not possible. - llvm::MapVector newOutDims = outDims; - for (const auto [outDimName, outDimSize] : divisor.outDims) { - if (newOutDims[outDimName] < outDimSize) { - return std::nullopt; +bool LinearLayout::isTrivialOver(ArrayRef dimNames) const { + for (StringAttr dim : dimNames) { + if (!llvm::is_contained(getInDimNames(), dim) && + !llvm::is_contained(getOutDimNames(), dim)) { + return false; } - newOutDims[outDimName] /= outDimSize; - } - - LDBG("Checking candidate_quotient * divisor == *this"); - LDBG("this:" << *this); - LDBG("divisor:" << divisor); - LDBG("newBases: " << triton::join(newBases, ", ", [](auto &p) { - return p.first.str() + "=" + std::to_string(p.second.size()); - })); - LDBG("newOutDims: " << triton::join(newOutDims, ", ", [](auto &p) { - return p.first.str() + "=" + std::to_string(p.second); - })); - std::optional candidateQuotient = LinearLayout::tryCreate( - std::move(newBases), std::move(newOutDims.takeVector()), - /*requireSurjective=*/false); - LDBG("candidate_quotient:" << candidateQuotient); - if (!candidateQuotient.has_value()) { - LDBG("candidate quotient failed invariant checks"); - return std::nullopt; - } - LDBG("*candidate_quotient * divisor=" << *candidateQuotient * divisor); - if (*candidateQuotient * divisor != *this) { - LDBG("candidate quotient failed invariant checks"); - return std::nullopt; } - // Now that we have a candidate quotient, we need to eliminate any empty - // dimensions from the candidate quotient but still ensure that - // quotient * divisor == *this. - newBases = candidateQuotient->bases; - newOutDims = candidateQuotient->outDims; - - // We only remove the trailing empty output dimensions from `quotient`. - // - // In the multiplication `quotient * divisor == result`, the output dimensions - // of `quotient` always come before those of `divisor` in `result`. Removing - // any non-trailing empty dimensions from `quotient` would change the - // order of the output dimensions in `result`. - // - // The following loop iterates through the output dimensions of `result` from - // right to left. During the iteration, the following conditions are checked: - // - // 1. If an output dimension exists only in `divisor` and not in `quotient`, - // the loop continues. - // 2. If an output dimension exists only in `quotient` and not in `divisor`, - // we stop the loop. - // 3. If an output dimension exists in both `quotient` and `divisor`, it may - // be removed, but only if it is a size-1 dimension and meets one of the - // following conditions: - // - The dimension immediately following it in `quotient` has already been - // removed. - // - It is the last dimension of `quotient`. - // Otherwise, removing this dimension could alter the structure of `result`. - // - // Consider the quotient l = o / r, where: - // out-dims(o) = ["out0", "out1", "out2", "out3"] - // out-dims(r) = ["out1", "out3"] - // - // Only "out1" is a size-1 dimension. If we remove "out1" from o, the - // resulting output dimensions would be: - // out-dims(l) = ["out0", "out2", "out3"] - // - // Performing the multiplication l * r results in: - // out-dims(l * r) = ["out0", "out2", "out3"] * ["out1", "out3"] = ["out0", - // "out2", "out3", "out1"] - // This outcome does not match the original out-dims(o). - // - // However, if we remove only "out3" from o, we get: - // out-dims(l) = ["out0", "out1", "out2"] - // - // Then, performing the multiplication l * r yields: - // out-dims(l * r) = ["out0", "out1", "out2"] * ["out1", "out3"] = ["out0", - // "out1", "out2", "out3"] - // This result matches the original out-dims(o). - llvm::SmallVector emptyOutDimIndices; - for (const auto [outDimName, outDimSize] : llvm::reverse(outDims)) { - if (newOutDims.contains(outDimName) && !divisor.hasOutDim(outDimName)) { - break; - } - if (newOutDims.contains(outDimName) && divisor.hasOutDim(outDimName) && - candidateQuotient->getOutDimSize(outDimName) == 1) { - auto lastOutDimName = newOutDims.rbegin()->first; - if (outDimName != lastOutDimName) { - break; + auto getRemainingDimNames = [&](auto allDimNames) { + SmallVector remainingDimNames; + for (StringAttr dim : allDimNames) { + if (!llvm::is_contained(dimNames, dim)) { + remainingDimNames.push_back(dim); } - emptyOutDimIndices.push_back(getOutDimIndex(outDimName)); - newOutDims.erase(outDimName); } + return remainingDimNames; + }; + SmallVector remainingInDimNames = + getRemainingDimNames(getInDimNames()); + SmallVector remainingOutDimNames = + getRemainingDimNames(getOutDimNames()); + + // Think of this as a block-matrix multiplying a vector: + // [[A, B], * [v_1, + // [C, D]] v_2] + // where v_2 is the dimNames and v_1 is the remainingInDimNames + // We can quotient out dimNames iff they don't affect the remainingInDimNames + // in the result. In other words, we want to check that B is zero, and C is + // zero, and D is the identity + return squareSublayoutIsIdentity(dimNames) && + sublayoutIsZero(remainingInDimNames, dimNames) && + sublayoutIsZero(dimNames, remainingOutDimNames); +} + +std::optional +LinearLayout::quotient(ArrayRef dimNames) const { + if (!isTrivialOver(dimNames)) { + return std::nullopt; } - // Erase the basis elements corresponding to the empty out-dims. - for (auto &[inDim, inDimBases] : newBases) { - for (auto &basis : inDimBases) { - for (int i : emptyOutDimIndices) { - basis.erase(basis.begin() + i); + // This should probably be even less general, where we ask inDimNames == + // outDimNames + auto getRemainingDimNames = [&](auto allDimNames) { + SmallVector remainingDimNames; + for (StringAttr dim : allDimNames) { + if (!llvm::is_contained(dimNames, dim)) { + remainingDimNames.push_back(dim); } } - } + return remainingDimNames; + }; - // Erase trailing empty in-dims. - for (auto inDimName : llvm::reverse(getInDimNames())) { - if (newBases[inDimName].empty() && divisor.hasInDim(inDimName)) { - newBases.erase(inDimName); - } else { - break; - } - } + SmallVector inDimNames = getRemainingDimNames(getInDimNames()); + SmallVector outDimNames = getRemainingDimNames(getOutDimNames()); - LDBG("Eliminated empty dims from candidate_quotient"); - LDBG("newBases: " << triton::join(newBases, ", ", [](auto &p) { - return p.first.str() + "=" + std::to_string(p.second.size()); - })); - LDBG("newOutDims: " << triton::join(newOutDims, ", ", [](auto &p) { - return p.first.str() + "=" + std::to_string(p.second); - })); - auto quotient = LinearLayout::tryCreate(std::move(newBases), - std::move(newOutDims).takeVector(), - /*requireSurjective=*/false); - LDBG("quotient:" << quotient); - assert(quotient.has_value()); - return quotient; + return sublayout(inDimNames, outDimNames); } LinearLayout LinearLayout::sublayout(ArrayRef inDimNames, @@ -813,10 +722,10 @@ LinearLayout LinearLayout::sublayout(ArrayRef inDimNames, SmallDenseSet inDimSet(inDimNames.begin(), inDimNames.end()); SmallDenseSet outDimSet(outDimNames.begin(), outDimNames.end()); - SmallDenseSet outDimIndicesToKeep; + SmallVector outDimIndicesToKeep; for (auto [i, outDim] : llvm::enumerate(getOutDimNames())) { if (outDimSet.contains(outDim)) { - outDimIndicesToKeep.insert(i); + outDimIndicesToKeep.push_back(i); } } BasesT newBases; @@ -856,13 +765,22 @@ bool LinearLayout::sublayoutIsZero(ArrayRef inDimNames, return true; } -bool LinearLayout::sublayoutIsIdentity(ArrayRef inDimNames, - ArrayRef outDimNames) const { - LinearLayout sl = - sublayout(inDimNames, outDimNames).flattenIns().flattenOuts(); - if (sl.getNumInDims() == 0 || sl.getNumOutDims() == 0) { +bool LinearLayout::squareSublayoutIsIdentity( + ArrayRef dimNames) const { + // The empty layout is the identity + if (dimNames.size() == 0) { return true; } + // Check that the input-output sizes are the same + LinearLayout sl = sublayout(dimNames, dimNames); + for (StringAttr dim : dimNames) { + if (getInDimSize(dim) != getOutDimSize(dim)) { + return false; + } + } + // Once the inputs and output dimensions are the same, we can just check + // that the basis for the single remaining dimension is the identity. + sl = sl.flattenIns().flattenOuts(); int b = 0; const auto &inDimBases = sl.bases.begin()->second; for (auto basis : inDimBases) { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 9296962983..c341963649 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -60,7 +60,7 @@ warpsPerTile(Operation *dotOp, ArrayRef shape, int numWarps, return {(unsigned)numWarps, 1}; SmallVector tensorShape = {shape[0], shape[1]}; - SmallVector ret = {1, 1}; + SmallVector ret = {1, 1}; do { if (ret[0] * ret[1] >= numWarps) break; diff --git a/unittest/Tools/LinearLayoutTest.cpp b/unittest/Tools/LinearLayoutTest.cpp index 8be680562c..f006447002 100644 --- a/unittest/Tools/LinearLayoutTest.cpp +++ b/unittest/Tools/LinearLayoutTest.cpp @@ -561,104 +561,6 @@ TEST_F(LinearLayoutTest, NumConsecutiveInOut) { .getNumConsecutiveInOut()); } -TEST_F(LinearLayoutTest, DivideRight_Simple) { - EXPECT_EQ(LinearLayout::identity1D(8, S("in"), S("out")) - .divideRight(LinearLayout::identity1D(4, S("in"), S("out"))), - LinearLayout::identity1D(2, S("in"), S("out"))); - - EXPECT_EQ(LinearLayout::identity1D(8, S("in"), S("out")) - .divideRight(LinearLayout::identity1D(8, S("in"), S("out"))), - LinearLayout::empty()); -} - -TEST_F(LinearLayoutTest, DivideRight_2D) { - LinearLayout l1( - { - {S("in1"), {{1, 1}, {2, 2}, {0, 8}, {0, 4}}}, - {S("in2"), {{0, 2}, {0, 1}}}, - }, - {S("out1"), S("out2")}); - LinearLayout l2({{S("in1"), {{2}, {1}}}}, {S("out2")}); - LinearLayout l3( - { - {S("in1"), {{1, 1}, {2, 2}}}, - {S("in2"), {{0, 2}, {0, 1}}}, - }, - {S("out1"), S("out2")}); - ASSERT_EQ(l1.divideRight(l2), l3); - EXPECT_EQ(l1.divideRight(l2).value() * l2, l1); -} - -TEST_F(LinearLayoutTest, DivideRight_EliminateInDim) { - LinearLayout l1( - { - {S("in2"), {{0, 1}, {1, 0}}}, - {S("in1"), {{2, 0}, {0, 2}}}, - }, - {S("out1"), S("out2")}); - LinearLayout l2({{S("in1"), {{1, 0}, {0, 1}}}}, {S("out1"), S("out2")}); - LinearLayout l3({{S("in2"), {{0, 1}, {1, 0}}}}, {S("out1"), S("out2")}); - ASSERT_EQ(l3 * l2, l1); - EXPECT_EQ(l1.divideRight(l2), l3); - - LinearLayout l4({{S("in1"), {{0, 1}, {0, 2}}}, {S("in2"), {}}}, - {S("out1"), S("out2")}); - LinearLayout l5({{S("in1"), {{0, 1}, {0, 2}}}}, {S("out1"), S("out2")}); - LinearLayout l6({{S("in2"), {}}}, {S("out1"), S("out2")}); - ASSERT_EQ(l5 * l6, l4); - EXPECT_EQ(l4.divideRight(l6), l5); - - LinearLayout l7({{S("in1"), {}}, {S("in2"), {{0, 1}}}, {S("in3"), {}}}, - {S("out1"), S("out2")}); - LinearLayout l8({{S("in2"), {{0, 1}}}}, {S("out1"), S("out2")}); - LinearLayout l9({{S("in1"), {}}, {S("in2"), {}}, {S("in3"), {}}}, {}); - ASSERT_EQ(l9 * l8, l7); - EXPECT_EQ(l7.divideRight(l8), l9); -} - -TEST_F(LinearLayoutTest, DivideRight_EliminateOutDim) { - LinearLayout l1( - { - {S("in2"), {{1, 0}, {1, 0}}}, - {S("in1"), {{2, 0}, {0, 1}}}, - }, - {S("out1"), S("out2")}); - LinearLayout l2({{S("in1"), {{1, 0}, {0, 1}}}}, {S("out1"), S("out2")}); - LinearLayout l3({{S("in2"), {{1}, {1}}}}, {S("out1")}); - ASSERT_EQ(l3 * l2, l1); - EXPECT_EQ(l1.divideRight(l2), l3); - - LinearLayout l4( - { - {S("in1"), {{0, 1}, {0, 2}}}, - }, - {S("out1"), S("out2")}); - LinearLayout l5({{S("in1"), {{1}, {2}}}}, {S("out2")}); - using BasesArray = - ArrayRef>>>; - LinearLayout l6(BasesArray{}, {S("out1")}); - ASSERT_EQ(l6 * l5, l4); - EXPECT_EQ(l4.divideRight(l5), l6); -} - -TEST_F(LinearLayoutTest, DivideRight_Assertion) { - LinearLayout l1({{S("register"), - {{0, 1, 0, 0}, {0, 2, 0, 0}, {0, 0, 2, 0}, {1, 0, 0, 0}}}, - {S("lane"), - {{0, 4, 0, 0}, - {0, 8, 0, 0}, - {0, 16, 0, 0}, - {0, 0, 1, 0}, - {2, 0, 0, 0}}}, - {S("warp"), {{4, 0, 0, 0}, {8, 0, 0, 0}}}, - {S("block"), {}}}, - {S("register"), S("lane"), S("warp"), S("block")}); - LinearLayout l2 = LinearLayout::identity1D(32, S("lane"), S("lane")) * - LinearLayout::identity1D(4, S("warp"), S("warp")) * - LinearLayout::identity1D(1, S("block"), S("block")); - EXPECT_EQ(l1.divideRight(l2), std::nullopt); -} - TEST_F(LinearLayoutTest, EqualsChecksOutDimSizes) { EXPECT_FALSE(LinearLayout::identity1D(4, S("in"), S("out")) == LinearLayout({{S("in"), {{1}, {2}}}}, {{S("out"), 8}}, @@ -710,52 +612,33 @@ TEST_F(LinearLayoutTest, SublayoutIsZero) { EXPECT_FALSE(l1.sublayoutIsZero({S("in2")}, {S("out2")})); } -TEST_F(LinearLayoutTest, SublayoutIsIdentity) { - EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) - .sublayoutIsIdentity({S("in")}, {S("out")})); - EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) - .sublayoutIsIdentity({}, {S("out")})); - EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) - .sublayoutIsIdentity({S("in")}, {})); - EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("out")) - .sublayoutIsIdentity({}, {})); +TEST_F(LinearLayoutTest, SquareSublayoutIsIdentity) { + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("in")) + .squareSublayoutIsIdentity({S("in")})); + EXPECT_TRUE(LinearLayout::identity1D(4, S("in"), S("in")) + .squareSublayoutIsIdentity({})); LinearLayout l1( {{S("in1"), {{1, 1}, {2, 2}, {4, 4}}}, {S("in2"), {{2, 1}, {1, 2}}}}, - {{S("out1"), 8}, {S("out2"), 8}}, /*requireSurjective=*/false); - EXPECT_FALSE(l1.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1")})); - EXPECT_FALSE(l1.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out2")})); - EXPECT_FALSE(l1.sublayoutIsIdentity({S("in1")}, {S("out1"), S("out2")})); - EXPECT_FALSE(l1.sublayoutIsIdentity({S("in1")}, {S("out2"), S("out1")})); - EXPECT_TRUE(l1.sublayoutIsIdentity({S("in1")}, {S("out1")})); - EXPECT_TRUE(l1.sublayoutIsIdentity({S("in1")}, {S("out2")})); - EXPECT_FALSE(l1.sublayoutIsIdentity({S("in2")}, {S("out1")})); - EXPECT_TRUE(l1.sublayoutIsIdentity({S("in2")}, {S("out2")})); - - LinearLayout l2 = - LinearLayout::identity1D(4, S("in1"), S("out1")) * - LinearLayout::identity1D(8, S("in2"), S("out2")) * - LinearLayout({{S("in3"), {{1, 1, 1}}}}, - {{S("out1"), 2}, {S("out2"), 2}, {S("out3"), 2}}, - /*requireSurjective=*/false); - EXPECT_TRUE(l2.sublayoutIsIdentity({S("in1")}, {S("out1")})); - EXPECT_TRUE(l2.sublayoutIsIdentity({S("in2")}, {S("out2")})); - EXPECT_TRUE(l2.sublayoutIsIdentity({S("in3")}, {S("out3")})); - EXPECT_FALSE( - l2.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1"), S("out2")})); - EXPECT_FALSE(l2.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1")})); - EXPECT_TRUE(l2.sublayoutIsIdentity({S("in1"), S("in3")}, {S("out1")})); - - LinearLayout l3 = LinearLayout::identity1D(4, S("in1"), S("out1")) * - LinearLayout::identity1D(8, S("in2"), S("out2")); - EXPECT_TRUE(l3.sublayoutIsIdentity({S("in1")}, {S("out1")})); - EXPECT_TRUE(l3.sublayoutIsIdentity({S("in2")}, {S("out2")})); - EXPECT_FALSE(l3.sublayoutIsIdentity({S("in1")}, {S("out2")})); - EXPECT_FALSE(l3.sublayoutIsIdentity({S("in2")}, {S("out1")})); - EXPECT_FALSE(l3.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1")})); - EXPECT_FALSE(l3.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out2")})); - EXPECT_TRUE( - l3.sublayoutIsIdentity({S("in1"), S("in2")}, {S("out1"), S("out2")})); + {{S("in1"), 8}, {S("in2"), 8}}, /*requireSurjective=*/false); + EXPECT_TRUE(l1.squareSublayoutIsIdentity({S("in1")})); + EXPECT_FALSE(l1.squareSublayoutIsIdentity({S("in2")})); + + LinearLayout l2 = LinearLayout::identity1D(4, S("in1"), S("in1")) * + LinearLayout::identity1D(8, S("in2"), S("in2")) * + LinearLayout({{S("in3"), {{1, 1, 1}}}}, + {{S("in1"), 2}, {S("in2"), 2}, {S("in3"), 2}}, + /*requireSurjective=*/false); + EXPECT_FALSE(l2.squareSublayoutIsIdentity({S("in1")})); + EXPECT_FALSE(l2.squareSublayoutIsIdentity({S("in2")})); + EXPECT_TRUE(l2.squareSublayoutIsIdentity({S("in3")})); + EXPECT_FALSE(l2.squareSublayoutIsIdentity({S("in1"), S("in2")})); + + LinearLayout l3 = LinearLayout::identity1D(4, S("in1"), S("in1")) * + LinearLayout::identity1D(8, S("in2"), S("in2")); + EXPECT_TRUE(l3.squareSublayoutIsIdentity({S("in1")})); + EXPECT_TRUE(l3.squareSublayoutIsIdentity({S("in2")})); + EXPECT_TRUE(l3.squareSublayoutIsIdentity({S("in1"), S("in2")})); } TEST_F(LinearLayoutTest, FreeVariableMasks) { @@ -788,6 +671,82 @@ TEST_F(LinearLayoutTest, FreeVariableMasks) { AR({{S("in1"), 0b100}, {S("in2"), 0b10}})); } +TEST_F(LinearLayoutTest, QuotientOneDimension) { + LinearLayout layout( + { + {S("dim1"), {{1, 0}}}, + {S("dim2"), {{0, 0}}}, + }, + {{S("dim1"), 2}, {S("dim2"), 1}}, /*requireSurjective=*/false); + + // Quotient over dim1, which is trivial + auto quotientLayout = layout.quotient({S("dim1")}); + ASSERT_TRUE(quotientLayout.has_value()); + EXPECT_EQ(*quotientLayout, LinearLayout::zeros1D(2, S("dim2"), S("dim2"))); + // dim2 is zero, not the identity + ASSERT_FALSE(quotientLayout->quotient({S("dim2")}).has_value()); +} + +TEST_F(LinearLayoutTest, QuotientSeveralDimensions) { + LinearLayout layout( + { + {S("dim1"), {{1, 0}, {2, 0}, {4, 0}}}, + {S("dim2"), {{0, 1}, {0, 2}}}, + }, + {S("dim1"), S("dim2")}); + + auto quotientLayout = layout.quotient({S("dim1"), S("dim2")}); + EXPECT_TRUE(quotientLayout.has_value()); +} + +TEST_F(LinearLayoutTest, QuotientMultipleTrivialDimensions) { + LinearLayout layout( + { + {S("dim1"), {{1, 0, 2}, {2, 0, 1}}}, + {S("dim2"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("dim3"), {{0, 0, 1}, {0, 0, 2}}}, + }, + {S("dim1"), S("dim2"), S("dim3")}); + + // Quotient over dim2 is trivial, even if there's some funny business + // going on in the other dimensions + auto quotientLayout = layout.quotient({S("dim2")}); + ASSERT_TRUE(quotientLayout.has_value()); + + layout = LinearLayout( + { + {S("dim1"), {{1, 0, 2}, {2, 0, 1}}}, + {S("dim2"), {{0, 1, 0}, {0, 2, 0}, {0, 4, 0}}}, + {S("dim3"), {{0, 1, 1}, {0, 0, 2}}}, + }, + {S("dim1"), S("dim2"), S("dim3")}); + + // As soon as one maps into the dimension being quotiented or out of it + // (in this case dim3 depends on dim2), we cannot quotient + quotientLayout = layout.quotient({S("dim2")}); + ASSERT_FALSE(quotientLayout.has_value()); +} + +TEST_F(LinearLayoutTest, QuotientEmptyLayout) { + LinearLayout layout = LinearLayout::empty(); + + // Quotienting over a dimension that doesn't exist is invalid + auto quotientLayout = layout.quotient({S("dim1")}); + ASSERT_FALSE(quotientLayout.has_value()); +} + +TEST_F(LinearLayoutTest, QuotientIdentityMultipleDimensions) { + // Test quotient on identity layout with multiple dimensions + LinearLayout layout = LinearLayout::identity1D(8, S("dim1"), S("dim1")) * + LinearLayout::identity1D(2, S("dim2"), S("dim2")) * + LinearLayout::identity1D(4, S("dim3"), S("dim3")); + + // We can quotient over all dimensions in any order + auto quotientLayout = layout.quotient({S("dim1"), S("dim3")}); + ASSERT_TRUE(quotientLayout.has_value()); + ASSERT_TRUE(quotientLayout->quotient({S("dim2")}).has_value()); +} + } // anonymous namespace } // namespace mlir::triton From 7370611d4f195b45b28cc648a17c9e49fd177042 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Mon, 28 Oct 2024 11:40:55 +0000 Subject: [PATCH 2/2] [TritonIntelGPUToLLVM] Adapt layout conversion to new LL interface Replace `divideRight` calls with `quotient` calls and simplify code following upstream model. Signed-off-by: victor-eds --- test/Conversion/intel/tritongpu_to_gen.mlir | 1 - .../ConvertLayoutOpToLLVM.cpp | 179 +++++++----------- 2 files changed, 73 insertions(+), 107 deletions(-) diff --git a/test/Conversion/intel/tritongpu_to_gen.mlir b/test/Conversion/intel/tritongpu_to_gen.mlir index 28b4a81e4f..d83a0b4b25 100644 --- a/test/Conversion/intel/tritongpu_to_gen.mlir +++ b/test/Conversion/intel/tritongpu_to_gen.mlir @@ -1752,7 +1752,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : // CHECK-NOT: llvm.store // CHECK-NOT: llvm.load // CHECK: llvm.insertvalue - // CHECK: llvm.extractvalue tt.func public @convert_single_element() attributes {noinline = false} { %cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1> %0 = triton_gpu.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked> diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index aaec293bac..18144d91fc 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -455,9 +455,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion StringAttr kBlock = str_attr("block"); LinearLayout comp = dstLayout.invertAndCompose(srcLayout); - std::optional conversion = comp.divideRight( - LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) * - LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock)); + std::optional conversion = + comp.quotient(kBlock)->quotient(kWarp); assert(conversion && "Expecting valid conversion"); // Expected conversion is: // - register=1 -> (0, 1) @@ -516,85 +515,87 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion const auto &shape = op.getType().getShape(); auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); - std::optional srcLayout = - toLinearLayout(shape, srcTy.getEncoding()); - std::optional dstLayout = - toLinearLayout(shape, dstTy.getEncoding()); - if (!srcLayout.has_value() || !dstLayout.has_value()) { - return failure(); - } - // There are four cases to handle. - // - // 1. Transfer between values in the same thread, in which case we simply - // reorder the elements of adaptor.getSrc(). - // 2. Transfer between values in the same warp, in which case we try to - // move values using warp shuffles, though if the pattern is complicated - // enough we may fall back to using shared memory (case 3). - // 3. Transfer between values in the same CTA, in which case we move values - // through shared memory. - // 4. Transfer between values in different CTAs, in which case we move - // values through distributed shared memory. - // - // We can tell which case we're in by examining `conversion`. - // For example, if the block -> block mapping is an identity layout: {1, 2, - // 4, ...}, then there's no movement between data in different CTAs, and we - // know we're not in case 4. - if (cvtReordersRegisters(srcTy, dstTy)) { // Case 1. - return transferWithinThread(op, *srcLayout, *dstLayout, adaptor, - rewriter); + auto conversion = minimalCvtLayout(srcTy, dstTy); + if (!conversion.has_value()) { + return rewriter.notifyMatchFailure( + op, "NYI. srcTy and/or dstTy don't implement LLs yet"); } + LinearLayout srcLayout = + *toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + LinearLayout dstLayout = + *toLinearLayout(dstTy.getShape(), dstTy.getEncoding()); - if (cvtNeedsWarpShuffle(srcTy, dstTy)) { // Case 2. - return transferWithinLane(op, *srcLayout, *dstLayout, adaptor, rewriter); - } + StringAttr kBlock = str_attr("block"); + StringAttr kWarp = str_attr("warp"); + StringAttr kLane = str_attr("lane"); + StringAttr kRegister = str_attr("register"); - // TODO: match transferWithinBlockOrGroup from - // TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp - return transferWithinBlockGroup(op, *srcLayout, *dstLayout, adaptor, - rewriter); + assert(to_vector(conversion->getInDimNames()) == + to_vector(conversion->getOutDimNames())); + auto dims = conversion->getInDimNames(); + if (llvm::is_contained(dims, str_attr("block"))) { + // Case 1: Transfer between values in different CTAs. + // This requires moving values through distributed shared memory. + return rewriter.notifyMatchFailure( + op, "NYI: Transfer between different CTAs"); + } else if (llvm::is_contained(dims, str_attr("warp"))) { + return rewriter.notifyMatchFailure( + op, "NYI: Transfer between different warps"); + } else if (llvm::is_contained(dims, str_attr("lane"))) { + // Case 2: Transfer between values in the same CTA, in which case we move + // values through shared memory. + // If the operation is a supported sub-group shuffle, perform via shuffle + // operations. + if (isSubGroupShuffle(srcLayout, dstLayout) && + isSupportedSubGroupShuffle(op, adaptor)) { + performSubGroupShuffle(op, srcLayout, dstLayout, adaptor, rewriter); + return success(); + } + // If the operation is a supported sub-group transposition, perform via + // SLM. + if (isSubGroupTranspose(srcLayout, dstLayout) && + isSupportedSubGroupTranspose(op, adaptor)) { + performSubGroupTranspose(op, srcLayout, dstLayout, adaptor, rewriter); + return success(); + } + // TODO(jlebar): Implement me. + return failure(); + } else if (llvm::is_contained(dims, str_attr("register"))) { + // Case 4. Transfer between values in the same thread, in which case we + // simply reorder the elements of adaptor.getSrc(). + return transferWithinThread( + op, dstLayout.getFreeVariableMasks()[kRegister], + dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter); + } else { + // The two layouts are equivalent. We should probably remove these in + // RemoveLayoutConversion. + rewriter.replaceOp(op, adaptor.getSrc()); + return success(); + } } LogicalResult - transferWithinThread(ConvertLayoutOp op, const LinearLayout &srcLayout, - const LinearLayout &dstLayout, OpAdaptor adaptor, + transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs, + const LinearLayout &conversion, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); StringAttr kRegister = str_attr("register"); - StringAttr kLane = str_attr("lane"); - StringAttr kWarp = str_attr("warp"); - StringAttr kBlock = str_attr("block"); - - // There are three possible cases: - // - // 1. `srcLayout` has the same number of registers as `dstLayout`. - // 2. `srcLayout` has fewer registers than `dstLayout`. - // 3. `srcLayout` has more registers than `dstLayout`. - // - // In the second case `srcLayout . dstLayout^-1` is not surjective - // because not all destination registers are covered. - // Since the goal is to cover all of the destination - // registers, we can instead use `dstLayout . srcLayout^-1`. - LinearLayout conversion = dstLayout.invertAndCompose(srcLayout); - auto dstToSrc = conversion.divideRight( - LinearLayout::identity1D(conversion.getInDimSize(kLane), kLane, kLane) * - LinearLayout::identity1D(conversion.getInDimSize(kWarp), kWarp, kWarp) * - LinearLayout::identity1D(conversion.getInDimSize(kBlock), kBlock, - kBlock)); - assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); - assert(ArrayRef(to_vector(dstToSrc->getInDimNames())) == - ArrayRef{kRegister}); - assert(ArrayRef(to_vector(dstToSrc->getOutDimNames())) == - ArrayRef{kRegister}); auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - SmallVector outVals; - outVals.resize(dstToSrc->getInDimSize(kRegister)); - for (int i = 0; i < dstToSrc->getInDimSize(kRegister); i++) { - auto srcIdx = dstToSrc->apply({{kRegister, i}}); - outVals[i] = inVals[srcIdx.begin()->second]; + SmallVector outVals(numRegs); + for (int i = 0; i < outVals.size(); i++) { + // Remove free masks from the register index + // For example, if idx = 0b00111, and masks = 0b00100, then we get + // 0b00011. It means that register 7 (0b111) has the same value as + // register 3 (0b011). + auto idx = i & (~regMasks); + auto srcIdx = conversion.hasInDim(kRegister) + ? conversion.apply({{kRegister, idx}}).begin()->second + : idx; + outVals[i] = inVals[srcIdx]; } Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, op.getType()); @@ -611,9 +612,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion StringAttr kBlock = str_attr("block"); LinearLayout comp = dstLayout.invertAndCompose(srcLayout); - std::optional conversion = comp.divideRight( - LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) * - LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock)); + std::optional conversion = + comp.quotient(kBlock)->quotient(kWarp); assert(conversion && "Expecting valid conversion"); // TODO: Support more kind of shuffles. // Expected conversion is: @@ -667,11 +667,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion StringAttr kWarp = str_attr("warp"); StringAttr kBlock = str_attr("block"); LinearLayout comp = dstLayout.invertAndCompose(srcLayout); - std::optional conversion = comp.divideRight( - LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) * - LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock)); - assert(conversion && "Expecting valid layout"); - int32_t subGroupSize = conversion->getOutDimSize(kLane); + LinearLayout conversion = *comp.quotient(kBlock)->quotient(kWarp); + int32_t subGroupSize = conversion.getOutDimSize(kLane); Location loc = op.getLoc(); @@ -772,28 +769,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion .Default(false); } - LogicalResult transferWithinLane(ConvertLayoutOp op, - const LinearLayout &srcLayout, - const LinearLayout &dstLayout, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // If the operation is a supported sub-group shuffle, perform via shuffle - // operations. - if (isSubGroupShuffle(srcLayout, dstLayout) && - isSupportedSubGroupShuffle(op, adaptor)) { - performSubGroupShuffle(op, srcLayout, dstLayout, adaptor, rewriter); - return success(); - } - // If the operation is a supported sub-group transposition, perform via SLM. - if (isSubGroupTranspose(srcLayout, dstLayout) && - isSupportedSubGroupTranspose(op, adaptor)) { - performSubGroupTranspose(op, srcLayout, dstLayout, adaptor, rewriter); - return success(); - } - // TODO(jlebar): Implement me. - return failure(); - } - bool isValidTypeForSubGroupTranspose(Type type) const { return TypeSwitch(type) .Case([](IntegerType intTy) { @@ -967,14 +942,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } return unwrapFromVectors(loc, transposedVecs, rewriter); } - - LogicalResult - transferWithinBlockGroup(ConvertLayoutOp op, const LinearLayout &srcLayout, - const LinearLayout &dstLayout, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // TODO(jlebar): Implement me. - return failure(); - } }; } // namespace