Skip to content

Commit 6a3027a

Browse files
authored
[FRONTEND] Better handling of boolean operators and scalars (#6769)
This commit contains the following frontend enhancements: 1. `tl.reshape(x, [])` works for single-element tensors and returns a scalar. Previously, it was very difficult to produce scalars, and `tl.reshape(x, [])` would throw an error. 2. `if`-statements correctly throw frontend errors if used on multiple-element tensors, and if a multidimensional single-element tensor is provided then we raise a warning (recommending `tl.reshape(x, [])`) and unsplat it to a scalar. Before, the backend would crash in both cases with an inscrutable MLIR error. 3. chained boolean operations such as `(P or Q or R)` are supported, whereas before they would throw frontend errors. 4. in boolean operations, operands with constexpr truth values are handled specially: - if we are in a conjunction and encounter a constexpr falsey operand, we short-circuit and return it; - if we are in a disjunction and encounter a constexpr truthy operand, we short-circuit and return it; - other constexpr operands are ignored completely (they do not participate in the result); The last of these enhancements allows one to write things such as: ``` if (x is not None) and (x.dtype == tl.int32): ... ``` which would previously have failed as Triton would have tried to compute both operands (the latter yielding an error) before taking their conjunction.
1 parent 0e54ff1 commit 6a3027a

File tree

3 files changed

+149
-10
lines changed

3 files changed

+149
-10
lines changed

python/test/unit/language/test_core.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7486,3 +7486,79 @@ def _namedtuple_float_tuple_kernel():
74867486
x, y = float('-inf'), float('inf') # noqa: F841
74877487

74887488
_namedtuple_float_tuple_kernel[(1, )]()
7489+
7490+
7491+
@pytest.mark.interpreter
7492+
def test_short_circuiting(device):
7493+
7494+
@triton.jit
7495+
def short_circuiting_kernel(x):
7496+
if (x is not None) and hasattr(x, "dtype") and isinstance(
7497+
x.dtype, tl.pointer_type) and (x.dtype.element_ty == tl.int32) and (tl.load(x) > 42):
7498+
tl.store(x, 42)
7499+
7500+
def f(x):
7501+
short_circuiting_kernel[(1, )](x, num_warps=1)
7502+
7503+
f(None) # should succeed with NoneType
7504+
f(1) # should succeed with tl.constexpr type
7505+
f(2) # should succeed with integer type
7506+
7507+
def g(y, dtype):
7508+
x = torch.full((1, ), y, device=device, dtype=dtype)
7509+
f(x)
7510+
return x.item()
7511+
7512+
assert g(37.5, torch.float32) == 37.5
7513+
assert g(84.0, torch.float32) == 84.0
7514+
assert g(-76893, torch.int32) == -76893
7515+
assert g(100000, torch.int32) == 42
7516+
assert g(100000, torch.int64) == 100000
7517+
7518+
7519+
@pytest.mark.interpreter
7520+
def test_unsplat(device):
7521+
7522+
@triton.jit
7523+
def unsplat_kernel(x, explicit: tl.constexpr):
7524+
7525+
# this is a single-element tensor:
7526+
condition = tl.load(x + tl.arange(0, 1)) > 42
7527+
7528+
if explicit:
7529+
condition = condition.reshape([])
7530+
7531+
if condition:
7532+
tl.store(x, 42)
7533+
7534+
def g(y, explicit):
7535+
x = torch.full((1, ), y, device=device, dtype=torch.int32)
7536+
unsplat_kernel[(1, )](x, explicit, num_warps=1)
7537+
return x.item()
7538+
7539+
assert g(41, False) == 41
7540+
assert g(43, False) == 42
7541+
assert g(41, True) == 41
7542+
assert g(43, True) == 42
7543+
7544+
7545+
@pytest.mark.interpreter
7546+
def test_tuple_logic():
7547+
7548+
@triton.jit
7549+
def tuple_logic_kernel():
7550+
7551+
# arity-2 BoolOps:
7552+
tl.static_assert(((3, 4) or (5, 6)) == (3, 4))
7553+
tl.static_assert(((3, 4) and (5, 6)) == (5, 6))
7554+
tl.static_assert(((3, 4) and ()) == ())
7555+
tl.static_assert((() or (5, 6)) == (5, 6))
7556+
7557+
# arity-3 BoolOps:
7558+
tl.static_assert(((1, 2) and (3, 4) and (5, 6)) == (5, 6))
7559+
tl.static_assert(((1, 2) or (3, 4) or (5, 6)) == (1, 2))
7560+
7561+
# constexpr short-circuiting over dynamic argument:
7562+
tl.static_assert((() and tl.program_id(0)) == ())
7563+
7564+
tuple_logic_kernel[(1, )]()

python/triton/compiler/code_generator.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,10 @@ def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunctio
324324
# special handling.
325325
self.visiting_arg_default_value = False
326326

327-
builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (len, list, range, float, int, isinstance, getattr)}
327+
builtin_namespace: Dict[str, Any] = {
328+
_.__name__: _
329+
for _ in (len, list, range, float, int, isinstance, getattr, hasattr)
330+
}
328331
builtin_namespace.update((
329332
('print', language.core.device_print),
330333
('min', language.minimum),
@@ -766,6 +769,13 @@ def visit_If(self, node):
766769
cond = self.visit(node.test)
767770

768771
if _is_triton_tensor(cond):
772+
if _is_non_scalar_tensor(cond):
773+
raise self._unsupported(node, "Boolean value of Tensor with more than one value is ambiguous")
774+
if cond.type.is_block():
775+
warnings.warn(
776+
"If conditional called with multidimensional Tensor instead of scalar; please use \"if (%s).reshape([])\" instead"
777+
% ast.unparse(node.test))
778+
cond = language.core._unsplat(cond, _builder=self.builder, _generator=self)
769779
cond = cond.to(language.int1, _builder=self.builder)
770780
contains_return = ContainsReturnChecker(self.gscope).visit(node)
771781
if contains_return:
@@ -876,6 +886,8 @@ def visit_UnaryOp(self, node):
876886
try:
877887
return getattr(operand, fn)()
878888
except AttributeError:
889+
if fn == "__not__":
890+
return constexpr(not operand)
879891
raise self._unsupported(
880892
node, f"AST unary operator '{fn}' is not (currently) implemented on type {type(operand).__name__}")
881893

@@ -1264,16 +1276,49 @@ def visit_Constant(self, node):
12641276
return constexpr(node.value)
12651277

12661278
def visit_BoolOp(self, node: ast.BoolOp):
1267-
if len(node.values) != 2:
1268-
raise self._unsupported(
1269-
node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.")
1270-
lhs = self.visit(node.values[0])
1271-
rhs = self.visit(node.values[1])
12721279
method_name = self._method_name_for_bool_op.get(type(node.op))
12731280
if method_name is None:
12741281
raise self._unsupported(
12751282
node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__))
1276-
return self._apply_binary_method(method_name, lhs, rhs)
1283+
1284+
nontrivial_values = []
1285+
1286+
for subnode in node.values:
1287+
# we visit the values in order, executing their side-effects
1288+
# and possibly early-exiting:
1289+
value = self.visit(subnode)
1290+
if not _is_triton_tensor(value):
1291+
# this is a constexpr, so we might be able to short-circuit:
1292+
bv = bool(value)
1293+
if (bv is False) and (method_name == "logical_and"):
1294+
# value is falsey so return that:
1295+
return value
1296+
if (bv is True) and (method_name == "logical_or"):
1297+
# value is truthy so return that:
1298+
return value
1299+
# otherwise, our constexpr has no effect on the output of the
1300+
# expression so we do not append it to nontrivial_values.
1301+
else:
1302+
if value.type.is_block():
1303+
warnings.warn(
1304+
"Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead"
1305+
)
1306+
# not a constexpr so we must append it:
1307+
nontrivial_values.append(value)
1308+
1309+
if len(nontrivial_values) == 0:
1310+
# the semantics of a disjunction of falsey values or conjunction
1311+
# of truthy values is to return the final value:
1312+
nontrivial_values.append(value)
1313+
1314+
while len(nontrivial_values) >= 2:
1315+
rhs = nontrivial_values.pop()
1316+
lhs = nontrivial_values.pop()
1317+
res = self._apply_binary_method(method_name, lhs, rhs)
1318+
nontrivial_values.append(res)
1319+
1320+
assert len(nontrivial_values) == 1
1321+
return nontrivial_values[0]
12771322

12781323
_method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'}
12791324

python/triton/language/core.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,6 +1718,22 @@ def _take_first(a, b):
17181718
return a
17191719

17201720

1721+
def _unsplat(x, _builder=None, _generator=None):
1722+
"""
1723+
Convert a single-element tensor to a scalar.
1724+
"""
1725+
if len(x.shape) == 0:
1726+
return x
1727+
numel = 1
1728+
for d in x.shape:
1729+
numel *= d
1730+
assert numel == 1, "can only unsplat single-element tensors"
1731+
if len(x.shape) >= 2:
1732+
x = semantic.reshape(x, [1], builder=_builder)
1733+
x = typing.cast(tensor, reduce(x, 0, _take_first, _builder=_builder, _generator=_generator))
1734+
return x
1735+
1736+
17211737
@_tensor_member_fn
17221738
@builtin
17231739
def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]:
@@ -1747,8 +1763,8 @@ def split(a, _builder=None, _generator=None) -> tuple[tensor, tensor]:
17471763

17481764
if was_rank_1:
17491765
# Currently `reduce` is the best way to convert a tensor of shape [1] to a scalar.
1750-
out_lhs = typing.cast(tensor, reduce(out_lhs, None, _take_first, _builder=_builder, _generator=_generator))
1751-
out_rhs = typing.cast(tensor, reduce(out_rhs, None, _take_first, _builder=_builder, _generator=_generator))
1766+
out_lhs = _unsplat(out_lhs, _builder, _generator)
1767+
out_rhs = _unsplat(out_rhs, _builder, _generator)
17521768

17531769
return out_lhs, out_rhs
17541770

@@ -1777,7 +1793,7 @@ def view(input, *shape, _builder=None):
17771793

17781794
@_tensor_member_fn
17791795
@builtin
1780-
def reshape(input, *shape, can_reorder=False, _builder=None):
1796+
def reshape(input, *shape, can_reorder=False, _builder=None, _generator=None):
17811797
"""
17821798
Returns a tensor with the same number of elements as input but with the
17831799
provided shape.
@@ -1793,6 +1809,8 @@ def reshape(input, *shape, can_reorder=False, _builder=None):
17931809
reshape(x, 32, 32)
17941810
"""
17951811
shape = _shape_check_impl(_unwrap_iterable(shape))
1812+
if len(shape) == 0:
1813+
return _unsplat(input, _builder=_builder, _generator=_generator)
17961814
return semantic.reshape(input, shape, can_reorder, _builder)
17971815

17981816

0 commit comments

Comments
 (0)