Skip to content

Commit d254e2b

Browse files
Merge commit 'c76b342a2d704b6552c1224a4e7706bb85a4b888'
2 parents b3ca988 + c76b342 commit d254e2b

File tree

12 files changed

+108
-31
lines changed

12 files changed

+108
-31
lines changed

include/triton/Analysis/AxisInfo.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ class AxisInfo {
2727
public:
2828
AxisInfo() : AxisInfo({}, {}, {}) {}
2929

30-
AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy)
30+
AxisInfo(ArrayRef<int64_t> contiguity, ArrayRef<int64_t> divisibility,
31+
ArrayRef<int64_t> constancy)
3132
: AxisInfo(contiguity, divisibility, constancy, std::nullopt) {}
3233

33-
AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy,
34-
std::optional<int64_t> constantValue)
34+
AxisInfo(ArrayRef<int64_t> contiguity, ArrayRef<int64_t> divisibility,
35+
ArrayRef<int64_t> constancy, std::optional<int64_t> constantValue)
3536
: contiguity(contiguity), divisibility(divisibility),
3637
constancy(constancy), constantValue(constantValue) {
3738
assert(divisibility.size() == contiguity.size());

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
778778
(ins "int":$opIdx,
779779
"int":$kWidth)>,
780780

781+
InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
782+
"SmallVector<unsigned>",
783+
"getRepOrderForOperand",
784+
(ins "int":$opIdx)>,
785+
781786
InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector<unsigned>",
782787
"getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
783788
"Type":$eltTy,

lib/Analysis/Allocation.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
8484

8585
assert(cvtNeedsSharedMemory(srcTy, dstTy));
8686

87-
auto inOrd = gpu::getOrder(srcLayout);
88-
auto outOrd = gpu::getOrder(dstLayout);
87+
const auto &inOrd = gpu::getOrder(srcLayout);
88+
const auto &outOrd = gpu::getOrder(dstLayout);
8989
scratchConfig.order = outOrd;
9090

9191
unsigned srcContigPerThread =
@@ -303,7 +303,7 @@ class AllocationAnalysis {
303303
/// arguments are involved.
304304
void resolveAliasBufferLiveness(
305305
function_ref<Interval<size_t>(Value value)> getLiveness) {
306-
for (auto aliasBufferIter : allocation->getAliasBuffer()) {
306+
for (const auto &aliasBufferIter : allocation->getAliasBuffer()) {
307307
auto value = aliasBufferIter.first;
308308
auto buffers = aliasBufferIter.second;
309309
auto range = getLiveness(value);
@@ -443,7 +443,7 @@ class AllocationAnalysis {
443443
std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) {
444444
auto xRange = bufferRange[buffer];
445445
bool res = xRange.intersects(range);
446-
for (auto val : tripleMap)
446+
for (const auto &val : tripleMap)
447447
res = res &&
448448
!val.second.intersects(xRange); // only one buffer intersect
449449
return res;

lib/Analysis/AxisInfo.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,9 +1084,11 @@ LogicalResult AxisInfoAnalysis::visitOperation(
10841084

10851085
void AxisInfoAnalysis::visitForOpInductionVar(
10861086
scf::ForOp op, ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices) {
1087-
ProgramPoint programPoint(op);
1088-
auto lb = getLatticeElementFor(&programPoint, op.getLowerBound())->getValue();
1089-
auto step = getLatticeElementFor(&programPoint, op.getStep())->getValue();
1087+
ProgramPoint *programPoint = getProgramPointAfter(op);
1088+
const auto &lb =
1089+
getLatticeElementFor(programPoint, op.getLowerBound())->getValue();
1090+
const auto &step =
1091+
getLatticeElementFor(programPoint, op.getStep())->getValue();
10901092

10911093
AxisInfo::DimVectorT knownContiguity(1, 1);
10921094
AxisInfo::DimVectorT knownDivisibility(1, 1);

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,7 +1702,14 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const {
17021702
}
17031703

17041704
SmallVector<unsigned> AMDMfmaEncodingAttr::getRepOrder() const {
1705-
llvm::report_fatal_error("NYI. AMDMfmaEncodingAttr::getRepOrder");
1705+
auto rank = getWarpsPerCTA().size();
1706+
return getMatrixOrder(rank, /*rowMajor*/ true);
1707+
}
1708+
1709+
SmallVector<unsigned>
1710+
AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
1711+
auto rank = getWarpsPerCTA().size();
1712+
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
17061713
}
17071714

17081715
SmallVector<int64_t>
@@ -1789,8 +1796,16 @@ AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
17891796
return shapePerCTATile;
17901797
}
17911798
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
1792-
llvm::report_fatal_error("NYI. AMDWmmaEncodingAttr::getRepOrder");
1799+
auto rank = getWarpsPerCTA().size();
1800+
return getMatrixOrder(rank, /*rowMajor*/ true);
17931801
}
1802+
1803+
SmallVector<unsigned>
1804+
AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
1805+
auto rank = getWarpsPerCTA().size();
1806+
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
1807+
}
1808+
17941809
SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {
17951810
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
17961811
}
@@ -2060,7 +2075,7 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
20602075
// DotOperand Encoding
20612076
//===----------------------------------------------------------------------===//
20622077
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
2063-
if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
2078+
if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
20642079
return mma.getRepOrderForOperand(getOpIdx());
20652080
}
20662081
llvm::report_fatal_error(

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ std::optional<std::pair<Operation *, int>> findZeroInitOp(Value accUse,
6565
return std::nullopt;
6666
}
6767
if (auto selOp = dyn_cast<arith::SelectOp>(defOp)) {
68+
if (!selOp.getCondition().getType().isInteger(1))
69+
return std::nullopt;
6870
if (isConstantZeroTensor(selOp.getTrueValue()) ||
6971
isConstantZeroTensor(selOp.getFalseValue())) {
7072
return std::make_pair(selOp, 0);

python/test/unit/language/test_core.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5630,7 +5630,7 @@ def matmul_kernel( #
56305630
stride_cm, stride_cn, #
56315631
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
56325632
low_precision_acc: tl.constexpr, #
5633-
num_pipeline_stages: tl.constexpr = 3 #
5633+
num_stages: tl.constexpr = 3 #
56345634
):
56355635
pid = tl.program_id(axis=0)
56365636
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
@@ -5642,7 +5642,7 @@ def matmul_kernel( #
56425642
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
56435643
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
56445644
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
5645-
for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_pipeline_stages):
5645+
for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages):
56465646
a = tl.load(a_ptrs)
56475647
b = tl.load(b_ptrs)
56485648
accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc)
@@ -5681,7 +5681,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
56815681
max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None
56825682
h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0),
56835683
C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps,
5684-
num_pipeline_stages=num_stages)
5684+
num_stages=num_stages)
56855685
torch_a = torch.from_numpy(A).to(device=device)
56865686
th_a = f8_to_f16(torch_a, in_type_str)
56875687
torch_b = torch.from_numpy(B).to(device=device)
@@ -5873,7 +5873,7 @@ def test_tl_range(device):
58735873
pgm = matmul_kernel[
58745874
1,
58755875
](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N,
5876-
BLOCK_K, 0, num_pipeline_stages=5)
5876+
BLOCK_K, 0, num_stages=5)
58775877
ref_out = torch.matmul(a, b).to(torch.float32)
58785878
if is_interpreter():
58795879
# GPU invokes tensor core for float16 matmul, which is not supported in interpreter.
@@ -5899,8 +5899,8 @@ def maxnreg_noinline2(X):
58995899
tl.store(X, 0)
59005900

59015901

5902+
@pytest.mark.interpreter
59025903
def test_maxnreg(device):
5903-
assert not is_interpreter(), "this test won't work with the interpreter"
59045904
if not is_cuda():
59055905
pytest.xfail('maxnreg only works on CUDA')
59065906

@@ -5914,14 +5914,15 @@ def kernel(X):
59145914
X = torch.empty(1, dtype=torch.int32, device=device)
59155915
k = kernel[(1, )](X, maxnreg=42)
59165916

5917-
# Ensure that .maxnreg is set on the kernel function (marked with .entry)
5918-
# and not on either of the noinline functions (marked with .func).
5919-
try:
5920-
assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"])
5921-
assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"])
5922-
except AssertionError:
5923-
print("Failing ptx:\n", k.asm["ptx"])
5924-
raise
5917+
if not is_interpreter():
5918+
# Ensure that .maxnreg is set on the kernel function (marked with .entry)
5919+
# and not on either of the noinline functions (marked with .func).
5920+
try:
5921+
assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"])
5922+
assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"])
5923+
except AssertionError:
5924+
print("Failing ptx:\n", k.asm["ptx"])
5925+
raise
59255926

59265927

59275928
@pytest.mark.interpreter

python/triton/runtime/interpreter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,9 +1034,6 @@ def _implicit_cvt(arg):
10341034

10351035
interpreter_builder = InterpreterBuilder()
10361036

1037-
# These keywords are not supported by the interpreter
1038-
RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"]
1039-
10401037

10411038
class GridExecutor:
10421039

@@ -1077,10 +1074,13 @@ def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
10771074
kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data)
10781075

10791076
def __call__(self, *args_dev, **kwargs):
1080-
# removes reserved keywords from kwargs
1081-
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
10821077
if kwargs.pop("warmup", False):
10831078
return
1079+
# Removes not used reserved keywords from kwargs
1080+
# Triton doesn't support keyword-only, variable positional or variable keyword arguments
1081+
# It's safe to inspect only positional or keyword arguments (i.e., argspec.args)
1082+
argspec = inspect.getfullargspec(self.fn)
1083+
kwargs = {k: v for k, v in kwargs.items() if k in argspec.args}
10841084
# copy arguments to the host
10851085
args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
10861086
# remaps core language functions to interpreted ones

test/TritonGPU/accumulator-init.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,4 +348,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
348348
}
349349
tt.return %17 : tensor<128x16xf32, #mma1>
350350
}
351+
352+
// If the condition is a tensor skip the optimization.
353+
// CHECK-LABEL: @negative_sel_tensor
354+
// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc
355+
tt.func @negative_sel_tensor(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %cnd: tensor<128x16xi1, #mma1>) -> tensor<128x16xf32, #mma1> {
356+
%c0_i32 = arith.constant 0 : i32
357+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
358+
%c1_i32 = arith.constant 1 : i32
359+
%c8_i32 = arith.constant 8 : i32
360+
%17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 {
361+
%acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1>
362+
%acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1>
363+
scf.yield %acc: tensor<128x16xf32, #mma1>
364+
}
365+
tt.return %17 : tensor<128x16xf32, #mma1>
366+
}
351367
}

test/TritonGPU/coalesce.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,32 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
131131
tt.return
132132
}
133133
}
134+
135+
// -----
136+
137+
// COM: Reproducer for issue #5122
138+
// CHECK-LABEL: @test_5122
139+
module {
140+
tt.func public @test_5122(%arg0: i32) attributes {noinline = false} {
141+
%c1_i32 = arith.constant 1 : i32
142+
%0 = arith.cmpi sgt, %arg0, %c1_i32 : i32
143+
scf.if %0 {
144+
%1 = scf.if %0 -> (i32) {
145+
scf.yield %c1_i32 : i32
146+
} else {
147+
scf.yield %c1_i32 : i32
148+
}
149+
%2 = arith.cmpi sgt, %1, %c1_i32 : i32
150+
%3 = scf.if %2 -> (i32) {
151+
scf.yield %c1_i32 : i32
152+
} else {
153+
scf.yield %c1_i32 : i32
154+
}
155+
%4 = scf.for %arg1 = %1 to %1 step %c1_i32 iter_args(%arg2 = %3) -> (i32) : i32 {
156+
%5 = arith.addi %arg2, %c1_i32 : i32
157+
scf.yield %5 : i32
158+
}
159+
}
160+
tt.return
161+
}
162+
}

0 commit comments

Comments
 (0)