diff --git a/BackendBench/eval.py b/BackendBench/eval.py index d5b1d6f..608f077 100644 --- a/BackendBench/eval.py +++ b/BackendBench/eval.py @@ -5,7 +5,7 @@ import triton.testing -from BackendBench.utils import uses_cuda_stream +from BackendBench.utils import uses_cuda_stream, check_for_constant_output, check_constant_inputs from BackendBench.utils import serialize_args logger = logging.getLogger(__name__) @@ -48,6 +48,17 @@ def eval_correctness_test(op, impl, test): def eval_correctness(op, impl, tests): correct, total = 0, 0 for test in tests: + if check_for_constant_output(op, serialize_args(test.args, test.kwargs)): + logger.warning( + f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the output is always the same" + ) + continue + if check_constant_inputs(test.args, test.kwargs): + logger.warning( + f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the a tensor in the inputs is mostly ones/zeros or contain NaNs" + ) + continue + logging.debug(f"Testing {op.__name__} with args {serialize_args(test.args, test.kwargs)}") if eval_correctness_test(op, impl, test): correct += 1 @@ -73,6 +84,18 @@ def eval_performance(op, impl, tests): base_times = [] test_times = [] for test in tests: + if check_for_constant_output(op, serialize_args(test.args, test.kwargs)): + logger.warning( + f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the output is always the same" + ) + continue + + if check_constant_inputs(test.args, test.kwargs): + logger.warning( + f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the a tensor in the inputs is mostly ones/zeros or contain NaNs" + ) + continue + logging.debug( f"Benchmarking {op.__name__} with args {serialize_args(test.args, test.kwargs)}" ) @@ -94,6 +117,7 @@ def eval_one_op(op, impl, correctness_tests, performance_tests): if uses_cuda_stream(impl): logger.warning(f"Skipping {op.__name__} because it uses CUDA stream") return 0, 0 + return eval_correctness(op, impl, correctness_tests), eval_performance( op, impl, performance_tests ) diff --git a/BackendBench/utils.py b/BackendBench/utils.py index 600934f..1edac78 100644 --- a/BackendBench/utils.py +++ b/BackendBench/utils.py @@ -153,3 +153,38 @@ def deserialize_args(inps): for key in dtype_abbrs_parsing: inps = inps.replace(f"'{key}'", key) return eval(inps.strip().strip("'").strip('"'), global_vals) + + +def check_for_constant_output(op, inps, n_iterations=10): + op_func = eval(f"torch.ops.{op}") + args, kwargs = deserialize_args(inps) + initial_output = op_func(*args, **kwargs) + for _ in range(n_iterations): + args, kwargs = deserialize_args(inps) + output = op_func(*args, **kwargs) + if not torch.allclose(initial_output, output, atol=1e-2, rtol=1e-2): + return False + return True + + +def check_constant_inputs(args, kwargs, threshold=0.01): + """Check if any tensor in args or kwargs is mostly zeros, ones, or NaNs""" + + def _check_tensor(tensor): + zeros_tensor = torch.zeros_like(tensor) + ones_tensor = torch.ones_like(tensor) + return ( + torch.allclose(tensor, zeros_tensor, atol=threshold) + or torch.allclose(tensor, ones_tensor, atol=threshold) + or torch.isnan(tensor).any() + ) + + for arg in args: + if isinstance(arg, torch.Tensor): + if _check_tensor(arg): + return True + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + if _check_tensor(value): + return True + return False diff --git a/test/test_utils.py b/test/test_utils.py index cf2d095..117e2cd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,6 +7,8 @@ deserialize_args, _deserialize_tensor, uses_cuda_stream, + check_for_constant_output, + check_constant_inputs, ) # Check if CUDA is available @@ -531,5 +533,93 @@ def test_integer_tensors(self): assert tensor.shape == (10,) +class TestCheckForConstantOutput: + """Test cases for check_for_constant_output function""" + + def test_constant_zeros_op(self): + """Test that zeros creation is constant""" + op = "aten.zeros" + inps = "(([3, 4],), {'dtype': torch.float32})" + + result = check_for_constant_output(op, inps, n_iterations=5) + assert result + + def test_unconstant_random_op(self): + """Test that random operations are correctly detected as unconstant""" + op = "aten.randn" + inps = "(([3, 3],), {'dtype': torch.float32})" + + result = check_for_constant_output(op, inps, n_iterations=5) + assert not result + + +class TestCheckConstantInputs: + """Test cases for check_constant_inputs function""" + + def test_zeros_tensor(self): + """Test detection of all-zeros and mostly-zeros tensors""" + # Test all zeros + zeros = torch.zeros(5, 5) + args = (zeros,) + kwargs = {} + assert check_constant_inputs(args, kwargs) + + # Test mostly zeros (within threshold) + mostly_zeros = torch.zeros(10, 10) + mostly_zeros[0, 0] = 0.005 # Small deviation within default threshold of 0.01 + args = (mostly_zeros,) + kwargs = {} + assert check_constant_inputs(args, kwargs) + + def test_ones_tensor(self): + """Test detection of all-ones and mostly-ones tensors""" + # Test all ones + ones = torch.ones(3, 4) + args = (ones,) + kwargs = {} + assert check_constant_inputs(args, kwargs) + + # Test mostly ones (within threshold) + mostly_ones = torch.ones(10, 10) + mostly_ones[0, 0] = 0.995 # Small deviation within default threshold of 0.01 + args = (mostly_ones,) + kwargs = {} + assert check_constant_inputs(args, kwargs) + + def test_nan_tensor(self): + """Test detection of tensor with NaN values""" + nan_tensor = torch.tensor([1.0, float("nan"), 3.0]) + args = (nan_tensor,) + kwargs = {} + + assert check_constant_inputs(args, kwargs) + + def test_random_tensor(self): + """Test that random tensor is not flagged as constant""" + random = torch.randn(5, 5) + args = (random,) + kwargs = {} + + assert not check_constant_inputs(args, kwargs) + + def test_kwargs_with_ones(self): + """Test detection in kwargs""" + ones = torch.ones(2, 2) + args = () + kwargs = {"weight": ones} + + assert check_constant_inputs(args, kwargs) + + def test_mixed_inputs(self): + """Test with both constant and non-constant tensors""" + zeros = torch.zeros(3, 3) + random = torch.randn(3, 3) + args = (random, zeros) + kwargs = {} + + # Should return True because at least one tensor is constant + assert check_constant_inputs(args, kwargs) + + if __name__ == "__main__": pytest.main([__file__])