Skip to content

Commit a3f5ea6

Browse files
authored
[FRONTEND] .item() as syntactic sugar for .reshape([]) (#6873)
suggested by @lezcano
1 parent cd57ce9 commit a3f5ea6

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

python/test/unit/language/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7528,7 +7528,7 @@ def unsplat_kernel(x, explicit: tl.constexpr):
75287528
condition = tl.load(x + tl.arange(0, 1)) > 42
75297529

75307530
if explicit:
7531-
condition = condition.reshape([])
7531+
condition = condition.item()
75327532

75337533
if condition:
75347534
tl.store(x, 42)

python/triton/compiler/code_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def visit_If(self, node):
773773
raise self._unsupported(node, "Boolean value of Tensor with more than one value is ambiguous")
774774
if cond.type.is_block():
775775
warnings.warn(
776-
"If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).reshape([])\" instead"
776+
"If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).item()\" instead"
777777
% ast.unparse(node.test))
778778
cond = language.core._unsplat(cond, _builder=self.builder, _generator=self)
779779
cond = cond.to(language.int1, _builder=self.builder)

python/triton/language/core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,6 +1792,15 @@ def view(input, *shape, _builder=None):
17921792
return semantic.reshape(input, shape, can_reorder=True, builder=_builder)
17931793

17941794

1795+
@_tensor_member_fn
1796+
@builtin
1797+
def item(input, _builder=None, _generator=None):
1798+
"""
1799+
Converts a single-element tensor into a scalar.
1800+
"""
1801+
return _unsplat(input, _builder=_builder, _generator=_generator)
1802+
1803+
17951804
@_tensor_member_fn
17961805
@builtin
17971806
def reshape(input, *shape, can_reorder=False, _builder=None, _generator=None):

0 commit comments

Comments
 (0)