Skip to content

Commit 1d2b89d

Browse files
committed
Merge commit '29009f1b136b738d354ffcb4e89c4bd3f2343832'
2 parents 82d5122 + 29009f1 commit 1d2b89d

File tree

38 files changed

+439
-275
lines changed

38 files changed

+439
-275
lines changed

include/triton/Dialect/TritonGPU/IR/Traits.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ class LocalLoadTrait
2222
// Optional: Add methods or verification logic here
2323
};
2424

25+
template <typename ConcreteType>
26+
class MemWaitOpTrait
27+
: public mlir::OpTrait::TraitBase<ConcreteType, MemWaitOpTrait> {
28+
// Optional: Add methods or verification logic here
29+
};
30+
2531
} // namespace OpTrait
2632
} // namespace mlir
2733

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
1414
// Traits used across several attrs.
1515
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
1616
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
17+
def MemWaitOpTrait : NativeOpTrait<"MemWaitOpTrait">;
1718

1819
// Common parameter helpers.
1920
def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
4242
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
4343
}
4444

45-
def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
45+
def TTG_AsyncWaitOp : TTG_Op<"async_wait", [MemWaitOpTrait]> {
4646
let summary = "async wait";
4747

4848
let arguments = (ins Variadic<TTG_AsyncToken>:$asyncToken, I32Attr:$num);

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter"> {
401401
let hasVerifier = 1;
402402
}
403403

404-
def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
404+
def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait", [MemWaitOpTrait]> {
405405
let summary = "wait until all the inputs are read.";
406406
let arguments = (ins I32Attr:$pendings);
407407
let description = [{

lib/Analysis/AxisInfo.cpp

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,11 +1079,10 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
10791079
LogicalResult AxisInfoAnalysis::visitOperation(
10801080
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10811081
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
1082-
// TODO: For sure not the right way to do this
1083-
// but why is scf.if not initialized otherwise?
1082+
// If any operands are not yet ready, skip this operation for now.
10841083
for (auto op : operands)
10851084
if (op->getValue().getRank() == 0)
1086-
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
1085+
return success();
10871086
AxisInfo curr = visitors.apply(op, operands);
10881087
if (curr.getRank() == 0) {
10891088
setAllToEntryStates(results);
@@ -1112,9 +1111,11 @@ void AxisInfoAnalysis::visitForOpInductionVar(
11121111
ProgramPoint *programPoint = getProgramPointAfter(op);
11131112
auto *lbLattice = getLatticeElementFor(programPoint, op.getLowerBound());
11141113
auto *stepLattice = getLatticeElementFor(programPoint, op.getStep());
1115-
for (auto op_iter : {lbLattice, stepLattice})
1116-
if (op_iter->getValue().getRank() == 0)
1117-
setToEntryState((dataflow::Lattice<AxisInfo> *)op_iter);
1114+
// If lb or step is not yet ready, skip this operation for now.
1115+
if (lbLattice->getValue().getRank() == 0 ||
1116+
stepLattice->getValue().getRank() == 0) {
1117+
return;
1118+
}
11181119

11191120
AxisInfo::DimVectorT knownContiguity(1, 1);
11201121
AxisInfo::DimVectorT knownDivisibility(1, 1);
@@ -1188,24 +1189,15 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
11881189
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
11891190
&knownContiguity, &knownDivisibility,
11901191
&knownConstancy);
1191-
} else if (isa<RegionBranchOpInterface, gpu::WarpSpecializePartitionsOp>(
1192-
op)) {
1193-
// scf::ForOp, scf::IfOp, scf::WhileOp, gpu::WarpSpecializePartitionsOp
1194-
// Control flow operations are initialized with "unknown" state:
1195-
// the maximum possible divisibility, contiguity, and constancy.
1192+
} else if (isa<gpu::WarpSpecializePartitionsOp>(op)) {
1193+
// Initialize the arguments to gpu::WarpSpecializePartitionsOp with
1194+
// "unknown" state: the maximum possible divisibility, contiguity, and
1195+
// constancy.
11961196
knownDivisibility = DimVectorT(rank, kMaxDivisor);
11971197
knownConstancy = DimVectorT(rank, kMaxDivisor);
11981198
knownContiguity = DimVectorT(rank, kMaxDivisor);
11991199
}
12001200
} else if (Operation *op = value.getDefiningOp()) {
1201-
if (isa<RegionBranchOpInterface>(op)) {
1202-
// scf::ForOp, scf::IfOp, scf::WhileOp
1203-
// Control flow operations are initialized with "unknown" state:
1204-
// the maximum possible divisibility, contiguity, and constancy.
1205-
knownDivisibility = DimVectorT(rank, kMaxDivisor);
1206-
knownConstancy = DimVectorT(rank, kMaxDivisor);
1207-
knownContiguity = DimVectorT(rank, kMaxDivisor);
1208-
}
12091201
// Other operations are conservatively initialized with the lowest possible
12101202
// divisibility, contiguity, and constancy unless they have specified.
12111203
AxisInfo::initDimVectorFromHint(op->getDiscardableAttr("tt.divisibility"),
@@ -1358,13 +1350,12 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
13581350
auto *axisInfoMap = getFuncData(funcOp);
13591351
auto updateAxisInfoMap = [&](Value value) {
13601352
auto axisInfo = analysis->getLatticeElement(value)->getValue();
1361-
AxisInfo curAxisInfo;
1362-
if (axisInfoMap->count(value)) {
1363-
curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value));
1364-
} else {
1365-
curAxisInfo = axisInfo;
1366-
}
1367-
(*axisInfoMap)[value] = curAxisInfo;
1353+
// If we could not determine the AxisInfo for this value, assume the
1354+
// pessimistic state.
1355+
if (axisInfo.getRank() == 0)
1356+
axisInfo = AxisInfo::getPessimisticValueState(value);
1357+
auto &valInfo = (*axisInfoMap)[value];
1358+
valInfo = AxisInfo::join(axisInfo, valInfo);
13681359
};
13691360
funcOp.walk([&](Operation *op) {
13701361
for (auto value : op->getResults()) {

lib/Analysis/Membar.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
171171
return;
172172
}
173173

174-
if (isa<triton::gpu::AsyncWaitOp, triton::nvidia_gpu::TMAStoreWaitOp>(op) &&
174+
if (op->hasTrait<mlir::OpTrait::MemWaitOpTrait>() &&
175175
!isa<gpu::BarrierOp, triton::gpu::LocalBarrierOp>(op->getNextNode())) {
176176
// If the current op is an async wait and the next op is not a barrier we
177177
// insert a barrier op and sync

lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Value createMemDescToI64(RewriterBase &rewriter, Location loc,
2727
const LLVMTypeConverter *typeConverter,
2828
ttg::MemDescType memDescTy, Value sharedMemStruct) {
2929
TritonLLVMOpBuilder b(loc, rewriter);
30-
if (isa<ttng::TensorMemoryEncodingAttr>(memDescTy.getEncoding())) {
30+
if (isa<ttng::TensorMemorySpaceAttr>(memDescTy.getMemorySpace())) {
3131
return b.ptrtoint(rewriter.getIntegerType(64), sharedMemStruct);
3232
}
3333
assert(isa<ttg::SharedEncodingTrait>(memDescTy.getEncoding()) &&

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,9 +2505,9 @@ LogicalResult DotOperandEncodingAttr::verify(
25052505
return emitError()
25062506
<< "ttg.dot_op kWidth parameter must be 4/8/16 for WMMA v2 "
25072507
"(including packed cases for `scaled_dot`)";
2508-
if (parentAttr.getVersion() == 3 && !llvm::is_contained({2, 8, 16}, kWidth))
2508+
if (parentAttr.getVersion() == 3 && kWidth == 0)
25092509
return emitError()
2510-
<< "ttg.dot_op kWidth parameter must be 2/8/16 for WMMA v3";
2510+
<< "ttg.dot_op kWidth parameter is mandatory for WMMA v3 ";
25112511
return success();
25122512
}
25132513

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class LayoutRematerialization {
127127
}
128128

129129
void cleanup();
130-
void backwardRematerialization();
130+
bool backwardRematerialization();
131131
void backwardRematerialization(ConvertLayoutOp convertOp);
132132
// TODO: Merge the three hoistConvert*(); functions as they are duplicate code
133133
void hoistConvertDotOperand();
@@ -1019,7 +1019,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
10191019
return success();
10201020
}
10211021

1022-
void LayoutRematerialization::backwardRematerialization() {
1022+
bool LayoutRematerialization::backwardRematerialization() {
1023+
bool changed = false;
10231024
// Go through each ConvertLayoutOp.
10241025
SmallVector<ConvertLayoutOp> convertOps;
10251026
funcOp.walk(
@@ -1031,8 +1032,11 @@ void LayoutRematerialization::backwardRematerialization() {
10311032
// backward slices.
10321033
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
10331034
convertOp.getResult());
1035+
} else {
1036+
changed = true;
10341037
}
10351038
}
1039+
return changed;
10361040
}
10371041

10381042
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
@@ -1593,12 +1597,14 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15931597
rewriteSlice(slice, layout, convertOp, mapping);
15941598
}
15951599

1596-
void backwardRematerialization(ModuleOp module) {
1597-
module.walk([](FuncOp funcOp) {
1600+
bool backwardRematerialization(ModuleOp module) {
1601+
bool changed = false;
1602+
module.walk([&](FuncOp funcOp) {
15981603
LayoutRematerialization layoutRemat(funcOp);
1599-
layoutRemat.backwardRematerialization();
1604+
changed |= layoutRemat.backwardRematerialization();
16001605
layoutRemat.cleanup();
16011606
});
1607+
return changed;
16021608
}
16031609

16041610
void hoistConvert(ModuleOp module) {
@@ -1659,17 +1665,20 @@ class TritonGPURemoveLayoutConversionsPass
16591665

16601666
cleanupConvertOps();
16611667

1662-
// 2. For remaining convert ops, try to rematerialize the slice of producer
1663-
// operation to avoid having to convert.
1664-
backwardRematerialization(m);
1665-
LLVM_DEBUG({
1666-
DBGS() << "Module after backward remat:\n";
1667-
m.dump();
1668-
});
1669-
1670-
// Cleanup dummy converts created during backward remat.
1671-
cleanupConvertOps();
1672-
1668+
bool changed = false;
1669+
do {
1670+
changed = false;
1671+
// 2. For remaining convert ops, try to rematerialize the slice of
1672+
// producer operation to avoid having to convert.
1673+
changed = backwardRematerialization(m);
1674+
LLVM_DEBUG({
1675+
DBGS() << "Module after backward remat:\n";
1676+
m.dump();
1677+
});
1678+
1679+
// Cleanup dummy converts created during backward remat.
1680+
cleanupConvertOps();
1681+
} while (changed);
16731682
// 3. For remaining converts, try to hoist them above cast generating larger
16741683
// size types in order to reduce the cost of the convert op.
16751684
hoistConvert(m);

python/src/gluon_ir.cc

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,47 @@ void init_gluon_ir(py::module &&m) {
387387
std::vector<int64_t> &shape) -> py::object {
388388
auto ctx = self.getContext();
389389
auto linearLayout = ttg::toLinearLayout(shape, layout);
390-
auto attr = ttg::LinearEncodingAttr::get(ctx, linearLayout);
391-
return layoutToGluon(attr);
390+
391+
if (isa<ttg::DistributedEncodingTrait>(layout)) {
392+
auto attr = ttg::LinearEncodingAttr::get(ctx, linearLayout);
393+
return layoutToGluon(attr);
394+
}
395+
if (isa<ttg::SharedEncodingTrait>(layout)) {
396+
auto alignment =
397+
cast<ttg::SharedEncodingTrait>(layout).getAlignment();
398+
auto attr = ttg::SharedLinearEncodingAttr::get(ctx, linearLayout,
399+
alignment);
400+
return layoutToGluon(attr);
401+
}
402+
403+
// TensorMemory encodings: keep the LinearLayout but wrap as
404+
// print-only Python object carrying row/col bases -> dim0/dim1.
405+
auto inNamesRange = linearLayout.getInDimNames();
406+
auto inNames = llvm::to_vector(inNamesRange);
407+
bool isTmemLayout =
408+
(inNames.size() == 2 && inNames[0].str() == "row" &&
409+
inNames[1].str() == "col");
410+
if (!isTmemLayout)
411+
throw std::invalid_argument(
412+
"Unsupported layout in to_linear_layout");
413+
414+
// Build Py _TensorMemoryLinearLayout(row_bases, col_bases, shape,
415+
// repr)
416+
py::object tmemCls =
417+
py::module::import(
418+
"triton.experimental.gluon.language.nvidia.blackwell")
419+
.attr("_TensorMemoryLinearLayout");
420+
auto bases = linearLayout.getBases();
421+
auto rowBases = bases[mlir::StringAttr::get(ctx, "row")];
422+
auto colBases = bases[mlir::StringAttr::get(ctx, "col")];
423+
auto outDims = linearLayout.getOutDims();
424+
std::vector<int> shapeVec;
425+
for (auto &od : outDims)
426+
shapeVec.push_back(od.second);
427+
428+
py::object pyObj = tmemCls(py::cast(rowBases), py::cast(colBases),
429+
py::cast(shapeVec));
430+
return pyObj;
392431
})
393432
.def("get_dot_operand_layout",
394433
[](GluonOpBuilder &self, unsigned opIdx, Attribute parent,

0 commit comments

Comments
 (0)