Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 74 additions & 10 deletions thunder/tests/test_elementwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from functools import partial
import builtins
import math
import operator

import torch
from torch.testing import assert_close, make_tensor
Expand All @@ -11,21 +13,83 @@
from thunder.tests.framework import instantiate, NOTHING


# TODO Enable the remaining elementwise unary operations (following the pattern of abs)
# TODO Expand testing to elementwise binary operations (following a similar pattern)
@instantiate(dtypes=NOTHING, devicetypes=(devices.DeviceType.CPU,))
def test_elementwise_binary_operations_on_numbers(executor, device, dtype):
# op, allowed a-types, allowed b-types, special handling
elementwise_binary_ops = (
(operator.add, (bool, int, float), (bool, int, float), None),
(operator.sub, (bool, int, float), (bool, int, float), None),
(operator.mul, (bool, int, float), (bool, int, float), None),
(operator.truediv, (bool, int, float), (bool, int, float), "nonzero_only"),
(operator.floordiv, (bool, int, float), (bool, int, float), "nonzero_only"),
(operator.mod, (bool, int, float), (bool, int, float), "nonzero_only"),
(operator.pow, (bool, int, float, complex), (int,), "pow_exponent"),
(operator.and_, (bool, int), (bool, int), None),
(operator.or_, (bool, int), (bool, int), None),
(operator.xor, (bool, int), (bool, int), None),
(operator.lshift, (bool, int), (int,), "shift_count"),
(operator.rshift, (bool, int), (int,), "shift_count"),
)

bool_inps = [False, True]
int_inps = [-1, 0, 2]
float_inps = [-0.7, 0.0, 0.3, 1.1]
complex_inps = [complex(1, 0.3), complex(-4.1, 0.9)]

exponent_inps = [0, 1, 2]
shift_inps = [0, 1, 2]

_type_to_input_map = {
bool: bool_inps,
int: int_inps,
float: float_inps,
complex: complex_inps,
}

def gather_inputs(allowed_types):
inps = []
for typ in allowed_types:
inps.extend(_type_to_input_map[typ])
return inps

def filter_b_values(vals, special):
if special == "nonzero_only":
return [v for v in vals if v != 0]
if special == "pow_exponent":
return exponent_inps
if special == "shift_count":
return shift_inps
return vals

for op, a_types, b_types, special in elementwise_binary_ops:

def foo(a, b):
return op(a, b)

cfoo = executor.make_callable(foo)

a_vals = gather_inputs(a_types)
b_vals = filter_b_values(gather_inputs(b_types), special)

for a in a_vals:
for b in b_vals:
actual = cfoo(a, b)
expected = foo(a, b)
assert_close(actual, expected)


@instantiate(dtypes=NOTHING, devicetypes=(devices.DeviceType.CPU,))
def test_elementwise_dunder_operations_on_numbers(executor, device, dtype):
# op, allowed types
elementwise_unary_ops = (
(builtins.abs, (bool, int, float, complex)),
# (math.ceil, (bool, int, float)),
# (math.floor, (bool, int, float)),
# (operator.inv, (bool, int)),
# (operator.neg, (bool, int, float, complex)),
# # TODO see issue "Implement positive operations"
# # operator.pos,
# (builtins.round, (bool, int, float)),
# (math.trunc, (bool, int, float)),
(math.ceil, (bool, int, float)),
(math.floor, (bool, int, float)),
(operator.inv, (bool, int)),
(operator.neg, (bool, int, float, complex)),
(operator.pos, (bool, int, float, complex)),
(builtins.round, (bool, int, float)),
(math.trunc, (bool, int, float)),
)

bool_inps = [False, True]
Expand Down
Loading