Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
85400f8
[BACKEND] run remove backward prop until a fix point (#8776)
ThomasRaoux Nov 20, 2025
80104b7
[AMD] Preserve Denorms for precise sqrt (#8697)
mengfei-jiang Nov 20, 2025
feedad8
[AMD] Fix `lowerLoops`: only erase load ops which are converted (#8737)
leonling-ll Nov 20, 2025
fe7838c
[consan] Handle all tmem allocations (#8787)
ThomasRaoux Nov 20, 2025
2e1a036
[AMD] Enabling Buffer Atomic for RDNA4 (#8778)
saeid-rostami Nov 20, 2025
ecfaec2
[PROTON][AMD] Fix failing proton tests for AMD GPUs (#8763)
ZelboK Nov 20, 2025
89a3c6e
[AMD] Make kWidth to mandatory for WMMA v3 (#8783)
borontion Nov 20, 2025
db14c2d
[AMD] Allow async load global to load block dimension duplication (#8…
AlexAUT Nov 20, 2025
31281bc
[Reland] Fix handling of unvisited operands in AxisInfoAnalysis (#8758)
neildhar Nov 20, 2025
6294db5
[KERNELS] fix persistent matmul heuristics (#8791)
aeng-openai Nov 21, 2025
4823a6e
[BACKEND] Support clamp optimization on scalars (#8796)
peterbell10 Nov 21, 2025
96bba6b
[AMD] Replace usage of llvm copysign intrinsic (#8789)
xiaohuguo2023 Nov 21, 2025
4cf9906
[AMD] Extended membar analysis with third_party ops using a trait (#8…
ravil-mobile Nov 21, 2025
046ab0e
[NFC] Simplify populating axisinfo map (#8800)
neildhar Nov 21, 2025
4b184cc
patch workaround by correctly setting stage/cluster attrubtes (#8797)
3gx Nov 21, 2025
29009f1
[GLUON] Allow TensorMemory layouts in `to_linear_layout` in the conte…
lezcano Nov 22, 2025
1d2b89d
Merge commit '29009f1b136b738d354ffcb4e89c4bd3f2343832'
anmyachev Dec 1, 2025
546a718
Revert "[Reland] Fix handling of unvisited operands in AxisInfoAnalys…
anmyachev Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ class LocalLoadTrait
// Optional: Add methods or verification logic here
};

template <typename ConcreteType>
class MemWaitOpTrait
: public mlir::OpTrait::TraitBase<ConcreteType, MemWaitOpTrait> {
// Optional: Add methods or verification logic here
};

} // namespace OpTrait
} // namespace mlir

Expand Down
1 change: 1 addition & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"
// Traits used across several attrs.
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
def MemWaitOpTrait : NativeOpTrait<"MemWaitOpTrait">;

// Common parameter helpers.
def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def TTG_ConvertLayoutOp : TTG_Op<"convert_layout",
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
}

def TTG_AsyncWaitOp : TTG_Op<"async_wait"> {
def TTG_AsyncWaitOp : TTG_Op<"async_wait", [MemWaitOpTrait]> {
let summary = "async wait";

let arguments = (ins Variadic<TTG_AsyncToken>:$asyncToken, I32Attr:$num);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def TTNG_AsyncTMAScatterOp : TTNG_Op<"async_tma_scatter"> {
let hasVerifier = 1;
}

def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait"> {
def TTNG_TMAStoreWaitOp : TTNG_Op<"async_tma_store_wait", [MemWaitOpTrait]> {
let summary = "wait until all the inputs are read.";
let arguments = (ins I32Attr:$pendings);
let description = [{
Expand Down
9 changes: 2 additions & 7 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1358,13 +1358,8 @@ void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
auto *axisInfoMap = getFuncData(funcOp);
auto updateAxisInfoMap = [&](Value value) {
auto axisInfo = analysis->getLatticeElement(value)->getValue();
AxisInfo curAxisInfo;
if (axisInfoMap->count(value)) {
curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value));
} else {
curAxisInfo = axisInfo;
}
(*axisInfoMap)[value] = curAxisInfo;
auto &valInfo = (*axisInfoMap)[value];
valInfo = AxisInfo::join(axisInfo, valInfo);
};
funcOp.walk([&](Operation *op) {
for (auto value : op->getResults()) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/Membar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
return;
}

if (isa<triton::gpu::AsyncWaitOp, triton::nvidia_gpu::TMAStoreWaitOp>(op) &&
if (op->hasTrait<mlir::OpTrait::MemWaitOpTrait>() &&
!isa<gpu::BarrierOp, triton::gpu::LocalBarrierOp>(op->getNextNode())) {
// If the current op is an async wait and the next op is not a barrier we
// insert a barrier op and sync
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Value createMemDescToI64(RewriterBase &rewriter, Location loc,
const LLVMTypeConverter *typeConverter,
ttg::MemDescType memDescTy, Value sharedMemStruct) {
TritonLLVMOpBuilder b(loc, rewriter);
if (isa<ttng::TensorMemoryEncodingAttr>(memDescTy.getEncoding())) {
if (isa<ttng::TensorMemorySpaceAttr>(memDescTy.getMemorySpace())) {
return b.ptrtoint(rewriter.getIntegerType(64), sharedMemStruct);
}
assert(isa<ttg::SharedEncodingTrait>(memDescTy.getEncoding()) &&
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2505,9 +2505,9 @@ LogicalResult DotOperandEncodingAttr::verify(
return emitError()
<< "ttg.dot_op kWidth parameter must be 4/8/16 for WMMA v2 "
"(including packed cases for `scaled_dot`)";
if (parentAttr.getVersion() == 3 && !llvm::is_contained({2, 8, 16}, kWidth))
if (parentAttr.getVersion() == 3 && kWidth == 0)
return emitError()
<< "ttg.dot_op kWidth parameter must be 2/8/16 for WMMA v3";
<< "ttg.dot_op kWidth parameter is mandatory for WMMA v3 ";
return success();
}

Expand Down
41 changes: 25 additions & 16 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class LayoutRematerialization {
}

void cleanup();
void backwardRematerialization();
bool backwardRematerialization();
void backwardRematerialization(ConvertLayoutOp convertOp);
// TODO: Merge the three hoistConvert*(); functions as they are duplicate code
void hoistConvertDotOperand();
Expand Down Expand Up @@ -1019,7 +1019,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice(
return success();
}

void LayoutRematerialization::backwardRematerialization() {
bool LayoutRematerialization::backwardRematerialization() {
bool changed = false;
// Go through each ConvertLayoutOp.
SmallVector<ConvertLayoutOp> convertOps;
funcOp.walk(
Expand All @@ -1031,8 +1032,11 @@ void LayoutRematerialization::backwardRematerialization() {
// backward slices.
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
convertOp.getResult());
} else {
changed = true;
}
}
return changed;
}

void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
Expand Down Expand Up @@ -1593,12 +1597,14 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
rewriteSlice(slice, layout, convertOp, mapping);
}

void backwardRematerialization(ModuleOp module) {
module.walk([](FuncOp funcOp) {
bool backwardRematerialization(ModuleOp module) {
bool changed = false;
module.walk([&](FuncOp funcOp) {
LayoutRematerialization layoutRemat(funcOp);
layoutRemat.backwardRematerialization();
changed |= layoutRemat.backwardRematerialization();
layoutRemat.cleanup();
});
return changed;
}

void hoistConvert(ModuleOp module) {
Expand Down Expand Up @@ -1659,17 +1665,20 @@ class TritonGPURemoveLayoutConversionsPass

cleanupConvertOps();

// 2. For remaining convert ops, try to rematerialize the slice of producer
// operation to avoid having to convert.
backwardRematerialization(m);
LLVM_DEBUG({
DBGS() << "Module after backward remat:\n";
m.dump();
});

// Cleanup dummy converts created during backward remat.
cleanupConvertOps();

bool changed = false;
do {
changed = false;
// 2. For remaining convert ops, try to rematerialize the slice of
// producer operation to avoid having to convert.
changed = backwardRematerialization(m);
LLVM_DEBUG({
DBGS() << "Module after backward remat:\n";
m.dump();
});

// Cleanup dummy converts created during backward remat.
cleanupConvertOps();
} while (changed);
// 3. For remaining converts, try to hoist them above cast generating larger
// size types in order to reduce the cost of the convert op.
hoistConvert(m);
Expand Down
43 changes: 41 additions & 2 deletions python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,47 @@ void init_gluon_ir(py::module &&m) {
std::vector<int64_t> &shape) -> py::object {
auto ctx = self.getContext();
auto linearLayout = ttg::toLinearLayout(shape, layout);
auto attr = ttg::LinearEncodingAttr::get(ctx, linearLayout);
return layoutToGluon(attr);

if (isa<ttg::DistributedEncodingTrait>(layout)) {
auto attr = ttg::LinearEncodingAttr::get(ctx, linearLayout);
return layoutToGluon(attr);
}
if (isa<ttg::SharedEncodingTrait>(layout)) {
auto alignment =
cast<ttg::SharedEncodingTrait>(layout).getAlignment();
auto attr = ttg::SharedLinearEncodingAttr::get(ctx, linearLayout,
alignment);
return layoutToGluon(attr);
}

// TensorMemory encodings: keep the LinearLayout but wrap as
// print-only Python object carrying row/col bases -> dim0/dim1.
auto inNamesRange = linearLayout.getInDimNames();
auto inNames = llvm::to_vector(inNamesRange);
bool isTmemLayout =
(inNames.size() == 2 && inNames[0].str() == "row" &&
inNames[1].str() == "col");
if (!isTmemLayout)
throw std::invalid_argument(
"Unsupported layout in to_linear_layout");

// Build Py _TensorMemoryLinearLayout(row_bases, col_bases, shape,
// repr)
py::object tmemCls =
py::module::import(
"triton.experimental.gluon.language.nvidia.blackwell")
.attr("_TensorMemoryLinearLayout");
auto bases = linearLayout.getBases();
auto rowBases = bases[mlir::StringAttr::get(ctx, "row")];
auto colBases = bases[mlir::StringAttr::get(ctx, "col")];
auto outDims = linearLayout.getOutDims();
std::vector<int> shapeVec;
for (auto &od : outDims)
shapeVec.push_back(od.second);

py::object pyObj = tmemCls(py::cast(rowBases), py::cast(colBases),
py::cast(shapeVec));
return pyObj;
})
.def("get_dot_operand_layout",
[](GluonOpBuilder &self, unsigned opIdx, Attribute parent,
Expand Down
53 changes: 17 additions & 36 deletions python/test/gluon/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,48 +1461,29 @@ def kernel(reg_type: ttgl.constexpr, shared_type: ttgl.constexpr, ref_conflicts:


@pytest.mark.parametrize(
"layout, expected",
"layout, shape",
[
(
ttgl.BlockedLayout([1], [4], [4], [0]),
ttgl.DistributedLinearLayout(
reg_bases=[],
lane_bases=[[1], [2]],
warp_bases=[[4], [8]],
block_bases=[],
shape=[16],
),
),
(
ttgl.BlockedLayout([1], [4], [4], [0], [[1], [0]]),
ttgl.DistributedLinearLayout(
reg_bases=[],
lane_bases=[[1], [2]],
warp_bases=[[4], [8]],
block_bases=[[16], [0]],
shape=[32],
),
),
(
ttgl.BlockedLayout([8, 1], [8, 4], [1, 4], [0, 1], [[0, 1]]),
ttgl.DistributedLinearLayout(
reg_bases=[[1, 0], [2, 0], [4, 0], [0, 16], [0, 32]],
lane_bases=[[8, 0], [16, 0], [32, 0], [0, 1], [0, 2]],
warp_bases=[[0, 4], [0, 8]],
block_bases=[[0, 64]],
shape=[64, 128],
),
),
(ttgl.BlockedLayout([1], [4], [4], [0]), [16]),
(ttgl.BlockedLayout([1], [4], [4], [0], [[1], [0]]), [32]),
(ttgl.BlockedLayout([8, 1], [8, 4], [1, 4], [0, 1], [[0, 1]]), [64, 128]),
(ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2), [64, 64]),
(TensorMemoryLayout((64, 64), col_stride=2), [64, 64]),
],
)
def test_to_linear_layout(layout, expected):
def test_to_linear_layout(layout, shape, capsys):

@gluon.jit
def kernel(layout: ttgl.constexpr, expected: ttgl.constexpr, shape: ttgl.constexpr):
def kernel(layout: ttgl.constexpr, shape: ttgl.constexpr):
computed: ttgl.constexpr = ttgl.to_linear_layout(layout, shape)
ttgl.static_assert(computed == expected)

run_parser(kernel, args=(layout, expected, tuple(expected.shape)), target=AMPERE_TARGET)
ttgl.static_print(computed)

run_parser(kernel, args=(layout, tuple(shape)), target=AMPERE_TARGET)
out = capsys.readouterr().out
if isinstance(layout, TensorMemoryLayout):
assert "rows=" in out
assert "cols=" in out
else:
assert "DistributedLinearLayout" in out or "SharedLinearLayout" in out


@filecheck_test
Expand Down
10 changes: 10 additions & 0 deletions python/triton/experimental/gluon/language/_layouts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
import itertools
from typing import List

from triton.language.core import _unwrap_if_constexpr, _unwrap_shape, constexpr_type
Expand Down Expand Up @@ -636,6 +637,15 @@ def _to_ir(self, builder):
def mangle(self) -> str:
return f"SharedLinear_{self.offset_bases}_{self.block_bases}_{self.alignment}_SharedLinear"

@property
def shape(self):
rank = len(self.offset_bases[0])
max_stride = [1] * rank
for b in itertools.chain(self.offset_bases, self.block_bases):
for i, bi in enumerate(b):
max_stride[i] = max(max_stride[i], bi)
return [2 * s for s in max_stride]

def __hash__(self):
return hash((
tuple(map(tuple, self.offset_bases)),
Expand Down
19 changes: 10 additions & 9 deletions python/triton/experimental/gluon/language/_semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
from triton.language.semantic import TritonSemantic
from . import _core as ttgl
from ._layouts import AutoLayout, DistributedLayout, DistributedLinearLayout, SliceLayout, SharedLayout, CoalescedLayout
from ._layouts import AutoLayout, DistributedLayout, DistributedLinearLayout, SliceLayout, SharedLayout, CoalescedLayout, SharedLinearLayout
from triton._C.libtriton.gluon_ir import GluonOpBuilder, compute_tmem_reg_layout
from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values

Expand Down Expand Up @@ -301,15 +301,16 @@ def bank_conflicts(self, distr_ty, shared_ty):
distr_ty.element_ty.primitive_bitwidth)

def to_linear_layout(self, layout, shape):
_check(isinstance(layout, (DistributedLayout, SharedLayout)),
lambda: f"Expected a DistributedLayout or SharedLayout, got {type(layout)}")

if not isinstance(shape, list):
shape = list(shape)

layout = ttgl._unwrap_if_constexpr(layout)
from triton.experimental.gluon.language.nvidia.blackwell import (
TensorMemoryLayout,
TensorMemoryScalesLayout,
)
_check(
isinstance(layout, (DistributedLayout, SharedLayout, TensorMemoryLayout, TensorMemoryScalesLayout)), lambda:
f"Expected a DistributedLayout, SharedLayout, or TensorMemoryLayout or TensorMemoryScalesLayout, got {type(layout)}"
)

if isinstance(layout, (AutoLayout, DistributedLinearLayout)):
if isinstance(layout, (AutoLayout, DistributedLinearLayout, SharedLinearLayout)):
return ttgl.constexpr(layout)

return ttgl.constexpr(self.builder.to_linear_layout(layout._to_ir(self.builder), shape))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, Tuple, List, TYPE_CHECKING

from dataclasses import dataclass
import itertools
from triton.runtime.jit import constexpr_function
from triton.experimental.gluon.language import _core as ttgl
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
Expand All @@ -26,7 +27,9 @@
"mma_v2",
"tensor_memory_descriptor",
"TensorMemoryLayout",
"TensorMemoryScalesLayout",
"tma",
"_TensorMemoryLinearLayout",
]


Expand Down Expand Up @@ -104,6 +107,25 @@ def __hash__(self):
return hash(self.cta_split_num)


@dataclass(frozen=True)
class _TensorMemoryLinearLayout:
"""
Print-only linear layout for TMEM (row/col -> dim0/dim1).
"""
rows: List[List[int]]
cols: List[List[int]]
shape: List[int]

def _to_ir(self, builder):
raise RuntimeError("TensorMemoryLinearLayout is print-only; IR materialization is unsupported")

def mangle(self):
return f"TMLL_{self.shape}_TMLL"

def __hash__(self):
return hash((tuple(map(tuple, self.rows)), tuple(map(tuple, self.cols)), tuple(self.shape)))


@constexpr_function
def get_tmem_reg_layout(
element_ty,
Expand Down
Loading
Loading