Skip to content

Commit db7170e

Browse files
authored
[Backend] Follow-up refactor of getWarpLayoutConvertDecomposition (#7571)
This PR is a follow-up to #7558 to move the logic of `basisPermutationLayout` inside `getWarpLayoutConvertDecomposition` and to remove the associated unit tests. We also restore the `convert_layout_blocked_blocked_multi_rep` LIT test with changes to the tensor shape and encodings. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 1e0a371 commit db7170e

File tree

5 files changed

+67
-194
lines changed

5 files changed

+67
-194
lines changed

include/triton/Tools/LayoutUtils.h

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -148,22 +148,6 @@ LinearLayout reshapeLayout(MLIRContext *ctx, LinearLayout layout,
148148
// order.
149149
LinearLayout transposeLinearLayout(LinearLayout layout, ArrayRef<int> order);
150150

151-
// Reorders the in and out dimensions to match another layout.
152-
LinearLayout reorder_like(const LinearLayout &x, const LinearLayout &y);
153-
154-
// For two layouts, `src` and `dst`, that differ only by a permutation of
155-
// their basis vectors, return a permutation layout `P` which satisfies
156-
// `dst` \circ `P` = `src`.
157-
//
158-
// The returned layout has the following properties:
159-
// - The orders of the input and output dimensions of `P` match the order of the
160-
// input dimensions of `src`.
161-
// - Prioritizes making zero (broadcasting) vectors fixed-points of the
162-
// permutation. I.e., if a vector is zero in both `src` and `dst` for the same
163-
// input coordinate, it maps to itself under `P`.
164-
LinearLayout basisPermutationLayout(const LinearLayout &src,
165-
const LinearLayout &dst);
166-
167151
} // namespace mlir::triton
168152

169153
#endif // TRITON_TOOLS_LAYOUTUTILS_H

lib/Analysis/Utility.cpp

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,11 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
278278
// subsequences of consecutive lane bits from cycles involving both bit types.
279279
// Further explanation of this method is below.
280280
//
281-
// The decomposition is implemented by building bases for the layouts `pReg`
282-
// and `pLane` by walking the cycles of `P`, a permutation layout returned by
283-
// `basisPermutationLayout(S, T)` which accepts two layouts `S` and `T` which
284-
// differ only by a permutation of their basis vectors.
281+
// The decomposition is performed in two stages. First, we compute the
282+
// permutation matrix `P` by using `invertAndCompose` to generate a skeleton
283+
// and then fill in any zero columns. Second, we walk the cycles of `P` to
284+
// factor out mixed transpositions to build `mixedTranspositions`, `pReg`, and
285+
// `pLane`.
285286

286287
// We remove any broadcasting in the register dimensions of the layouts before
287288
// forming the permutation `P` as the components of the decomposition directly
@@ -310,9 +311,10 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
310311
int nRegBases = std::max(nSrcRegBases, nDstRegBases);
311312
int nLaneBases = std::max(nSrcLaneBases, nDstLaneBases);
312313
// Restrict attention to the input dimensions which matter.
314+
SmallVector<StringAttr> inDimNames{kReg, kLane};
313315
auto outDimNames = llvm::to_vector(srcLayout.getOutDimNames());
314-
auto S = srcLayout.sublayout({kReg, kLane}, outDimNames);
315-
auto T = dstLayout.sublayout({kReg, kLane}, outDimNames);
316+
auto S = srcLayout.sublayout(inDimNames, outDimNames);
317+
auto T = dstLayout.sublayout(inDimNames, outDimNames);
316318
// Conditionally pad.
317319
if (nSrcRegBases != nDstRegBases || nSrcLaneBases != nDstLaneBases) {
318320
auto padWithZeros = [&](const LinearLayout &ll) {
@@ -334,10 +336,41 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
334336
T = padWithZeros(T);
335337
}
336338

337-
// Now that `S` and `T` have the same basis vectors, we compute the
338-
// permutation `P` which transforms `S` into `T`.
339-
auto P = basisPermutationLayout(S, T);
340-
auto &pBases = P.getBases();
339+
// Flatten outs for ease of building `P`, and reorder outs as flattening
340+
// depends on output dimension order.
341+
if (outDimNames != llvm::to_vector(T.getOutDimNames()))
342+
T = T.transposeOuts(outDimNames);
343+
S = S.flattenOuts();
344+
T = T.flattenOuts();
345+
346+
// We compute T^transpose \circ S, which serves as a skeleton for `P`, then
347+
// fill in zero columns, prioritizing producing fixed points. As we only need
348+
// the basis vectors of `P`, we never actually produce the LinearLayout.
349+
auto pBases = S.invertAndCompose(T).getBases();
350+
351+
// Find the common and uncommon zeros of S and T
352+
SmallVector<std::pair<int32_t, int32_t>> srcFreeZeros;
353+
SmallVector<std::pair<int32_t, int32_t>> dstFreeZeros;
354+
for (auto [dimIdx, dim] : llvm::enumerate(inDimNames)) {
355+
for (int inIdx = 0; inIdx < S.getInDimSizeLog2(dim); ++inIdx) {
356+
int sVal = S.getBasis(dim, inIdx)[0];
357+
int tVal = T.getBasis(dim, inIdx)[0];
358+
if (sVal == 0 && tVal == 0) {
359+
pBases[dim][inIdx][dimIdx] = 1 << inIdx;
360+
} else if (sVal == 0) {
361+
srcFreeZeros.emplace_back(dimIdx, inIdx);
362+
} else if (tVal == 0) {
363+
dstFreeZeros.emplace_back(dimIdx, inIdx);
364+
}
365+
}
366+
}
367+
// Fill in non-fixed-point zero vectors
368+
for (auto [srcZeroLoc, dstZeroLoc] : llvm::zip(srcFreeZeros, dstFreeZeros)) {
369+
auto [srcDimIdx, srcIdx] = srcZeroLoc;
370+
auto [dstDimIdx, dstIdx] = dstZeroLoc;
371+
auto inDim = inDimNames[srcDimIdx];
372+
pBases[inDim][srcIdx][dstDimIdx] = 1 << dstIdx;
373+
}
341374

342375
// We walk the cycles of `P` to build the bases for `pReg` and `pLane` while
343376
// factoring out mixed transpositions from cycles that include both register
@@ -355,9 +388,8 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
355388
return (dim == kReg) ? index : nRegBases + index;
356389
};
357390

358-
auto dimNames = llvm::to_vector(P.getInDimNames());
359-
for (auto dim : dimNames) {
360-
int inDimSize = P.getInDimSizeLog2(dim);
391+
for (auto dim : inDimNames) {
392+
int inDimSize = S.getInDimSizeLog2(dim);
361393
for (int i = 0; i < inDimSize; ++i) {
362394
if (visited.test(flatIdx(dim, i)))
363395
continue;
@@ -393,7 +425,7 @@ getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
393425
int32_t nextIdx;
394426
for (auto [nextDimIdx, nextVal] : llvm::enumerate(nextVec)) {
395427
if (nextVal != 0) {
396-
nextDim = dimNames[nextDimIdx];
428+
nextDim = inDimNames[nextDimIdx];
397429
nextIdx = llvm::Log2_32(nextVal);
398430
}
399431
}

lib/Tools/LayoutUtils.cpp

Lines changed: 0 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include "triton/Tools/LayoutUtils.h"
22
#include "triton/Tools/GenericSwizzling.h"
3-
#include "llvm/ADT/SmallSet.h"
43

54
namespace mlir::triton {
65

@@ -447,137 +446,4 @@ LinearLayout transposeLinearLayout(LinearLayout layout, ArrayRef<int> order) {
447446
to_vector(layout.getOutDimNames()));
448447
}
449448

450-
LinearLayout reorder_like(const LinearLayout &x, const LinearLayout &y) {
451-
// This will check that the names are the same up to permutation, and
452-
// apply the necessary permutation:
453-
auto x2 = x.transposeOuts(llvm::to_vector(y.getOutDimNames()));
454-
auto x3 = x2.transposeIns(llvm::to_vector(y.getInDimNames()));
455-
return x3;
456-
}
457-
458-
LinearLayout basisPermutationLayout(const LinearLayout &src,
459-
const LinearLayout &dst) {
460-
// This function computes a permutation layout `P` which satisfies the
461-
// property `src = dst \circ P`. It requires that the multiset of basis
462-
// vectors for each of `src` and `dst` agree and that the nonzero values in
463-
// each of the multisets are unique. I.e., broadcasting is allowed in either
464-
// layout so long as the degree of broadcasting (the number of zero basis
465-
// vectors) is the same between the two layouts.
466-
//
467-
// The orders of the input and output dimensions of `P` are set to be the
468-
// order of the input dimensions of `src`.
469-
//
470-
// The mapping of broadcasting basis vectors prioritizes keeping such vectors
471-
// as fixed points of the permutation. I.e., if `src[inDim][i]` and
472-
// `dst[inDim][i]` are zero vectors, then `P[inDim][i][inDimIdx] == 1 << i`,
473-
// where `inDimIdx` is the index of `inDim` in `src`. Otherwise, they are
474-
// paired according to their order of appearance in the two layouts, again
475-
// following the order of the input dimensions of `src`.
476-
//
477-
// The algorithm first performs a linear scan over the columns of `dst` and
478-
// `src` to build a map from ('flattened') basis vectors to the input
479-
// vectors of `dst` while tracking the fixed-point zero vectors and 'free'
480-
// zero vectors. It then performs a second linear scan over `src` to build
481-
// the basis of `P`.
482-
483-
// Check that the input and output dimensions are equal up to ordering.
484-
auto srcInDims = src.getInDimNames();
485-
assert(std::is_permutation(srcInDims.begin(), srcInDims.end(),
486-
dst.getInDimNames().begin()) &&
487-
"Layouts must have same input dimensions");
488-
for (auto inDim : srcInDims) {
489-
assert(src.getInDimSize(inDim) == dst.getInDimSize(inDim) &&
490-
"Layouts must have same input dimension sizes");
491-
}
492-
auto srcOutDims = src.getOutDims();
493-
assert(std::is_permutation(srcOutDims.begin(), srcOutDims.end(),
494-
dst.getOutDims().begin()) &&
495-
"Layouts must have same output dimensions and dimension sizes");
496-
497-
auto srcFlat = src.flattenOuts();
498-
// Reorder the output dimensions of `dst` if necessary before flattening, as
499-
// flattening depends on the order.
500-
LinearLayout dstFlat;
501-
if (!llvm::equal(src.getOutDims(), dst.getOutDims())) {
502-
auto temp = dst.transposeOuts(llvm::to_vector(src.getOutDimNames()));
503-
dstFlat = temp.flattenOuts();
504-
} else {
505-
dstFlat = dst.flattenOuts();
506-
}
507-
508-
// Populate the map of flattened values to dst inputs and track zero vectors.
509-
// The `commonZeros` become fixed-points of `P`, while the 'free' zeros are
510-
// later paired with one another.
511-
DenseMap<int32_t, std::pair<StringAttr, int32_t>> valToDstInput;
512-
llvm::SmallDenseMap<StringAttr, llvm::SmallSet<int32_t, 4>> commonZeros;
513-
SmallVector<std::pair<StringAttr, int32_t>> dstFreeZeros;
514-
size_t srcFreeZerosCount = 0;
515-
516-
// We traverse the input dimensions according to their order in `src` so that
517-
// 'free' zero vectors for a given input dimension in `src` prefer to map to
518-
// 'free' zero vectors in the same dimension in `dst.
519-
for (auto inDim : srcInDims) {
520-
int inDimSize = dstFlat.getInDimSizeLog2(inDim);
521-
for (int i = 0; i < inDimSize; ++i) {
522-
int32_t dstVal = dstFlat.getBasis(inDim, i)[0];
523-
int32_t srcVal = srcFlat.getBasis(inDim, i)[0];
524-
if (dstVal == 0 && srcVal == 0) {
525-
commonZeros[inDim].insert(i);
526-
} else if (dstVal == 0) {
527-
dstFreeZeros.emplace_back(inDim, i);
528-
} else {
529-
auto [it, success] = valToDstInput.try_emplace(dstVal, inDim, i);
530-
assert(success && "Found duplicate nonzero vectors in dst layout");
531-
if (srcVal == 0)
532-
++srcFreeZerosCount;
533-
}
534-
}
535-
}
536-
assert(srcFreeZerosCount == dstFreeZeros.size() &&
537-
"src and dst layouts have differing number of zero bases");
538-
539-
// Build the basis vectors for the permutation layout `P`.
540-
// For each basis vector in `src`, determine its target in `dst`:
541-
// - If the vector is nonzero, find the corresponding vector in `dst`.
542-
// - If it is a zero vector common to both layouts, set it as a fixed-point.
543-
// - Otherwise, pair it with the next available free zero of `dst`.
544-
LinearLayout::BasesT pBases;
545-
size_t numDims = llvm::size(srcInDims);
546-
size_t freeZeroIdx = 0;
547-
for (auto inDim : srcInDims) {
548-
int inDimSize = srcFlat.getInDimSizeLog2(inDim);
549-
auto &inDimBases = pBases[inDim];
550-
inDimBases.reserve(inDimSize);
551-
for (int i = 0; i < inDimSize; ++i)
552-
inDimBases.emplace_back(numDims, 0);
553-
554-
for (int inIdx = 0; inIdx < inDimSize; ++inIdx) {
555-
int32_t val = srcFlat.getBasis(inDim, inIdx)[0];
556-
std::pair<StringAttr, int32_t> dstTarget;
557-
558-
if (val != 0) {
559-
auto it = valToDstInput.find(val);
560-
assert(it != valToDstInput.end() && "src basis not found in dst");
561-
dstTarget = it->second;
562-
} else if (commonZeros.lookup(inDim).count(inIdx)) {
563-
dstTarget = {inDim, inIdx};
564-
} else {
565-
dstTarget = dstFreeZeros[freeZeroIdx++];
566-
}
567-
568-
// Build the basis vector for `P` using the ordering on output dimensions
569-
// induced by the ordering on the input dimensions of `src`.
570-
auto it = llvm::find(srcInDims, dstTarget.first);
571-
int outDimIdx = std::distance(srcInDims.begin(), it);
572-
inDimBases[inIdx][outDimIdx] = 1 << dstTarget.second;
573-
}
574-
}
575-
// Declare the ordering on the `outDims` of `P` to be that of `srcInDims`.
576-
SmallVector<std::pair<StringAttr, int32_t>> outDims;
577-
for (auto outDim : srcInDims)
578-
outDims.emplace_back(outDim, srcFlat.getInDimSize(outDim));
579-
580-
return LinearLayout(std::move(pBases), outDims, /*requireSurjective=*/true);
581-
}
582-
583449
} // namespace mlir::triton

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,27 @@ tt.func @convert_layout_ptr_element(%arg0: tensor<16x16x!tt.ptr<i32>, #blocked0>
874874

875875
// -----
876876

877+
#blocked0 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
878+
#blocked1 = #ttg.blocked<{sizePerThread = [4, 8], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
879+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
880+
// CHECK: llvm.mlir.global external @global_smem
881+
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
882+
tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<32x32xf32, #blocked0>) {
883+
// CHECK: llvm.mlir.addressof @global_smem
884+
// CHECK-COUNT-4: llvm.store
885+
// CHECK: nvvm.barrier0
886+
// CHECK-COUNT-4: llvm.load
887+
// CHECK: nvvm.barrier0
888+
// CHECK-COUNT-4: llvm.store
889+
// CHECK: nvvm.barrier0
890+
// CHECK-COUNT-4: llvm.load
891+
%0 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked0> -> tensor<32x32xf32, #blocked1>
892+
tt.return
893+
}
894+
}
895+
896+
// -----
897+
877898
#blocked0 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
878899
#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase=1, maxPhase=1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
879900
#mma0 = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>

unittest/Tools/LayoutUtilsTest.cpp

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -45,35 +45,5 @@ TEST_F(LayoutUtilsTest, SquareSublayoutIsIdentity) {
4545
EXPECT_TRUE(squareSublayoutIsIdentity(l3, {S("in1"), S("in2")}));
4646
}
4747

48-
TEST_F(LayoutUtilsTest, BasisPermutationLayout) {
49-
LinearLayout src1(
50-
{{S("in1"), {{1, 0}, {0, 0}, {0, 2}}}, {S("in2"), {{2, 0}, {0, 1}}}},
51-
{S("out1"), S("out2")});
52-
LinearLayout dst1(
53-
{{S("in2"), {{1, 0}, {0, 0}}}, {S("in1"), {{2, 0}, {0, 1}, {0, 2}}}},
54-
{S("out2"), S("out1")});
55-
LinearLayout P1(
56-
{{S("in1"), {{2, 0}, {0, 2}, {1, 0}}}, {S("in2"), {{4, 0}, {0, 1}}}},
57-
{S("in1"), S("in2")});
58-
EXPECT_EQ(P1, basisPermutationLayout(src1, dst1));
59-
EXPECT_EQ(src1, reorder_like(P1.compose(dst1), src1));
60-
LinearLayout src2({{S("in3"), {{2, 0}, {4, 0}, {8, 0}, {0, 0}}},
61-
{S("in2"), {{0, 0}, {16, 0}, {0, 0}, {0, 1}}},
62-
{S("in1"), {{0, 2}, {0, 0}, {0, 4}}}},
63-
{{S("out1"), 32}, {S("out2"), 8}},
64-
/*requireSurjective=*/false);
65-
LinearLayout dst2({{S("in1"), {{0, 0}, {0, 16}, {2, 0}}},
66-
{S("in2"), {{0, 4}, {0, 8}, {0, 0}, {4, 0}}},
67-
{S("in3"), {{0, 0}, {0, 0}, {0, 2}, {1, 0}}}},
68-
{{S("out2"), 8}, {S("out1"), 32}},
69-
/*requireSurjective=*/false);
70-
LinearLayout P2({{S("in3"), {{4, 0, 0}, {0, 1, 0}, {0, 2, 0}, {1, 0, 0}}},
71-
{S("in2"), {{2, 0, 0}, {0, 0, 2}, {0, 4, 0}, {8, 0, 0}}},
72-
{S("in1"), {{0, 0, 4}, {0, 0, 1}, {0, 8, 0}}}},
73-
{S("in3"), S("in2"), S("in1")});
74-
EXPECT_EQ(P2, basisPermutationLayout(src2, dst2));
75-
EXPECT_EQ(src2, reorder_like(P2.compose(dst2), src2));
76-
}
77-
7848
} // namespace
7949
} // namespace mlir::triton

0 commit comments

Comments
 (0)