Skip to content

Commit 3854ae8

Browse files
[Backend] Improve warp-local layout conversion algo using shuffles (#7558)
This PR replaces the existing `transferWithinWarp` warp-shuffle algorithm for the lowering of `ConvertLayoutOp` with a more precise algorithm which allows for broadcasting in layouts and which emits fewer `select` and `shuffle` instructions in general. We additionally implement register packing for sub-32-bit data types. ### Combinatorial point of view The new implementation describes and decomposes a layout conversion as a permutation of hardware index bits ```math P = P_{\text{mixed}} \circ P_{\text{lane}} \circ P_{\text{reg}} ``` where - $P_{\text{lane}}$ is a permutation of lane bits, - $P_{\text{reg}}$ is a permutation of register bits, - $P_{\text{mixed}}$ is a product of disjoint transpositions which swap register and lane bits. ### Instruction count differences Letting $m$ denote the number of such mixed transpositions and $R$ denote the number of registers, the existing algorithm: - Uses $2 \cdot (2^m - 1) \cdot R$ `select`s - Uses $R$ `shuffle`s while the new algorithm, using the in-place bit-swap implementation: - Uses $2 \cdot m \cdot R$ `select`s - Uses $(1 - (0.5)^m) \cdot R$ `shuffle`s if $P_{\text{lane}}$ is the trivial permutation and $R$ `shuffle`s otherwise. and in the case $m = 1$ with trivial $P_{\text{lane}}$, using the out-of-place bit-swap implementation: - Uses $1.5 \cdot R$ `select`s - Uses $0.5 \cdot R$ `shuffle`s Despite these improvements, empirical results on Turing T4 GPUs, using a modification of the test in triton-lang/triton#5419 (comment), show that the shared memory approach is faster in most cases when $m > 1$. Only the $m = 2$ case with trivial $P_{\text{lane}}$ using the out-of-place implementation twice in succession was measured to be faster. For this reason, in the new algorithm we still bail out when $m > 1$. Due to the higher arithmetic pressure, I believe it’s unlikely the shuffle approach would be profitable compared to the shared memory approach in these cases, but testing with actual kernels rather than microbenchmarks and using different hardware may reveal otherwise. The new algorithm also handles layout conversions that have been hand-implemented in some cases, such as `convertMMAV3To8BitsDotOperand`. However, this PR does not attempt to remove or modify any of these code paths. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [ ] I have not added any `lit` tests. - [x ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Co-authored-by: apgoucher <[email protected]>
1 parent 570f24d commit 3854ae8

File tree

10 files changed

+725
-347
lines changed

10 files changed

+725
-347
lines changed

include/triton/Analysis/Utility.h

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,22 +172,29 @@ class GatherLoweringHelper {
172172
RankedTensorType dstTy;
173173
};
174174

175-
// This struct represents a decomposed layout conversion within a warp into
176-
// three transformations: P1 and P2 represent lane-dependent register shuffles
177-
// and W represents a warp shuffle. P2^-1 is returned because it represents the
178-
// (reg, lane) -> (reg) mapping from the perspective of the destination element.
175+
// This struct represents the factorization of a warp-local layout conversion
176+
// into three components: a register-only permutation, a lane-only permutation,
177+
// and a set of swaps between lane and register basis vectors. Algebraically, it
178+
// represents the factorization P = P_mixed \circ P_lane \circ P_reg. It is used
179+
// to aid in the implementation of the layout conversion using warp-shuffles.
179180
//
180-
// Nearly all layout conversions that only require data movement within a warp
181-
// can be implemented this way.
181+
// `pReg` and `pLane` are square layouts each with only one input and output
182+
// dimension. `mixedTranspositions` holds pairs of integers (i, j)
183+
// corresponding to the transposition (r_i l_j) of the i-th register basis
184+
// vector with the j-th lane basis vector.
182185
struct DecomposedWarpConversion {
183-
triton::LinearLayout P1, W, P2inv;
184-
triton::LinearLayout reducedP1, reducedP2inv;
186+
triton::LinearLayout pReg, pLane;
187+
SmallVector<std::pair<int, int>> mixedTranspositions;
185188
};
186189

187-
// Given the source and destination tensor types where a layout conversion only
188-
// involves data movement within warps, attempt to find a decomposition for a
189-
// warp layout conversion.
190-
std::optional<DecomposedWarpConversion>
190+
// Produces a decomposition of a permutation describing a warp-local layout
191+
// conversion as described in `DecomposedWarpConversion` above.
192+
//
193+
// This function handles cases where the numbers of register and lane basis
194+
// vectors differ between the two layouts. This is done by padding the smaller
195+
// dimension(s) with zero vectors, ensuring that the layout conversion can be
196+
// represented as a permutation.
197+
DecomposedWarpConversion
191198
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
192199
RankedTensorType dstTy);
193200

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,9 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
327327
StringRef funcName, Type funcType,
328328
StringRef libname = "",
329329
StringRef libpath = "");
330+
331+
// Multiply a square layout with 1 input and output dimension with a vector
332+
Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x);
330333
} // namespace gpu
331334

332335
} // namespace triton

include/triton/Tools/LayoutUtils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,22 @@ LinearLayout reshapeLayout(MLIRContext *ctx, LinearLayout layout,
148148
// order.
149149
LinearLayout transposeLinearLayout(LinearLayout layout, ArrayRef<int> order);
150150

151+
// Reorders the in and out dimensions to match another layout.
152+
LinearLayout reorder_like(const LinearLayout &x, const LinearLayout &y);
153+
154+
// For two layouts, `src` and `dst`, that differ only by a permutation of
155+
// their basis vectors, return a permutation layout `P` which satisfies
156+
// `dst` \circ `P` = `src`.
157+
//
158+
// The returned layout has the following properties:
159+
// - The orders of the input and output dimensions of `P` match the order of the
160+
// input dimensions of `src`.
161+
// - Prioritizes making zero (broadcasting) vectors fixed-points of the
162+
// permutation. I.e., if a vector is zero in both `src` and `dst` for the same
163+
// input coordinate, it maps to itself under `P`.
164+
LinearLayout basisPermutationLayout(const LinearLayout &src,
165+
const LinearLayout &dst);
166+
151167
} // namespace mlir::triton
152168

153169
#endif // TRITON_TOOLS_LAYOUTUTILS_H

0 commit comments

Comments
 (0)