Skip to content

Commit 65c7f47

Browse files
Merge commit '4ff1fd66c2cea812226cc02aaa461e4355977ed7'
2 parents be3b9ad + 4ff1fd6 commit 65c7f47

File tree

9 files changed

+197
-52
lines changed

9 files changed

+197
-52
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/Conversion/LLVMCommon/Pattern.h"
77
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
88
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
9+
#include "mlir/Interfaces/FunctionInterfaces.h"
910
#include "triton/Analysis/Utility.h"
1011
#include "triton/Conversion/MLIRTypes.h"
1112
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
@@ -364,17 +365,14 @@ inline bool isKernel(FunctionOpInterface funcOp) {
364365

365366
inline Value getStackPointer(RewriterBase &rewriter,
366367
FunctionOpInterface funcOp) {
368+
if (!isKernel(funcOp)) {
369+
return funcOp.getArgument(funcOp.getNumArguments() - 1);
370+
}
371+
367372
auto mod = funcOp->getParentOfType<ModuleOp>();
368-
LLVM::GlobalOp globalBase = nullptr;
369-
mod.walk([&](LLVM::GlobalOp op) {
370-
if (op.getSymName() == "global_smem")
371-
globalBase = op;
372-
});
373+
auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol("global_smem"));
373374
assert(globalBase);
374-
if (isKernel(funcOp))
375-
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
376-
else
377-
return funcOp.getArgument(funcOp.getNumArguments() - 1);
375+
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
378376
}
379377

380378
inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,12 @@ def TT_GetProgramIdOp : TT_Op<"get_program_id", [Pure]> {
602602

603603
let assemblyFormat = "$axis attr-dict `:` type($result)";
604604

605+
let builders = [
606+
OpBuilder<(ins "int":$axis), [{
607+
build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis)));
608+
}]>
609+
];
610+
605611
let extraClassDeclaration = [{
606612
int32_t getAxisAsInt() {
607613
return static_cast<int32_t>(getAxis());
@@ -615,6 +621,11 @@ def TT_GetNumProgramsOp : TT_Op<"get_num_programs", [Pure]> {
615621
let results = (outs I32:$result);
616622

617623
let assemblyFormat = "$axis attr-dict `:` type($result)";
624+
let builders = [
625+
OpBuilder<(ins "int":$axis), [{
626+
build($_builder, $_state, $_builder.getI32Type(), ProgramIDDimAttr::get($_builder.getContext(), ProgramIDDim(axis)));
627+
}]>
628+
];
618629

619630
let extraClassDeclaration = [{
620631
int32_t getAxisAsInt() {

lib/Tools/LinearLayout.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,11 +681,11 @@ LinearLayout::divideRight(const LinearLayout &divisor) const {
681681
std::move(newBases), std::move(newOutDims.takeVector()),
682682
/*requireSurjective=*/false);
683683
LDBG("candidate_quotient:" << candidateQuotient);
684-
LDBG("*candidate_quotient * divisor=" << *candidateQuotient * divisor);
685684
if (!candidateQuotient.has_value()) {
686685
LDBG("candidate quotient failed invariant checks");
687686
return std::nullopt;
688687
}
688+
LDBG("*candidate_quotient * divisor=" << *candidateQuotient * divisor);
689689
if (*candidateQuotient * divisor != *this) {
690690
LDBG("candidate quotient failed invariant checks");
691691
return std::nullopt;

python/src/ir.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,19 +1433,13 @@ void init_triton_ir(py::module &&m) {
14331433
[](TritonOpBuilder &self, int axis) -> Value {
14341434
if (axis < 0 || axis > 3)
14351435
throw pybind11::index_error("program_id must be in [0,3]");
1436-
return self.create<GetProgramIdOp>(
1437-
self.getBuilder().getI32Type(),
1438-
ProgramIDDimAttr::get(self.getBuilder().getContext(),
1439-
ProgramIDDim(axis)));
1436+
return self.create<GetProgramIdOp>(axis);
14401437
})
14411438
.def("create_get_num_programs",
14421439
[](TritonOpBuilder &self, int axis) -> Value {
14431440
if (axis < 0 || axis > 3)
14441441
throw pybind11::index_error("program_id must be in [0,3]");
1445-
return self.create<GetNumProgramsOp>(
1446-
self.getBuilder().getI32Type(),
1447-
ProgramIDDimAttr::get(self.getBuilder().getContext(),
1448-
ProgramIDDim(axis)));
1442+
return self.create<GetNumProgramsOp>(axis);
14491443
})
14501444
.def("create_dot",
14511445
[](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b,

python/test/unit/language/test_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3385,7 +3385,8 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_
33853385
input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee"
33863386

33873387
if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32":
3388-
if torch.cuda.is_available() and triton.runtime.driver.active.utils.get_device_properties(
3388+
if not is_interpreter() and torch.cuda.is_available(
3389+
) and triton.runtime.driver.active.utils.get_device_properties(
33893390
torch.cuda.current_device())["max_shared_mem"] < 131072:
33903391
pytest.skip(
33913392
"Skipping tests with B = 8, M = 64, in_type = float32, out_type = float32 due to insufficient shared memory (less than 128 KB per SM) on this GPU."

python/triton/compiler/code_generator.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .. import language
1010
from .._C.libtriton import ir
1111
from ..language import constexpr, tensor, str_to_ty
12-
from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type
12+
from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value
1313
from ..runtime.jit import _normalize_ty, get_jit_fn_file_line
1414
# ideally we wouldn't need any runtime component
1515
from ..runtime import JITFunction
@@ -47,6 +47,10 @@ def mangle_fn(name, arg_tys, constants):
4747
return ret
4848

4949

50+
def _is_triton_value(o: Any) -> bool:
51+
return isinstance(o, _value)
52+
53+
5054
def _is_triton_tensor(o: Any) -> bool:
5155
return isinstance(o, tensor)
5256

@@ -501,7 +505,7 @@ def visit_Assign(self, node):
501505
# by default, constexpr are assigned into python variable
502506
value = _unwrap_if_constexpr(value)
503507
if value is not None and \
504-
not _is_triton_tensor(value) and \
508+
not _is_triton_value(value) and \
505509
not isinstance(value, native_nontensor_types):
506510
value = language.semantic.to_tensor(value, self.builder)
507511
self.set_value(name, value)
@@ -802,6 +806,15 @@ def visit_UnaryOp(self, node):
802806
ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'
803807
}
804808

809+
def _verify_loop_carried_variable(self, name, loop_val, live_val):
810+
assert _is_triton_value(loop_val), f'cannot reassign constxpr {name} in the loop'
811+
assert _is_triton_value(live_val), f'cannot reasign constexpr {name} in the loop'
812+
assert type(loop_val) == type(live_val), f'Loop carried variable {name} changed type'
813+
assert not _is_triton_tensor(loop_val) or loop_val.type == live_val.type, \
814+
f'Loop-carried variable {name} has initial type {live_val.type} '\
815+
f'but is re-assigned to {loop_val.type} in loop! '\
816+
f'Please make sure that the type stays consistent.'
817+
805818
def visit_While(self, node):
806819
with enter_sub_region(self) as sr:
807820
liveins, insert_block = sr
@@ -824,17 +837,14 @@ def visit_While(self, node):
824837
for name in loop_defs:
825838
if name in liveins:
826839
# We should not def new constexpr
827-
assert _is_triton_tensor(loop_defs[name]), f'cannot reassign constxpr {name} in the loop'
828-
assert _is_triton_tensor(liveins[name]), f'cannot reasign constexpr {name} in the loop'
829-
assert loop_defs[name].type == liveins[name].type, \
830-
f'Loop-carried variable {name} has initial type {liveins[name].type} '\
831-
f'but is re-assigned to {loop_defs[name].type} in loop! '\
832-
f'Please make sure that the type stays consistent.'
840+
loop_val = loop_defs[name]
841+
live_val = liveins[name]
842+
self._verify_loop_carried_variable(name, loop_val, live_val)
833843

834844
# these are loop-carried values
835845
names.append(name)
836-
ret_types.append(loop_defs[name].type)
837-
init_args.append(liveins[name])
846+
ret_types.append(loop_val.type)
847+
init_args.append(live_val)
838848

839849
self._set_insertion_point_and_loc(ip, last_loc)
840850
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
@@ -972,16 +982,13 @@ def visit_For(self, node):
972982
names = []
973983
for name in self.local_defs:
974984
if name in liveins:
975-
assert _is_triton_tensor(self.local_defs[name]), f'cannot reassign constxpr {name} in the loop'
976-
assert _is_triton_tensor(liveins[name]), f'cannot reassign constxpr {name} in the loop'
977-
assert self.local_defs[name].type == liveins[name].type, \
978-
f'Loop-carried variable {name} has initial type {liveins[name].type} '\
979-
f'but is re-assigned to {self.local_defs[name].type} in loop! '\
980-
f'Please make sure that the type stays consistent.'
985+
loop_val = self.local_defs[name]
986+
live_val = liveins[name]
987+
self._verify_loop_carried_variable(name, loop_val, live_val)
981988

982989
names.append(name)
983-
init_args.append(language.semantic.to_tensor(liveins[name], self.builder))
984-
yields.append(language.semantic.to_tensor(self.local_defs[name], self.builder))
990+
init_args.append(live_val)
991+
yields.append(loop_val)
985992

986993
# create ForOp
987994
self._set_insertion_point_and_loc(ip, last_loc)
@@ -1051,7 +1058,7 @@ def visit_Assert(self, node) -> Any:
10511058
def call_JitFunction(self, fn: JITFunction, args, kwargs):
10521059
args = inspect.getcallargs(fn.fn, *args, **kwargs)
10531060
args = [args[name] for name in fn.arg_names]
1054-
args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args]
1061+
args = [arg if _is_triton_value(arg) else constexpr(arg) for arg in args]
10551062
# generate function def
10561063
attributes = {}
10571064
constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)]
@@ -1110,7 +1117,7 @@ def visit_Call(self, node):
11101117
if isinstance(fn, JITFunction):
11111118
_check_fn_args(node, fn, args)
11121119
return self.call_JitFunction(fn, args, kws)
1113-
if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn):
1120+
if (hasattr(fn, '__self__') and _is_triton_value(fn.__self__)) or language.core.is_builtin(fn):
11141121
extra_kwargs = {"_builder": self.builder}
11151122
sig = inspect.signature(fn)
11161123
if '_generator' in sig.parameters:

python/triton/language/core.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -701,12 +701,20 @@ def get_int_dtype(bitwidth: int, signed: bool) -> dtype:
701701
raise ValueError(f'Unsupported bitwidth {bitwidth} and signedness {signed}')
702702

703703

704+
class _value:
705+
"""Base class of values that exist in the triton IR (i.e. not constexprs).
706+
"""
707+
708+
def __init__(self, handle):
709+
self.handle = handle
710+
711+
704712
# -----------------------
705713
# tensor
706714
# -----------------------
707715

708716

709-
class tensor:
717+
class tensor(_value):
710718
"""Represents an N-dimensional array of values or pointers.
711719
712720
:code:`tensor` is the fundamental data structure in Triton programs. Most
@@ -729,7 +737,7 @@ class tensor:
729737
def __init__(self, handle, type: dtype):
730738
"""Not called by user code."""
731739
# IR handle
732-
self.handle = handle
740+
super().__init__(handle)
733741
# Block shape
734742
self.shape = type.shape if type.is_block() else ()
735743
self.numel = 1

test/TritonGPU/amd/amd-canonicalize-pointers.mlir

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,3 +577,119 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
577577
tt.return %11 : tensor<1024xf32, #blocked>
578578
}
579579
}
580+
581+
// -----
582+
583+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
584+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
585+
// CHECK-LABEL: scalar_pointers
586+
tt.func public @scalar_pointers(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
587+
%0 = tt.get_program_id x : i32
588+
%c1_i32 = arith.constant 1 : i32
589+
%c0_i64 = arith.constant 0 : i64
590+
%c10_i64 = arith.constant 10 : i64
591+
%c100_i32 = arith.constant 100 : i32
592+
%5 = tt.addptr %arg0, %c1_i32 : !tt.ptr<i64>, i32
593+
// CHECK: arith.constant 0 : i64
594+
// CHECK: arith.constant 0 : i64
595+
// CHECK: %[[offset0:.*]] = arith.constant 0 : i64
596+
// CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %c1_i32 : !tt.ptr<i64>, i32
597+
// CHECK: scf.for {{.*}} iter_args({{.*}}, %[[ptr1:.*]] = %[[ptr0]], %[[offset1:.*]] = %[[offset0]])
598+
%10:1 = scf.for %arg3 = %c1_i32 to %c100_i32 step %c1_i32 iter_args(%arg4 = %5) -> (!tt.ptr<i64>) : i32 {
599+
// CHECK: tt.store %[[ptr1]]
600+
tt.store %arg4, %c0_i64 : !tt.ptr<i64>
601+
// CHECK: tt.addptr %[[ptr1]]
602+
%11 = tt.addptr %arg4, %c1_i32 : !tt.ptr<i64>, i32
603+
scf.yield %11 : !tt.ptr<i64>
604+
}
605+
tt.return
606+
}
607+
}
608+
609+
// -----
610+
611+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
612+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
613+
// CHECK-LABEL: @scalar_if
614+
tt.func @scalar_if(%arg0: !tt.ptr<f32>, %init : tensor<1024xf32, #blocked>, %cond : i1)->f32{
615+
%0 = tt.get_program_id x : i32
616+
%c1_i32 = arith.constant 1 : i32
617+
%c0_i64 = arith.constant 0 : i64
618+
%c10_i64 = arith.constant 10 : i64
619+
%c100_i32 = arith.constant 100 : i32
620+
%5 = tt.addptr %arg0, %c1_i32 : !tt.ptr<f32>, i32
621+
// CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}}
622+
// CHECK: scf.if {{.*}} -> ({{.*}}, !tt.ptr<f32>, i64)
623+
%6 = scf.if %cond -> (!tt.ptr<f32>){
624+
%true = tt.addptr %5, %c1_i32 : !tt.ptr<f32>, i32
625+
// CHECK: %[[ptr1:.*]] = tt.addptr %[[ptr0]]
626+
// CHECK: scf.yield {{.*}}, %[[ptr1]]
627+
scf.yield %true : !tt.ptr<f32>
628+
} else {
629+
%false = tt.addptr %5, %c100_i32 : !tt.ptr<f32>, i32
630+
// CHECK: %[[ptr2:.*]] = tt.addptr %[[ptr0]]
631+
// CHECK: scf.yield {{.*}}, %[[ptr2]]
632+
scf.yield %false : !tt.ptr<f32>
633+
}
634+
%11 = tt.load %6 : !tt.ptr<f32>
635+
tt.return %11 : f32
636+
}
637+
}
638+
639+
// -----
640+
641+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
642+
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
643+
// CHECK-LABEL: tt.func @scalar_while
644+
tt.func @scalar_while(%arg0: !tt.ptr<f32>, %init : f32, %cond : i1)->f32{
645+
%c1024_i32 = arith.constant 1024 : i32
646+
%c0 = arith.constant 0: index
647+
%c128 = arith.constant 128: index
648+
%c1 = arith.constant 1 : index
649+
%0 = tt.get_program_id x : i32
650+
%1 = arith.muli %0, %c1024_i32 : i32
651+
// CHECK: %[[ptr0:.*]] = tt.addptr %arg0, %{{.*}}
652+
// CHECK: scf.while ({{.*}}, {{.*}} = %arg2, %[[ptr1:.*]] = %[[ptr0]], {{.*}})
653+
%2 = tt.addptr %arg0, %0 : !tt.ptr<f32>, i32
654+
%6 = scf.while (%arg1 = %2, %arg2 = %cond) : (!tt.ptr<f32>, i1) -> (!tt.ptr<f32>) {
655+
// CHECK: scf.condition({{.*}}) {{.*}}, %[[ptr1]]
656+
scf.condition(%arg2) %arg1 : !tt.ptr<f32>
657+
} do {
658+
// CHECK: ^bb0({{.*}}: !tt.ptr<f32>, %[[ptr2:.*]]: !tt.ptr<f32>, {{.*}})
659+
// CHECK: scf.yield %{{.*}}, {{.*}} %[[ptr2]], {{.*}}, {{.*}}
660+
^bb0(%arg1: !tt.ptr<f32>):
661+
scf.yield %arg1, %cond : !tt.ptr<f32>, i1
662+
}
663+
%11 = tt.load %6 : !tt.ptr<f32>
664+
tt.return %11 : f32
665+
}
666+
}
667+
668+
// -----
669+
670+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
671+
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
672+
// CHECK-LABEL: tt.func @scalar_cond_branch
673+
tt.func @scalar_cond_branch(%arg0 : !tt.ptr<f32>, %i1 : i1) -> f32{
674+
%c1024_i32 = arith.constant 1024 : i32
675+
%c0 = arith.constant 0: index
676+
%c128 = arith.constant 128: index
677+
%c1 = arith.constant 1 : index
678+
%0 = tt.get_program_id x : i32
679+
%1 = arith.muli %0, %c1024_i32 : i32
680+
%6 = tt.addptr %arg0, %0 : !tt.ptr<f32>, i32
681+
// CHECK: %[[ptr0:.*]] = tt.addptr %arg0
682+
// CHECK: cf.cond_br %arg1, ^bb1(%{{.*}}, %[[ptr0]], {{.*}}), ^bb2(%{{.*}}, %arg0, {{.*}})
683+
cf.cond_br %i1, ^bb1(%6 : !tt.ptr<f32>), ^bb2(%arg0 : !tt.ptr<f32>)
684+
// CHECK: ^bb1({{.*}}, %[[ptr1:.*]]: !tt.ptr<f32>, {{.*}}):
685+
^bb1(%arg1 : !tt.ptr<f32>):
686+
// CHECK: tt.load %[[ptr1]]
687+
%out1 = tt.load %arg1 : !tt.ptr<f32>
688+
tt.return %out1 : f32
689+
// CHECK: ^bb2({{.*}}, %[[ptr2:.*]]: !tt.ptr<f32>, {{.*}}):
690+
^bb2(%arg2 : !tt.ptr<f32>): // 2 preds: ^bb0, ^bb1
691+
// CHECK: tt.load %[[ptr2]]
692+
%out2 = tt.load %arg2 : !tt.ptr<f32>
693+
tt.return %out2 : f32
694+
}
695+
}

0 commit comments

Comments
 (0)