Skip to content

Commit 24acc39

Browse files
Merge OpenAI Triton commit c24aa15 (#4113)
This PR change the Triton base from efa8774 to c24aa15 (May 2). Pass rate: 94.64%
2 parents 8d5b317 + 268b414 commit 24acc39

File tree

74 files changed

+3171
-1291
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+3171
-1291
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ arbitrary LLVM version.
152152

153153
- Do a local build. Run command `pip install -e .`
154154
- Get the full path to the `compile_commands.json` file produced by the build:
155-
`find python/build -name 'compile_commands.json' | xargs readlink -f`.
155+
`find ./build -name 'compile_commands.json' | xargs readlink -f`.
156156
You might get a full path similar to `/Users/{username}/triton/build/cmake.macosx-11.1-arm64-cpython-3.12/compile_commands.json`
157157
- In vscode, install the
158158
[C/C++

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ void registerTestAlignmentPass();
5151
void registerTestAllocationPass();
5252
void registerTestLivenessPass();
5353
void registerTestMembarPass();
54+
void registerTestAMDGPUMembarPass();
5455
void registerTestTritonAMDGPURangeAnalysis();
5556
} // namespace test
5657
} // namespace mlir
@@ -66,6 +67,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6667
mlir::test::registerTestAllocationPass();
6768
mlir::test::registerTestLivenessPass();
6869
mlir::test::registerTestMembarPass();
70+
mlir::test::registerTestAMDGPUMembarPass();
6971
mlir::test::registerTestTritonAMDGPURangeAnalysis();
7072
mlir::triton::registerConvertTritonToTritonGPUPass();
7173
mlir::triton::intel::registerConvertTritonToTritonGPUWarpPass();

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 7 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -718,89 +718,16 @@ void storeDistributedToShared(
718718
RewriterBase &rewriter, const TargetInfoBase &target,
719719
std::pair<size_t, Type> *const llvmOpCount = nullptr);
720720

721-
inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
722-
RewriterBase &rewriter) {
723-
assert(bool(llvmStruct) && "can not unpack null values");
724-
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
725-
isa<triton::PointerType>(llvmStruct.getType()) ||
726-
isa<LLVM::LLVMPointerType>(llvmStruct.getType()))
727-
return {llvmStruct};
728-
ArrayRef<Type> types =
729-
cast<LLVM::LLVMStructType>(llvmStruct.getType()).getBody();
730-
SmallVector<Value> results(types.size());
731-
auto b = TritonLLVMOpBuilder(loc, rewriter);
732-
for (unsigned i = 0; i < types.size(); ++i) {
733-
Type type = types[i];
734-
results[i] = b.extract_val(type, llvmStruct, i);
735-
}
736-
return results;
737-
}
738-
739-
inline Value packLLElements(Location loc,
740-
const LLVMTypeConverter *typeConverter,
741-
ValueRange resultVals, RewriterBase &rewriter,
742-
Type type) {
743-
auto structType =
744-
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
745-
if (!structType) {
746-
assert(resultVals.size() == 1);
747-
return *resultVals.begin();
748-
}
749-
750-
auto elementTypes = structType.getBody();
751-
if (elementTypes.size() != resultVals.size()) {
752-
emitError(loc) << " size mismatch when packing elements for LLVM struct"
753-
<< " expected " << elementTypes.size() << " but got "
754-
<< resultVals.size();
755-
}
756-
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
757-
auto b = TritonLLVMOpBuilder(loc, rewriter);
758-
for (const auto &v : llvm::enumerate(resultVals)) {
759-
if (!v.value()) {
760-
emitError(loc)
761-
<< "cannot insert null values into struct, but tried to insert"
762-
<< v.value();
763-
}
764-
if (v.value().getType() != elementTypes[v.index()]) {
765-
LDBG("type " << type << " structType " << structType);
766-
LDBG("value " << v.value());
767-
emitError(loc) << "invalid element type in packLLElements. Expected "
768-
<< elementTypes[v.index()] << " but got "
769-
<< v.value().getType();
770-
}
771-
llvmStruct = b.insert_val(structType, llvmStruct, v.value(), v.index());
772-
}
773-
return llvmStruct;
774-
}
721+
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
722+
RewriterBase &rewriter);
775723

776-
inline SmallVector<Value> unpackLLVector(Location loc, Value llvmVec,
777-
RewriterBase &rewriter) {
778-
assert(bool(llvmVec) && "cannot unpack null value");
779-
if (llvmVec.getType().isIntOrIndexOrFloat() ||
780-
isa<triton::PointerType>(llvmVec.getType()) ||
781-
isa<LLVM::LLVMPointerType>(llvmVec.getType()))
782-
return {llvmVec};
724+
Value packLLElements(Location loc, const LLVMTypeConverter *typeConverter,
725+
ValueRange resultVals, RewriterBase &rewriter, Type type);
783726

784-
auto b = TritonLLVMOpBuilder(loc, rewriter);
785-
SmallVector<Value> results;
786-
for (int i = 0; i < cast<VectorType>(llvmVec.getType()).getNumElements();
787-
i++) {
788-
results.push_back(b.extract_element(llvmVec, b.i32_val(i)));
789-
}
790-
return results;
791-
}
727+
SmallVector<Value> unpackLLVector(Location loc, Value llvmVec,
728+
RewriterBase &rewriter);
792729

793-
inline Value packLLVector(Location loc, ValueRange vals,
794-
RewriterBase &rewriter) {
795-
assert(vals.size() > 0);
796-
auto vecType = vec_ty(vals[0].getType(), vals.size());
797-
auto b = TritonLLVMOpBuilder(loc, rewriter);
798-
Value vec = b.undef(vecType);
799-
for (int i = 0; i < vals.size(); i++) {
800-
vec = b.insert_element(vec, vals[i], b.i32_val(i));
801-
}
802-
return vec;
803-
}
730+
Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter);
804731

805732
inline bool
806733
isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ def TT_BroadcastOp : TT_Op<"broadcast", [Pure,
493493
let hasVerifier = 1;
494494
}
495495

496+
// Cat is not pure because it may reorder elements.
496497
def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
497498
SameTypeOperands,
498499
SameOperandsAndResultElementType]> {
@@ -506,7 +507,7 @@ def TT_CatOp : TT_Op<"cat", [NoMemoryEffect,
506507
}
507508

508509
def TT_JoinOp : TT_Op<"join", [
509-
NoMemoryEffect, SameTypeOperands]> {
510+
Pure, SameTypeOperands]> {
510511
let summary = "join two tensors along a new, minor dimension";
511512
let description = [{
512513
For example, if the two input tensors are 4x8xf32, returns a tensor of
@@ -526,7 +527,7 @@ def TT_JoinOp : TT_Op<"join", [
526527
}
527528

528529
def TT_SplitOp : TT_Op<"split", [
529-
NoMemoryEffect,
530+
Pure,
530531
InferTypeOpWithLayoutEquivalence,
531532
TypesMatchWith<"outLHS and outRHS types match",
532533
"outLHS", "outRHS", "$_self">,

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// TritonGPU depends on Triton
1010
#include "triton/Dialect/Triton/IR/Dialect.h"
1111
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
12+
#include "triton/Dialect/TritonGPU/IR/Traits.h"
1213
#include "triton/Dialect/TritonGPU/IR/Types.h"
1314

1415
#include <unordered_map>
@@ -278,6 +279,10 @@ bool areLayoutsEquivalent(ArrayRef<int64_t> shape, Attribute lhs,
278279

279280
// Return true if the innermost numElems are contiguous.
280281
bool isInnermostContiguous(MemDescType type, unsigned numElems);
282+
283+
LinearLayout inferReshapeLinearLayout(ArrayRef<int64_t> srcShape,
284+
Attribute srcEnc,
285+
ArrayRef<int64_t> dstShape);
281286
} // namespace mlir::triton::gpu
282287

283288
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
5454
//
5555
// If `disableSwizzle` is set, then the resulting layout does not include
5656
// swizzling.
57-
LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
58-
NVMMASharedEncodingAttr shared,
59-
bool disableSwizzle = false);
57+
LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
58+
NVMMASharedEncodingAttr shared,
59+
bool disableSwizzle = false);
6060

6161
// Given a linear layout where the input dimensions contain a "block" dimension,
6262
// this method sets the "block" dimension to 0 and removes the corresponding
@@ -282,11 +282,6 @@ LinearLayout chooseScaledMfmaScaleLayout(
282282
MLIRContext *ctx, int dotOperandIdx,
283283
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
284284
ArrayRef<int64_t> dotOperandShape, unsigned mfmaMDim);
285-
286-
// Create LinearLayout for nvidia mma tile.
287-
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
288-
unsigned kWidth, ArrayRef<unsigned> order,
289-
ArrayRef<unsigned> repOrder);
290285
} // namespace mlir::triton::gpu
291286

292287
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef TRITONGPU_IR_TRAITS_H_
2+
#define TRITONGPU_IR_TRAITS_H_
3+
4+
#include "mlir/IR/BuiltinTypes.h"
5+
#include "mlir/IR/OpDefinition.h"
6+
#include "mlir/Interfaces/InferTypeOpInterface.h"
7+
#include "mlir/Support/LogicalResult.h"
8+
#include "triton/Dialect/Triton/IR/Types.h"
9+
10+
namespace mlir {
11+
namespace OpTrait {
12+
13+
template <typename ConcreteType>
14+
class MemDescViewTrait
15+
: public mlir::OpTrait::TraitBase<ConcreteType, MemDescViewTrait> {
16+
// Optional: Add methods or verification logic here
17+
};
18+
19+
} // namespace OpTrait
20+
} // namespace mlir
21+
22+
#endif

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
1515
];
1616
}
1717

18+
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
19+
20+
1821
class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
1922
Dialect dialect = TritonGPU_Dialect,
2023
string baseCppClass = "::mlir::Attribute">
@@ -309,46 +312,54 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
309312
if(!mmaEnc)
310313
return get(context, 1, 1, 1, order, CTALayout);
311314

315+
int opIdx = dotOpEnc.getOpIdx();
316+
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
317+
318+
// number of rows per phase
319+
320+
// index of the inner dimension in `order`
321+
unsigned inner = (opIdx == 0) ? 0 : 1;
322+
312323
// ---- begin Ampere & Hopper ----
313324
if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
314-
return get(context, dotOpEnc.getOpIdx(), dotOpEnc.getKWidth(), shape, order, CTALayout, typeWidthInBit, needTrans);
325+
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
326+
perPhase = std::max<int>(perPhase, 1);
327+
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
328+
int vecWidth = 32 / typeWidthInBit;
329+
if (vecWidth != dotOpEnc.getKWidth() && order[0] == inner) {
330+
perPhase = std::max<int>(perPhase, 2 * vecWidth);
331+
}
332+
int rank = order.size();
333+
// --- handle A operand ---
334+
if (opIdx == 0) { // compute swizzling for A operand
335+
int m = (needTrans) ? matShape[2] : matShape[0];
336+
int k = (needTrans) ? matShape[0] : matShape[2];
337+
int vec = (order[0] == rank-1) ? k : m;
338+
int mmaStride = (order[0] == rank-1) ? m : k;
339+
int maxPhase = std::max(mmaStride / perPhase, 1);
340+
return get(context, vec, perPhase, maxPhase, order, CTALayout);
341+
}
342+
343+
// --- handle B operand ---
344+
if (opIdx == 1) {
345+
// we compute vec and maxPhase m, n and k size of the mma
346+
// instruction. when matmul operands is transposed, we should
347+
// consider that to get m, n and k.
348+
int n = needTrans ? matShape[2] : matShape[1];
349+
int k = needTrans ? matShape[1] : matShape[2];
350+
int vec = (order[0] == rank-1) ? n : k;
351+
int mmaStride = (order[0] == rank-1) ? k : n;
352+
int maxPhase = std::max(mmaStride / perPhase, 1);
353+
return get(context, vec, perPhase, maxPhase, order, CTALayout);
354+
}
355+
356+
llvm_unreachable("invalid operand index");
315357
}
316358

317359
// ---- not implemented ----
318360
llvm_unreachable("unsupported swizzling for provided MMA version");
319361
}]>,
320362

321-
// NVIDIA constructor!
322-
// TODO(lezcano): We should totally get rid of all these constructors...
323-
AttrBuilder<(ins "int":$opIdx,
324-
"unsigned":$kWidth,
325-
"ArrayRef<int64_t>":$shape,
326-
"ArrayRef<unsigned>":$order,
327-
"CTALayoutAttr":$CTALayout,
328-
"unsigned":$bitwidth,
329-
"bool":$needTrans), [{
330-
int K = getShapePerCTA(CTALayout.getCTASplitNum(), shape)[order[0]];
331-
// Elems necessary to cover all the banks divided by the inner dimension
332-
// This packs a few rows together for small K
333-
int perPhase = std::max<int>(1024 / (bitwidth * K), 1);
334-
335-
int mmaStride = 8;
336-
int vec = 4 * kWidth;
337-
// needsTrans is equiv. to flipping the opIdx
338-
if (needTrans)
339-
std::swap(vec, mmaStride);
340-
assert(opIdx == 0 || opIdx == 1);
341-
int rank = order.size();
342-
int kDim = opIdx == 0 ? rank-1 : rank-2;
343-
if (order[0] != kDim)
344-
std::swap(vec, mmaStride);
345-
// Count how many vec elements are needed to cover all the banks
346-
int maxPhase = std::max(std::min<int>(mmaStride, 1024 / (vec * bitwidth)), 1);
347-
// Account for the row packing from perPhase: mmaStride / perPhase
348-
maxPhase = std::max(maxPhase / perPhase, 1);
349-
return get(context, vec, perPhase, maxPhase, order, CTALayout);
350-
}]>,
351-
352363
AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc,
353364
"ArrayRef<int64_t>":$shape,
354365
"ArrayRef<unsigned>":$order,
@@ -387,6 +398,8 @@ def NVMMASharedEncodingAttr :
387398
This is meant to represent 2d tiled blocked layout.
388399
The full layout representation is described here:
389400
https://docs.nvidia.com/cuda/parallel-thread-execution/#asynchronous-warpgroup-level-matrix-shared-memory-layout
401+
When the memdesc has more than 2 dimensions the tiling is applied to 8 rows even if the first outer dimension is smaller than 8.
402+
In this case `transposed` means that the contiguous dimension is the most outer dimension of the memdesc.
390403
}];
391404

392405

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def TTG_LocalDeallocOp : TTG_Op<"local_dealloc"> {
198198
let assemblyFormat = [{$src attr-dict `:` qualified(type($src))}];
199199
}
200200

201-
def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
201+
def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure, MemDescViewTrait]> {
202202
let summary = "take a subview of the descriptor.";
203203

204204
let description = [{
@@ -224,6 +224,7 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
224224
}
225225

226226
def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
227+
MemDescViewTrait,
227228
TransposeOpInterface,
228229
InferTypeOpWithLayoutEquivalence,
229230
SameOperandsAndResultElementType]> {
@@ -248,6 +249,29 @@ def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
248249
let hasFolder = 1;
249250
}
250251

252+
def TTG_MemDescReshapeOp : TTG_Op<"memdesc_reshape", [Pure,
253+
MemDescViewTrait,
254+
SameOperandsAndResultElementType]> {
255+
let summary = "creates a descriptor for the new shape";
256+
257+
let description = [{
258+
This operation returns a new descriptor representing a reshaped view of the underlying buffer.
259+
This doesn't affect the memory.
260+
}];
261+
262+
let arguments = (ins TTG_MemDescType:$src);
263+
264+
let arguments = (
265+
ins TTG_MemDescType:$src
266+
);
267+
268+
let results = (outs TTG_MemDescType:$result);
269+
270+
let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))";
271+
272+
let hasVerifier = 1;
273+
}
274+
251275
def TTG_LocalLoadOp : TTG_Op<"local_load"> {
252276
let summary = "Load a buffer from local memory into a distributed tensor";
253277

include/triton/Dialect/TritonGPU/Transforms/Partition.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class WarpSchedule {
4545
ArrayRef<Operation *> getOps() const { return ops; }
4646

4747
void insert(Operation *op) { ops.push_back(op); }
48+
void remove(Operation *op) { ops.erase(llvm::find(ops, op)); }
4849

4950
private:
5051
void setIndex(int idx) { this->idx = idx; }
@@ -62,6 +63,8 @@ class WarpSchedule {
6263
Partition *addPartition(unsigned stage);
6364
// Give each partition a new index and order. The indices must be unique.
6465
void reorderPartitions(ArrayRef<unsigned> order);
66+
// Update the op to partition mapping.
67+
void updatePartitions();
6568

6669
// Get the partition the op belongs to.
6770
Partition *getPartition(Operation *op);
@@ -115,6 +118,9 @@ class WarpSchedule {
115118
scf::ForOp loop, const Partition *partition,
116119
function_ref<void(OpResult, OpOperand &, unsigned)> callback) const;
117120

121+
// Debug dump the schedule.
122+
LLVM_DUMP_METHOD void dump() const;
123+
118124
private:
119125
// Partitions are numbered [0, N).
120126
SmallVector<std::unique_ptr<Partition>> partitions;

0 commit comments

Comments
 (0)