Skip to content

[EZ] Check for stable outputs #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion BackendBench/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import triton.testing


from BackendBench.utils import uses_cuda_stream
from BackendBench.utils import uses_cuda_stream, check_for_constant_output, check_constant_inputs
from BackendBench.utils import serialize_args

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,6 +48,17 @@ def eval_correctness_test(op, impl, test):
def eval_correctness(op, impl, tests):
correct, total = 0, 0
for test in tests:
if check_for_constant_output(op, serialize_args(test.args, test.kwargs)):
logger.warning(
f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the output is always the same"
)
continue
if check_constant_inputs(test.args, test.kwargs):
logger.warning(
f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the a tensor in the inputs is mostly ones/zeros or contain NaNs"
)
continue

logging.debug(f"Testing {op.__name__} with args {serialize_args(test.args, test.kwargs)}")
if eval_correctness_test(op, impl, test):
correct += 1
Expand All @@ -73,6 +84,18 @@ def eval_performance(op, impl, tests):
base_times = []
test_times = []
for test in tests:
if check_for_constant_output(op, serialize_args(test.args, test.kwargs)):
logger.warning(
f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the output is always the same"
)
continue

if check_constant_inputs(test.args, test.kwargs):
logger.warning(
f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the a tensor in the inputs is mostly ones/zeros or contain NaNs"
)
continue

logging.debug(
f"Benchmarking {op.__name__} with args {serialize_args(test.args, test.kwargs)}"
)
Expand All @@ -94,6 +117,7 @@ def eval_one_op(op, impl, correctness_tests, performance_tests):
if uses_cuda_stream(impl):
logger.warning(f"Skipping {op.__name__} because it uses CUDA stream")
return 0, 0

return eval_correctness(op, impl, correctness_tests), eval_performance(
op, impl, performance_tests
)
35 changes: 35 additions & 0 deletions BackendBench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,38 @@ def deserialize_args(inps):
for key in dtype_abbrs_parsing:
inps = inps.replace(f"'{key}'", key)
return eval(inps.strip().strip("'").strip('"'), global_vals)


def check_for_constant_output(op, inps, n_iterations=10):
op_func = eval(f"torch.ops.{op}")
args, kwargs = deserialize_args(inps)
initial_output = op_func(*args, **kwargs)
for _ in range(n_iterations):
args, kwargs = deserialize_args(inps)
output = op_func(*args, **kwargs)
if not torch.allclose(initial_output, output, atol=1e-2, rtol=1e-2):
return False
return True


def check_constant_inputs(args, kwargs, threshold=0.01):
"""Check if any tensor in args or kwargs is mostly zeros, ones, or NaNs"""

def _check_tensor(tensor):
zeros_tensor = torch.zeros_like(tensor)
ones_tensor = torch.ones_like(tensor)
return (
torch.allclose(tensor, zeros_tensor, atol=threshold)
or torch.allclose(tensor, ones_tensor, atol=threshold)
or torch.isnan(tensor).any()
)

for arg in args:
if isinstance(arg, torch.Tensor):
if _check_tensor(arg):
return True
for value in kwargs.values():
if isinstance(value, torch.Tensor):
if _check_tensor(value):
return True
return False
90 changes: 90 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
deserialize_args,
_deserialize_tensor,
uses_cuda_stream,
check_for_constant_output,
check_constant_inputs,
)

# Check if CUDA is available
Expand Down Expand Up @@ -531,5 +533,93 @@ def test_integer_tensors(self):
assert tensor.shape == (10,)


class TestCheckForConstantOutput:
"""Test cases for check_for_constant_output function"""

def test_constant_zeros_op(self):
"""Test that zeros creation is constant"""
op = "aten.zeros"
inps = "(([3, 4],), {'dtype': torch.float32})"

result = check_for_constant_output(op, inps, n_iterations=5)
assert result

def test_unconstant_random_op(self):
"""Test that random operations are correctly detected as unconstant"""
op = "aten.randn"
inps = "(([3, 3],), {'dtype': torch.float32})"

result = check_for_constant_output(op, inps, n_iterations=5)
assert not result


class TestCheckConstantInputs:
"""Test cases for check_constant_inputs function"""

def test_zeros_tensor(self):
"""Test detection of all-zeros and mostly-zeros tensors"""
# Test all zeros
zeros = torch.zeros(5, 5)
args = (zeros,)
kwargs = {}
assert check_constant_inputs(args, kwargs)

# Test mostly zeros (within threshold)
mostly_zeros = torch.zeros(10, 10)
mostly_zeros[0, 0] = 0.005 # Small deviation within default threshold of 0.01
args = (mostly_zeros,)
kwargs = {}
assert check_constant_inputs(args, kwargs)

def test_ones_tensor(self):
"""Test detection of all-ones and mostly-ones tensors"""
# Test all ones
ones = torch.ones(3, 4)
args = (ones,)
kwargs = {}
assert check_constant_inputs(args, kwargs)

# Test mostly ones (within threshold)
mostly_ones = torch.ones(10, 10)
mostly_ones[0, 0] = 0.995 # Small deviation within default threshold of 0.01
args = (mostly_ones,)
kwargs = {}
assert check_constant_inputs(args, kwargs)

def test_nan_tensor(self):
"""Test detection of tensor with NaN values"""
nan_tensor = torch.tensor([1.0, float("nan"), 3.0])
args = (nan_tensor,)
kwargs = {}

assert check_constant_inputs(args, kwargs)

def test_random_tensor(self):
"""Test that random tensor is not flagged as constant"""
random = torch.randn(5, 5)
args = (random,)
kwargs = {}

assert not check_constant_inputs(args, kwargs)

def test_kwargs_with_ones(self):
"""Test detection in kwargs"""
ones = torch.ones(2, 2)
args = ()
kwargs = {"weight": ones}

assert check_constant_inputs(args, kwargs)

def test_mixed_inputs(self):
"""Test with both constant and non-constant tensors"""
zeros = torch.zeros(3, 3)
random = torch.randn(3, 3)
args = (random, zeros)
kwargs = {}

# Should return True because at least one tensor is constant
assert check_constant_inputs(args, kwargs)


if __name__ == "__main__":
pytest.main([__file__])