Skip to content

Commit 4f35147

Browse files
Merge OpenAI Triton commit 1e0a371 (#4804)
This PR change the Triton base from 570f24d to 1e0a371 (Jul 20). Pass rate: 98.62%
2 parents 82f505a + fd748f4 commit 4f35147

File tree

17 files changed

+747
-370
lines changed

17 files changed

+747
-370
lines changed

include/triton/Analysis/Utility.h

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,22 +172,29 @@ class GatherLoweringHelper {
172172
RankedTensorType dstTy;
173173
};
174174

175-
// This struct represents a decomposed layout conversion within a warp into
176-
// three transformations: P1 and P2 represent lane-dependent register shuffles
177-
// and W represents a warp shuffle. P2^-1 is returned because it represents the
178-
// (reg, lane) -> (reg) mapping from the perspective of the destination element.
175+
// This struct represents the factorization of a warp-local layout conversion
176+
// into three components: a register-only permutation, a lane-only permutation,
177+
// and a set of swaps between lane and register basis vectors. Algebraically, it
178+
// represents the factorization P = P_mixed \circ P_lane \circ P_reg. It is used
179+
// to aid in the implementation of the layout conversion using warp-shuffles.
179180
//
180-
// Nearly all layout conversions that only require data movement within a warp
181-
// can be implemented this way.
181+
// `pReg` and `pLane` are square layouts each with only one input and output
182+
// dimension. `mixedTranspositions` holds pairs of integers (i, j)
183+
// corresponding to the transposition (r_i l_j) of the i-th register basis
184+
// vector with the j-th lane basis vector.
182185
struct DecomposedWarpConversion {
183-
triton::LinearLayout P1, W, P2inv;
184-
triton::LinearLayout reducedP1, reducedP2inv;
186+
triton::LinearLayout pReg, pLane;
187+
SmallVector<std::pair<int, int>> mixedTranspositions;
185188
};
186189

187-
// Given the source and destination tensor types where a layout conversion only
188-
// involves data movement within warps, attempt to find a decomposition for a
189-
// warp layout conversion.
190-
std::optional<DecomposedWarpConversion>
190+
// Produces a decomposition of a permutation describing a warp-local layout
191+
// conversion as described in `DecomposedWarpConversion` above.
192+
//
193+
// This function handles cases where the numbers of register and lane basis
194+
// vectors differ between the two layouts. This is done by padding the smaller
195+
// dimension(s) with zero vectors, ensuring that the layout conversion can be
196+
// represented as a permutation.
197+
DecomposedWarpConversion
191198
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
192199
RankedTensorType dstTy);
193200

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,9 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
327327
StringRef funcName, Type funcType,
328328
StringRef libname = "",
329329
StringRef libpath = "");
330+
331+
// Multiply a square layout with 1 input and output dimension with a vector
332+
Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x);
330333
} // namespace gpu
331334

332335
} // namespace triton

include/triton/Dialect/TritonGPU/IR/Traits.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ class MemDescViewTrait
1616
// Optional: Add methods or verification logic here
1717
};
1818

19+
template <typename ConcreteType>
20+
class LocalLoadTrait
21+
: public mlir::OpTrait::TraitBase<ConcreteType, LocalLoadTrait> {
22+
// Optional: Add methods or verification logic here
23+
};
24+
1925
} // namespace OpTrait
2026
} // namespace mlir
2127

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
1717

1818
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
1919

20+
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
2021

2122
class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
2223
Dialect dialect = TritonGPU_Dialect,

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewT
322322
let hasFolder = 1;
323323
}
324324

325-
def TTG_LocalLoadOp : TTG_Op<"local_load"> {
325+
def TTG_LocalLoadOp : TTG_Op<"local_load", [LocalLoadTrait]> {
326326
let summary = "Load a buffer from local memory into a distributed tensor";
327327

328328
let description = [{

include/triton/Tools/LayoutUtils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,22 @@ LinearLayout reshapeLayout(MLIRContext *ctx, LinearLayout layout,
148148
// order.
149149
LinearLayout transposeLinearLayout(LinearLayout layout, ArrayRef<int> order);
150150

151+
// Reorders the in and out dimensions to match another layout.
152+
LinearLayout reorder_like(const LinearLayout &x, const LinearLayout &y);
153+
154+
// For two layouts, `src` and `dst`, that differ only by a permutation of
155+
// their basis vectors, return a permutation layout `P` which satisfies
156+
// `dst` \circ `P` = `src`.
157+
//
158+
// The returned layout has the following properties:
159+
// - The orders of the input and output dimensions of `P` match the order of the
160+
// input dimensions of `src`.
161+
// - Prioritizes making zero (broadcasting) vectors fixed-points of the
162+
// permutation. I.e., if a vector is zero in both `src` and `dst` for the same
163+
// input coordinate, it maps to itself under `P`.
164+
LinearLayout basisPermutationLayout(const LinearLayout &src,
165+
const LinearLayout &dst);
166+
151167
} // namespace mlir::triton
152168

153169
#endif // TRITON_TOOLS_LAYOUTUTILS_H

0 commit comments

Comments
 (0)