Skip to content

Commit 7e60978

Browse files
authored
Merge branch 'main' into llvm-head
2 parents aef3768 + c1ed673 commit 7e60978

File tree

24 files changed

+1096
-800
lines changed

24 files changed

+1096
-800
lines changed

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,12 @@ class DialectInferLayoutInterface
7272

7373
virtual LogicalResult
7474
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
75+
ArrayRef<int64_t> shape,
7576
std::optional<Location> loc) const = 0;
7677

7778
virtual LogicalResult
7879
inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc,
80+
ArrayRef<int64_t> shape,
7981
std::optional<Location> loc) const = 0;
8082

8183
// Verify that the encoding are compatible to be used together in a dot

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,9 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
620620
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape) const;
621621
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape) const;
622622

623+
SmallVector<unsigned int> getContig(const char *, SmallVector<unsigned int>) const;
623624
SmallVector<unsigned> getContigPerThread() const;
625+
SmallVector<unsigned> getContigPerWarp() const;
624626
SmallVector<unsigned> getOrder() const;
625627

626628
// Generalizes get{Warp,Thread,CTA}Order to linear layouts.

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ unsigned
4949
getNumElementsPerThread(Operation *op, SmallVector<unsigned> order,
5050
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis);
5151

52+
// Returns whether the op is a "view op", i.e. doesn't move any data
53+
bool isView(Operation *op);
54+
5255
/* Dump Triton IR in graphviz dot format.
5356
*
5457
* You can override `onValue` and `onOperation` in a subclass to mark

include/triton/Tools/LayoutUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ LinearLayout ensureLayoutNotSmallerThan(
8787
// are "dim0", "dim1", etc.
8888
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank);
8989

90+
// Return a vector of the standard out dimension name/value pairs, i.e.
91+
// ("dim0", dstShape[0]), ("dim1", dstShape[1]), etc.
92+
SmallVector<std::pair<StringAttr, int32_t>>
93+
standardOutDimPairs(MLIRContext *ctx, ArrayRef<int64_t> dstShape);
94+
9095
// Return an identity mapping from `inDimName` to the standard out dimensions,
9196
// with the dimensions sized according to the shape. The bases are sorted
9297
// according to `order`, with the most minor dimension first.

lib/Analysis/Membar.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ void MembarAnalysis::resolve(FunctionOpInterface funcOp,
6060
// the outputBlockInfo, we skip the successors
6161
continue;
6262
}
63-
// Update the current block
64-
outputBlockInfoMap[block].join(inputBlockInfo);
63+
// Update the current block. The block transfer function is not monotonic,
64+
// so overwrite the output state entirely.
65+
outputBlockInfoMap[block] = inputBlockInfo;
6566
// Update the successors
6667
for (auto *successor : successors) {
6768
inputBlockInfoMap[successor].join(outputBlockInfoMap[block]);

lib/Analysis/Utility.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,11 @@ bool supportMMA(triton::DotOp op, int version) {
638638
return false;
639639
if (op.getType().getRank() != 2)
640640
return false;
641-
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
641+
if (numWarps != 4 && numWarps != 8) {
642+
// Currently only support numWarps 4 or 8 for TMEM load and store.
643+
return false;
644+
}
645+
if (!(retShapePerCTA[rank - 2] % 64 == 0 &&
642646
retShapePerCTA[rank - 1] % 8 == 0))
643647
return false;
644648
return true;

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,7 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
10461046
Attribute retEnc;
10471047
if (srcEnc) {
10481048
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
1049-
->inferJoinOpEncoding(srcEnc, retEnc, location)
1049+
->inferJoinOpEncoding(srcEnc, retEnc, srcTy.getShape(), location)
10501050
.failed()) {
10511051
return failure();
10521052
}
@@ -1079,7 +1079,7 @@ LogicalResult SplitOp::inferReturnTypes(
10791079
Attribute retEnc;
10801080
if (srcEnc) {
10811081
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
1082-
->inferSplitOpEncoding(srcEnc, retEnc, location)
1082+
->inferSplitOpEncoding(srcEnc, retEnc, srcTy.getShape(), location)
10831083
.failed()) {
10841084
return failure();
10851085
}

0 commit comments

Comments
 (0)