Skip to content

Commit 674bd20

Browse files
authored
Merge branch 'main' into yudong/tune-in-ci
2 parents 9bdca90 + 85682e4 commit 674bd20

File tree

82 files changed

+5281
-937
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+5281
-937
lines changed

CMakeLists.txt

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ set(CMAKE_CXX_STANDARD 17)
1212

1313
set(CMAKE_INCLUDE_CURRENT_DIR ON)
1414

15-
project(triton)
15+
project(triton CXX)
1616
include(CTest)
1717

1818
if(NOT WIN32)
@@ -26,8 +26,25 @@ option(TRITON_BUILD_TUTORIALS "Build C++ Triton tutorials" ON)
2626
option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
2727
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
2828
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
29+
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
2930
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
3031

32+
if(TRITON_BUILD_WITH_CCACHE)
33+
find_program(CCACHE_PROGRAM ccache)
34+
if(CCACHE_PROGRAM)
35+
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}"
36+
CACHE STRING "C compiler launcher")
37+
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}"
38+
CACHE STRING "CXX compiler launcher")
39+
else()
40+
message(
41+
STATUS
42+
"Could not find ccache. Consider installing ccache to speed up compilation."
43+
)
44+
endif()
45+
endif()
46+
47+
3148
# Ensure Python3 vars are set correctly
3249
# used conditionally in this file and by lit tests
3350

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
b5cc222d7429fe6f18c787f633d5262fac2e676f
1+
fa57c7a6a5f594a9e3ae2dbe3542cf89a20cdd73

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def documenter(app, obj, parent):
145145
autosummary_generate = True
146146

147147
# versioning config
148-
smv_tag_whitelist = r'^(v3.1.0)$'
148+
smv_tag_whitelist = r'^(v3.2.0)$'
149149
smv_branch_whitelist = r'^main$'
150150
smv_remote_whitelist = None
151151
smv_released_pattern = r'^tags/.*$'

include/triton/Analysis/Utility.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class ReduceOpHelper {
6666
// The shape of the shared memory space needed for the reduction.
6767
SmallVector<unsigned> getScratchRepShape();
6868

69-
SmallVector<unsigned> getOrderWithAxisAtBeginning();
69+
SmallVector<unsigned> getThreadOrderWithAxisAtBeginning();
7070

7171
unsigned getScratchSizeInBytes();
7272

@@ -212,7 +212,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
212212

213213
bool atomicNeedsSharedMemory(Value result);
214214

215-
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
215+
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
216216

217217
// Return true if the src and dst layout match.
218218
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
108108
let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)";
109109

110110
let hasVerifier = 1;
111+
112+
let hasFolder = 1;
111113
}
112114

113115
//
@@ -891,7 +893,7 @@ def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
891893
`tt.assert` takes a condition tensor and a message string.
892894
If the condition is false, the message is printed, and the program is aborted.
893895
}];
894-
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
896+
let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message);
895897
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
896898
}
897899

include/triton/Tools/LinearLayout.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,13 @@ class LinearLayout {
679679
// (i.e. every input bit affects the output).
680680
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;
681681

682+
// Increase an input dimension without affecting the output dimension. The
683+
// added free variables are mapped to 0, ensuring that the new input
684+
// dimensions correspond directly to the existing output space. The function
685+
// errors out if `newInDimSize` is less than the current size or the new size
686+
// is not a power of 2.
687+
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;
688+
682689
std::string toString() const;
683690

684691
friend bool operator==(LinearLayout lhs, LinearLayout rhs);

lib/Analysis/Allocation.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
4646
auto dstShapePerCTATile =
4747
gpu::getShapePerCTATile(dstLayout, dstTy.getShape());
4848

49+
assert(srcTy.getRank() == dstTy.getRank() &&
50+
"src and dst must have the same rank");
51+
4952
unsigned rank = dstTy.getRank();
5053
SmallVector<unsigned> repShape(rank);
5154
for (unsigned d = 0; d < rank; ++d) {

lib/Analysis/AxisInfo.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12131213

12141214
// Here order should be ordered by contiguous first, so the first element
12151215
// should have the largest contiguous.
1216-
auto order = triton::gpu::getOrder(layout);
1216+
auto order = triton::gpu::getThreadOrder(layout);
12171217
unsigned align = getPtrAlignment(ptr);
12181218

12191219
auto uniqueContigPerThread =
@@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
12351235
if (!axisInfo)
12361236
return 1;
12371237
auto layout = tensorTy.getEncoding();
1238-
auto order = triton::gpu::getOrder(layout);
1238+
auto order = triton::gpu::getThreadOrder(layout);
12391239
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
12401240
auto maxContig = axisInfo->getContiguity(order[0]);
12411241
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
@@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
12621262
auto *axisInfo = getAxisInfo(mask);
12631263
if (!axisInfo)
12641264
return 1;
1265-
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
1265+
auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding());
12661266
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
12671267
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
12681268
<< alignment);

lib/Analysis/Utility.cpp

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ int getParentAxis(Attribute layout, int axis) {
3434
return axis;
3535
}
3636

37-
SmallVector<unsigned> getParentOrder(Attribute layout) {
37+
SmallVector<unsigned> getParentThreadOrder(Attribute layout) {
3838
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
39-
return getParentOrder(sliceEncoding.getParent());
39+
return getParentThreadOrder(sliceEncoding.getParent());
4040
}
4141
return getThreadOrder(layout);
4242
}
@@ -46,12 +46,12 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
4646
// TODO(jlebar): Move this class into namespace triton.
4747
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
4848
return getParentAxis(getSrcLayout(), axis) ==
49-
getParentOrder(getSrcLayout())[0];
49+
getParentThreadOrder(getSrcLayout())[0];
5050
}
5151

52-
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
52+
SmallVector<unsigned> ReduceOpHelper::getThreadOrderWithAxisAtBeginning() {
5353
auto srcLayout = getSrcLayout();
54-
auto order = getOrder(srcLayout);
54+
auto order = getThreadOrder(srcLayout);
5555
auto it = std::find(order.begin(), order.end(), axis);
5656
// delete the axis from order
5757
order.erase(it);
@@ -543,7 +543,7 @@ bool supportMMA(Value value, int version) {
543543
(elemTy.isInteger(8) && version >= 2);
544544
}
545545

546-
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
546+
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
547547
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
548548
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
549549
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
@@ -646,8 +646,46 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
646646
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
647647
if (!(srcLayout.has_value() && dstLayout.has_value()))
648648
return std::nullopt;
649+
StringAttr kRegister = StringAttr::get(ctx, "register");
650+
StringAttr kLane = StringAttr::get(ctx, "lane");
651+
StringAttr kWarp = StringAttr::get(ctx, "warp");
652+
StringAttr kBlock = StringAttr::get(ctx, "block");
653+
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
654+
auto numDstRegs = dstLayout->getInDimSize(kRegister);
655+
// The `invertAndCompose` function will generate a layout that is injective
656+
// by assigning new output dimensions to free variables. For instance,
657+
// consider a scenario where `srcLayout` has a free variable in the lane
658+
// dimension, while `dstLayout` has two free variables in the lane
659+
// dimension and also a larger number of registers.
660+
// The injective form of `srcLayout` will add only a single additional row
661+
// to the transformation matrix, whereas the injective form of `dstLayout`
662+
// will add two additional rows. This discrepancy causes misleading results
663+
// because the matrices end up with a different number of rows.
664+
//
665+
// Take `dstLayout ⋅ srcLayout^-1` as an example:
666+
//
667+
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
668+
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
669+
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
670+
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
671+
// 1] → [n + 2, n + 1]
672+
//
673+
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
674+
// variable in registers, and the `(n + 2)`-th row represents the free
675+
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
676+
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
677+
// in two layouts do not correspond to the same free variable.
678+
//
679+
// To address this issue, we pad the free variables in `srcLayout` and
680+
// `dstLayout` to ensure they have the same number of registers. This
681+
// guarantees that the resulting matrices have the same number of rows,
682+
// ensuring consistency in the composition process.
683+
auto numRegs = std::max(numSrcRegs, numDstRegs);
684+
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
685+
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
649686
// comp describes the layout function to create dst from src.
650-
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
687+
LinearLayout comp =
688+
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);
651689
// We try to quotient by the largest subspace first
652690
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
653691
for (auto dim : dims) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -328,20 +328,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
328328
} else {
329329
// Cast 5. The two layouts are equivalent. We should probably remove
330330
// these in RemoveLayoutConversion.
331-
auto dstCvt = requiresI32Conversion(dstTy);
332-
auto srcCvt = requiresI32Conversion(srcTy);
333-
if (dstCvt || srcCvt) {
334-
auto inVals = unpackLLElements(op.getLoc(), adaptor.getSrc(), rewriter);
335-
inVals = unpackI32s(inVals, srcTy, rewriter, op.getLoc(),
336-
getTypeConverter());
337-
inVals =
338-
packI32s(inVals, dstTy, rewriter, op.getLoc(), getTypeConverter());
339-
auto res = packLLElements(op.getLoc(), getTypeConverter(), inVals,
340-
rewriter, op.getType());
341-
rewriter.replaceOp(op, res);
342-
} else {
343-
rewriter.replaceOp(op, adaptor.getSrc());
344-
}
331+
rewriter.replaceOp(op, adaptor.getSrc());
345332
return success();
346333
}
347334
}
@@ -358,9 +345,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
358345
auto srcTy = op.getSrc().getType();
359346
auto dstTy = op.getType();
360347
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
361-
inVals = unpackI32s(inVals, srcTy, rewriter, loc, getTypeConverter());
362348
SmallVector<Value> outVals(numRegs);
363-
for (int i = 0; i < numRegs; i++) {
349+
for (int i = 0; i < outVals.size(); i++) {
364350
// Remove free masks from the register index
365351
// For example, if idx = 0b00111, and masks = 0b00100, then we get
366352
// 0b00011. It means that register 7 (0b111) has the same value as

0 commit comments

Comments
 (0)