Skip to content

Commit 40b9889

Browse files
authored
Assert whether tl.advance is unused at the AST level. (#6768)
Previously, we would check whether it is unused at the IR level, but that has false positives: for example, DCE after resolving compile-time branches and static_range expressions could result in the tl.advance return value being discarded due to no error on the user's behalf. We instead inspect the AST and see whether a tl.advance Call is the child of an Expr node; if so, the user has almost certainly erred and we should loudly communicate this.
1 parent 236f6b5 commit 40b9889

File tree

5 files changed

+55
-13
lines changed

5 files changed

+55
-13
lines changed

python/src/ir.cc

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -772,19 +772,7 @@ void init_triton_ir(py::module &&m) {
772772
},
773773
ret::reference)
774774
// .def("has_attr", &::FuncOp::hasAttr)
775-
.def("finalize",
776-
[](FuncOp &self) -> void {
777-
// Check if the result of tl.advance is used
778-
self.walk([&](AdvanceOp op) {
779-
if (op->getResult(0).use_empty())
780-
outputWarning(op->getLoc(), "The result of tl.advance is not "
781-
"being used. Note that tl.advance "
782-
"does not have any side effects. "
783-
"To move the block pointer, you "
784-
"need to assign the result of "
785-
"tl.advance to a variable.");
786-
});
787-
})
775+
.def("finalize", [](FuncOp &self) -> void {})
788776
.def_property_readonly("type", &FuncOp::getFunctionType)
789777
.def("reset_type", &FuncOp::setType);
790778

python/test/unit/language/test_compile_errors.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,3 +453,35 @@ def dot_kernel():
453453
assert (str(e.value.__cause__) == "max_num_imprecise_acc (128) must be <= K (64)")
454454
except AssertionError as assertion_err:
455455
raise assertion_err from e.value
456+
457+
458+
extra_words = "These are extra words in the error message."
459+
460+
461+
@triton.must_use_result(extra_words)
462+
@triton.jit
463+
def cube(x):
464+
return x * x * x
465+
466+
467+
def test_unused_result():
468+
469+
@triton.jit
470+
def evil_cube_kernel():
471+
a = tl.full((64, 64), 0.0, tl.float32)
472+
cube(a)
473+
474+
@triton.jit
475+
def good_cube_kernel():
476+
a = tl.full((64, 64), 0.0, tl.float32)
477+
a = cube(a)
478+
479+
triton.compile(triton.compiler.ASTSource(fn=good_cube_kernel, signature={}, constexprs={}))
480+
481+
with pytest.raises(CompilationError) as e:
482+
triton.compile(triton.compiler.ASTSource(fn=evil_cube_kernel, signature={}, constexprs={}))
483+
484+
expected_err_msg = "The result of cube is not being used. " + extra_words
485+
obtained_err_msg = str(e.value).split('\n')[-1]
486+
487+
assert expected_err_msg == obtained_err_msg

python/triton/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from . import testing
2727
from . import tools
2828

29+
must_use_result = language.core.must_use_result
30+
2931
__all__ = [
3032
"autotune",
3133
"cdiv",
@@ -39,6 +41,7 @@
3941
"KernelInterface",
4042
"language",
4143
"MockTensor",
44+
"must_use_result",
4245
"next_power_of_2",
4346
"OutOfResources",
4447
"reinterpret",

python/triton/compiler/code_generator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,13 @@ def visit_Call(self, node):
12201220
if static_implementation is not None:
12211221
return static_implementation(self, node)
12221222

1223+
mur = getattr(fn, '_must_use_result', False)
1224+
if mur and getattr(node, '_is_unused', False):
1225+
error_message = ["The result of %s is not being used." % ast.unparse(node.func)]
1226+
if isinstance(mur, str):
1227+
error_message.append(mur)
1228+
raise CompilationError(self.jit_fn.src, node, " ".join(error_message))
1229+
12231230
kws = dict(self.visit(keyword) for keyword in node.keywords)
12241231
args = [self.visit(arg) for arg in node.args]
12251232
args = list(itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in args))
@@ -1277,6 +1284,7 @@ def visit_Attribute(self, node):
12771284
return getattr(lhs, node.attr)
12781285

12791286
def visit_Expr(self, node):
1287+
node.value._is_unused = True
12801288
ast.NodeVisitor.generic_visit(self, node)
12811289

12821290
def visit_NoneType(self, node):

python/triton/language/core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
PropagateNan = ir.PROPAGATE_NAN
2424

2525

26+
def must_use_result(x, s=True):
27+
"""If the result of this function is unused, throw an error."""
28+
if isinstance(x, str):
29+
return (lambda fn: must_use_result(fn, x))
30+
x._must_use_result = s
31+
return x
32+
33+
2634
def builtin(fn: T) -> T:
2735
"""Mark a function as a builtin."""
2836
assert callable(fn)
@@ -2080,6 +2088,9 @@ def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _b
20802088
return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder)
20812089

20822090

2091+
@must_use_result(
2092+
"Note that tl.advance does not have any side effects. To move the block pointer, you need to assign the result of tl.advance to a variable."
2093+
)
20832094
@_tensor_member_fn
20842095
@builtin
20852096
def advance(base, offsets, _builder=None):

0 commit comments

Comments
 (0)