From f66c0f538613bc883d812a9dc5d1bd0dd79aee36 Mon Sep 17 00:00:00 2001 From: Arseniy Obolenskiy Date: Fri, 3 Oct 2025 15:39:36 +0200 Subject: [PATCH] tests: Extend testing for dunder and binary elementwise operations --- thunder/tests/test_elementwise.py | 84 +++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 10 deletions(-) diff --git a/thunder/tests/test_elementwise.py b/thunder/tests/test_elementwise.py index 7cdd9d156f..2286f0e971 100644 --- a/thunder/tests/test_elementwise.py +++ b/thunder/tests/test_elementwise.py @@ -1,5 +1,7 @@ from functools import partial import builtins +import math +import operator import torch from torch.testing import assert_close, make_tensor @@ -11,21 +13,83 @@ from thunder.tests.framework import instantiate, NOTHING -# TODO Enable the remaining elementwise unary operations (following the pattern of abs) -# TODO Expand testing to elementwise binary operations (following a similar pattern) +@instantiate(dtypes=NOTHING, devicetypes=(devices.DeviceType.CPU,)) +def test_elementwise_binary_operations_on_numbers(executor, device, dtype): + # op, allowed a-types, allowed b-types, special handling + elementwise_binary_ops = ( + (operator.add, (bool, int, float), (bool, int, float), None), + (operator.sub, (bool, int, float), (bool, int, float), None), + (operator.mul, (bool, int, float), (bool, int, float), None), + (operator.truediv, (bool, int, float), (bool, int, float), "nonzero_only"), + (operator.floordiv, (bool, int, float), (bool, int, float), "nonzero_only"), + (operator.mod, (bool, int, float), (bool, int, float), "nonzero_only"), + (operator.pow, (bool, int, float, complex), (int,), "pow_exponent"), + (operator.and_, (bool, int), (bool, int), None), + (operator.or_, (bool, int), (bool, int), None), + (operator.xor, (bool, int), (bool, int), None), + (operator.lshift, (bool, int), (int,), "shift_count"), + (operator.rshift, (bool, int), (int,), "shift_count"), + ) + + bool_inps = [False, True] + int_inps = [-1, 0, 2] + float_inps = [-0.7, 0.0, 0.3, 1.1] + complex_inps = [complex(1, 0.3), complex(-4.1, 0.9)] + + exponent_inps = [0, 1, 2] + shift_inps = [0, 1, 2] + + _type_to_input_map = { + bool: bool_inps, + int: int_inps, + float: float_inps, + complex: complex_inps, + } + + def gather_inputs(allowed_types): + inps = [] + for typ in allowed_types: + inps.extend(_type_to_input_map[typ]) + return inps + + def filter_b_values(vals, special): + if special == "nonzero_only": + return [v for v in vals if v != 0] + if special == "pow_exponent": + return exponent_inps + if special == "shift_count": + return shift_inps + return vals + + for op, a_types, b_types, special in elementwise_binary_ops: + + def foo(a, b): + return op(a, b) + + cfoo = executor.make_callable(foo) + + a_vals = gather_inputs(a_types) + b_vals = filter_b_values(gather_inputs(b_types), special) + + for a in a_vals: + for b in b_vals: + actual = cfoo(a, b) + expected = foo(a, b) + assert_close(actual, expected) + + @instantiate(dtypes=NOTHING, devicetypes=(devices.DeviceType.CPU,)) def test_elementwise_dunder_operations_on_numbers(executor, device, dtype): # op, allowed types elementwise_unary_ops = ( (builtins.abs, (bool, int, float, complex)), - # (math.ceil, (bool, int, float)), - # (math.floor, (bool, int, float)), - # (operator.inv, (bool, int)), - # (operator.neg, (bool, int, float, complex)), - # # TODO see issue "Implement positive operations" - # # operator.pos, - # (builtins.round, (bool, int, float)), - # (math.trunc, (bool, int, float)), + (math.ceil, (bool, int, float)), + (math.floor, (bool, int, float)), + (operator.inv, (bool, int)), + (operator.neg, (bool, int, float, complex)), + (operator.pos, (bool, int, float, complex)), + (builtins.round, (bool, int, float)), + (math.trunc, (bool, int, float)), ) bool_inps = [False, True]