Skip to content

Commit 90e9aec

Browse files
Merge commit '988d3885e17d94d651187899d94c54e45db58fba'
2 parents d635ba8 + 988d388 commit 90e9aec

File tree

12 files changed

+391
-118
lines changed

12 files changed

+391
-118
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -600,11 +600,13 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
600600

601601
createTMABarrierAndWait(forOp, asyncLoads, loadGroups, schedule);
602602

603+
bool hasAsyncLoads = false;
603604
for (auto [op, asyncLoad] : asyncLoads) {
604605
auto [insertIdx, extractIdx, phase, _] = loadGroups[asyncLoad.stageDiff];
605606
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
606607
createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx,
607608
schedule);
609+
hasAsyncLoads = true;
608610
} else if (auto loadOp = dyn_cast<tt::DescriptorLoadOp>(op)) {
609611
createTMAAsyncLoad(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx,
610612
asyncLoad.barrier, asyncLoad.waitOp, schedule);
@@ -628,10 +630,12 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
628630
// correct stages.
629631
scheduleDependencies(forOp, schedule);
630632

631-
// Insert sync point for any possibly outstanding loads after the loop. This
632-
// can happen as we speculatively execute loads in the loop.
633-
builder.setInsertionPointAfter(forOp);
634-
builder.create<ttg::AsyncWaitOp>(loc, ValueRange({}), 0);
633+
if (hasAsyncLoads) {
634+
// Insert sync point for any possibly outstanding loads after the loop. This
635+
// can happen as we speculatively execute loads in the loop.
636+
builder.setInsertionPointAfter(forOp);
637+
builder.create<ttg::AsyncWaitOp>(loc, ValueRange({}), 0);
638+
}
635639

636640
// Make sure all ops have attributes.
637641
for (Operation &op : forOp.getBody()->without_terminator()) {

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 146 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,40 @@ void LayoutRematerialization::hoistConvertIntoConditionals() {
10571057
}
10581058
}
10591059

1060+
static bool isExpensiveMathOp(Operation *op) {
1061+
// These operations are either multiple instructions or have throughput
1062+
// lower than 16 according to the arithmetic instructions table in:
1063+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions
1064+
return isa<arith::DivFOp, math::ErfcOp, math::SinhOp, math::CoshOp,
1065+
math::TanhOp, math::AsinhOp, math::AcoshOp, math::AtanhOp,
1066+
math::CtPopOp, math::CountLeadingZerosOp,
1067+
math::CountTrailingZerosOp, math::ExpOp, math::Exp2Op,
1068+
math::ExpM1Op, math::LogOp, math::Log2Op, math::Log10Op,
1069+
math::Log1pOp, math::SinOp, math::CosOp, math::TanOp, math::AsinOp,
1070+
math::AcosOp, math::AtanOp, math::Atan2Op, math::PowFOp,
1071+
math::SqrtOp, math::RsqrtOp, math::ErfOp, math::CbrtOp>(op);
1072+
}
1073+
1074+
static int64_t getByteCount(Value result, int64_t minElementCount = 0,
1075+
int64_t minBitWidth = 0) {
1076+
int64_t elementCount = 0;
1077+
int64_t dtypeBitWidth = 0;
1078+
if (auto tensorTy = dyn_cast<RankedTensorType>(result.getType())) {
1079+
elementCount = tensorTy.getNumElements();
1080+
auto elemType = tensorTy.getElementType();
1081+
if (elemType.isIntOrFloat()) {
1082+
dtypeBitWidth = elemType.getIntOrFloatBitWidth();
1083+
}
1084+
}
1085+
if (elementCount < minElementCount) {
1086+
elementCount = minElementCount;
1087+
}
1088+
if (dtypeBitWidth < minBitWidth) {
1089+
dtypeBitWidth = minBitWidth;
1090+
}
1091+
return (elementCount * dtypeBitWidth) >> 3;
1092+
}
1093+
10601094
void LayoutRematerialization::backwardRematerialization(
10611095
ConvertLayoutOp convertOp) {
10621096
// DotOperand is hoisted by hoistDotOperand
@@ -1088,12 +1122,112 @@ void LayoutRematerialization::backwardRematerialization(
10881122
return;
10891123
}
10901124

1125+
// 2. Determine whether rematerialisation is beneficial.
1126+
1127+
// Identify all operations in the slice
1128+
SetVector<Operation *> sliceOps;
1129+
for (Value v : slice) {
1130+
if (Operation *op = v.getDefiningOp()) {
1131+
sliceOps.insert(op);
1132+
}
1133+
}
1134+
1135+
// Compute single-use operations
1136+
DenseMap<Operation *, bool> isSingleUse;
1137+
std::function<bool(Operation *)> isOpSingleUse;
1138+
isOpSingleUse = [&](Operation *op) -> bool {
1139+
// lookup in memoization array:
1140+
auto it = isSingleUse.find(op);
1141+
if (it != isSingleUse.end()) {
1142+
return it->second;
1143+
}
1144+
1145+
bool singleUse = true;
1146+
1147+
for (Value result : op->getResults()) {
1148+
for (Operation *user : result.getUsers()) {
1149+
if (user == convertOp) {
1150+
continue;
1151+
}
1152+
if (sliceOps.contains(user)) {
1153+
if (!isOpSingleUse(user)) {
1154+
singleUse = false;
1155+
break;
1156+
}
1157+
} else {
1158+
singleUse = false;
1159+
break;
1160+
}
1161+
}
1162+
if (!singleUse) {
1163+
break;
1164+
}
1165+
}
1166+
1167+
// insert into memoization array:
1168+
isSingleUse[op] = singleUse;
1169+
return singleUse;
1170+
};
1171+
1172+
// Measure the number of bytes that we're manipulating with the
1173+
// ConvertLayoutOp. We pessimistically assume that we round-trip
1174+
// through shared memory and that we cannot vectorise sub-register
1175+
// loads/stores, so we set a minimum element count of 32 (the warp
1176+
// size and number of shared memory banks) and minimum bitwidth of
1177+
// 32 (the width per bank of the shared memory load/store unit).
1178+
int64_t convertLayoutBytes = getByteCount(convertOp.getSrc(), 32, 32);
1179+
1180+
// We measure costs in standardised milli-SM-cycles. This gives:
1181+
// smem load/store: 8 * byte count
1182+
// synchronisation: 1024 (assuming 4 warps per block)
1183+
int64_t convertLayoutCost = 16 * convertLayoutBytes + 1024;
1184+
int64_t rematerialisationCost = 0;
1185+
1186+
// Evaluate single-use status for every operation in slice
1187+
for (Operation *op : sliceOps) {
1188+
auto dialect = op->getDialect();
1189+
if (isOpSingleUse(op)) {
1190+
// when we rematerialise, this operation does not get duplicated
1191+
// so it does not contribute to our cost model:
1192+
continue;
1193+
} else if (isa<arith::ConstantOp>(op)) {
1194+
// special-case: arith.constant has zero cost
1195+
continue;
1196+
} else if (isa<LoadOp>(op)) {
1197+
// optimistically assume L1-cached:
1198+
for (Value result : op->getResults()) {
1199+
rematerialisationCost += 8 * getByteCount(result);
1200+
}
1201+
} else if (isa<arith::ArithDialect, math::MathDialect>(dialect)) {
1202+
// this is an arithmetic operation; we distinguish between cheap
1203+
// operations (such as floating point add/mul which can be fused
1204+
// as halves of a single-cycle FMA instruction) and expensive
1205+
// operations which use the special function unit and/or involve
1206+
// multiple instructions.
1207+
int64_t multiplier = isExpensiveMathOp(op) ? 8 : 1;
1208+
for (Value result : op->getResults()) {
1209+
rematerialisationCost += multiplier * getByteCount(result);
1210+
}
1211+
}
1212+
}
1213+
1214+
LLVM_DEBUG({
1215+
DBGS() << " convert layout cost: " << convertLayoutCost << "\n";
1216+
DBGS() << " rematerialisation cost: " << rematerialisationCost << "\n";
1217+
});
1218+
1219+
if (rematerialisationCost > convertLayoutCost) {
1220+
LDBG(" skipped rematerialization due to higher cost");
1221+
return;
1222+
}
1223+
10911224
LLVM_DEBUG({
10921225
DBGS() << " remat convert op " << convertOp << '\n';
10931226
for (Value v : slice)
10941227
DBGS() << " " << v << '\n';
10951228
});
1096-
// 2. Rewrite the slice.
1229+
1230+
// 3. Rewrite the slice.
10971231
rewriteSlice(slice, layout, convertOp);
10981232
}
10991233

@@ -1179,30 +1313,32 @@ void LayoutRematerialization::hoistConvertDotOperand(
11791313
{ DBGS() << " Block arguments not supported. Got " << v << "\n"; });
11801314
return;
11811315
}
1182-
auto loadOp = dyn_cast<LoadOp>(v.getDefiningOp());
1183-
// We expect the leaves of the slice to be Load or arith::Constant
1184-
// This could be generalised if necessary
1185-
if (!loadOp) {
1316+
1317+
// We expect the leaves of the slice to be Load, DescriptorLoad or
1318+
// arith::Constant This could be generalised if necessary
1319+
if (!isa<LoadOp, DescriptorLoadOp>(v.getDefiningOp())) {
11861320
auto op = v.getDefiningOp();
11871321
if (isa<arith::ConstantOp>(op) || noDataMovement(op)) {
11881322
innerSlice.insert(v);
11891323
continue;
11901324
} else {
11911325
LLVM_DEBUG({
1192-
DBGS() << " Leaves must be Load or Constant. Got " << v << "\n";
1326+
DBGS() << " Leaves must be Load, DescriptorLoad or Constant. Got "
1327+
<< v << "\n";
11931328
});
11941329
return;
11951330
}
11961331
}
1332+
Operation *loadOp = v.getDefiningOp();
11971333
builder.setInsertionPointAfter(loadOp);
1198-
auto type = dyn_cast<RankedTensorType>(loadOp.getType());
1334+
auto type = dyn_cast<RankedTensorType>(loadOp->getResult(0).getType());
11991335
if (!type)
12001336
continue;
12011337
auto newType = RankedTensorType::get(type.getShape(), type.getElementType(),
1202-
layout[loadOp]);
1338+
layout[loadOp->getResult(0)]);
12031339
auto newConvertOp = builder.create<ConvertLayoutOp>(
1204-
convertOp.getLoc(), newType, loadOp.getResult());
1205-
mapping.map(loadOp.getResult(), newConvertOp.getResult());
1340+
convertOp.getLoc(), newType, loadOp->getResult(0));
1341+
mapping.map(loadOp->getResult(0), newConvertOp.getResult());
12061342
}
12071343

12081344
if (innerSlice.empty()) {

lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -175,21 +175,20 @@ static TMemChunk allocFirstFit(MemoryBitMap &memoryMap,
175175
}
176176

177177
static Operation *getAlloc(Value value) {
178-
Operation *op = value.getDefiningOp();
179-
while (auto subOp = dyn_cast<triton::gpu::MemDescSubviewOp>(op)) {
180-
if (subOp.getSrc().getDefiningOp()) {
181-
op = subOp.getSrc().getDefiningOp();
182-
} else {
183-
auto arg = cast<BlockArgument>(subOp.getSrc());
184-
auto partitions =
185-
cast<WarpSpecializePartitionsOp>(arg.getOwner()->getParentOp());
186-
WarpSpecializeOp wsOp = partitions.getParentOp();
187-
auto capture = wsOp.getExplicitCaptures()[arg.getArgNumber()];
188-
op = capture.getDefiningOp();
178+
while (true) {
179+
if (auto allocOp = value.getDefiningOp<TMEMAllocOp>())
180+
return allocOp;
181+
if (auto subviewOp = value.getDefiningOp<triton::gpu::MemDescSubviewOp>()) {
182+
value = subviewOp.getSrc();
183+
continue;
189184
}
185+
auto arg = dyn_cast<BlockArgument>(value);
186+
if (!arg || !isa<WarpSpecializePartitionsOp>(arg.getOwner()->getParentOp()))
187+
llvm::report_fatal_error("expected to find a TMEM alloc op");
188+
auto partitions =
189+
cast<WarpSpecializePartitionsOp>(arg.getOwner()->getParentOp());
190+
value = partitions.getParentOp().getExplicitCaptures()[arg.getArgNumber()];
190191
}
191-
assert(isa<triton::nvidia_gpu::TMEMAllocOp>(op) && "Expected a TMEMAllocOp");
192-
return op;
193192
}
194193

195194
class RowIdConstraints {

lib/Tools/LinearLayout.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,13 @@ LinearLayout LinearLayout::removeZeroBasesAlongDim(StringAttr stripDim) const {
10881088
}
10891089
}
10901090
}
1091-
return LinearLayout(std::move(result), llvm::to_vector(getOutDimNames()));
1091+
SmallVector<std::pair<StringAttr, int32_t>> newOutDimSizes;
1092+
for (auto outDim : getOutDimNames()) {
1093+
newOutDimSizes.push_back({outDim, getOutDimSize(outDim)});
1094+
}
1095+
auto newLayout = LinearLayout(std::move(result), ArrayRef(newOutDimSizes),
1096+
this->isSurjective());
1097+
return newLayout;
10921098
}
10931099

10941100
size_t hash_value(const LinearLayout &layout) {

setup.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,22 @@ def get_git_version_suffix():
776776
# keep it separate for easy substitution
777777
TRITON_VERSION = "3.3.0" + get_git_version_suffix() + os.environ.get("TRITON_WHEEL_VERSION_SUFFIX", "")
778778

779+
# Dynamically define supported Python versions and classifiers
780+
MIN_PYTHON = (3, 9)
781+
MAX_PYTHON = (3, 13)
782+
783+
PYTHON_REQUIRES = f">={MIN_PYTHON[0]}.{MIN_PYTHON[1]},<{MAX_PYTHON[0]}.{MAX_PYTHON[1] + 1}"
784+
BASE_CLASSIFIERS = [
785+
"Development Status :: 4 - Beta",
786+
"Intended Audience :: Developers",
787+
"Topic :: Software Development :: Build Tools",
788+
"License :: OSI Approved :: MIT License",
789+
]
790+
PYTHON_CLASSIFIERS = [
791+
f"Programming Language :: Python :: {MIN_PYTHON[0]}.{m}" for m in range(MIN_PYTHON[1], MAX_PYTHON[1] + 1)
792+
]
793+
CLASSIFIERS = BASE_CLASSIFIERS + PYTHON_CLASSIFIERS
794+
779795
setup(
780796
name=os.environ.get("TRITON_WHEEL_NAME", "triton"),
781797
version=TRITON_VERSION,
@@ -807,17 +823,8 @@ def get_git_version_suffix():
807823
# for PyPI
808824
keywords=["Compiler", "Deep Learning"],
809825
url="https://github.com/triton-lang/triton/",
810-
classifiers=[
811-
"Development Status :: 4 - Beta",
812-
"Intended Audience :: Developers",
813-
"Topic :: Software Development :: Build Tools",
814-
"License :: OSI Approved :: MIT License",
815-
"Programming Language :: Python :: 3.9",
816-
"Programming Language :: Python :: 3.10",
817-
"Programming Language :: Python :: 3.11",
818-
"Programming Language :: Python :: 3.12",
819-
"Programming Language :: Python :: 3.13",
820-
],
826+
python_requires=PYTHON_REQUIRES,
827+
classifiers=CLASSIFIERS,
821828
test_suite="tests",
822829
extras_require={
823830
"build": [

0 commit comments

Comments
 (0)