From 0609b8e273be0aa75ccb7e9c7ced92ec2fe35323 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Tue, 12 Aug 2025 11:18:36 -0700 Subject: [PATCH 1/4] [ez] Check for stable outputs --- BackendBench/eval.py | 13 ++++++++++++- BackendBench/utils.py | 12 ++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/BackendBench/eval.py b/BackendBench/eval.py index d5b1d6f..5ea54d3 100644 --- a/BackendBench/eval.py +++ b/BackendBench/eval.py @@ -5,7 +5,7 @@ import triton.testing -from BackendBench.utils import uses_cuda_stream +from BackendBench.utils import uses_cuda_stream, check_for_stable_output from BackendBench.utils import serialize_args logger = logging.getLogger(__name__) @@ -48,6 +48,11 @@ def eval_correctness_test(op, impl, test): def eval_correctness(op, impl, tests): correct, total = 0, 0 for test in tests: + if check_for_stable_output(op, serialize_args(test.args, test.kwargs)): + logger.warning( + f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the output is always the same" + ) + continue logging.debug(f"Testing {op.__name__} with args {serialize_args(test.args, test.kwargs)}") if eval_correctness_test(op, impl, test): correct += 1 @@ -73,6 +78,11 @@ def eval_performance(op, impl, tests): base_times = [] test_times = [] for test in tests: + if check_for_stable_output(op, serialize_args(test.args, test.kwargs)): + logger.warning( + f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the output is always the same" + ) + continue logging.debug( f"Benchmarking {op.__name__} with args {serialize_args(test.args, test.kwargs)}" ) @@ -94,6 +104,7 @@ def eval_one_op(op, impl, correctness_tests, performance_tests): if uses_cuda_stream(impl): logger.warning(f"Skipping {op.__name__} because it uses CUDA stream") return 0, 0 + return eval_correctness(op, impl, correctness_tests), eval_performance( op, impl, performance_tests ) diff --git a/BackendBench/utils.py b/BackendBench/utils.py index 600934f..b28b09e 100644 --- a/BackendBench/utils.py +++ b/BackendBench/utils.py @@ -153,3 +153,15 @@ def deserialize_args(inps): for key in dtype_abbrs_parsing: inps = inps.replace(f"'{key}'", key) return eval(inps.strip().strip("'").strip('"'), global_vals) + + +def check_for_stable_output(op, inps, n_iterations=10): + op_func = eval(f"torch.ops.{op}") + args, kwargs = deserialize_args(inps) + initial_output = op_func(*args, **kwargs) + for _ in range(n_iterations): + args, kwargs = deserialize_args(inps) + output = op_func(*args, **kwargs) + if not torch.allclose(initial_output, output, atol=1e-2, rtol=1e-2): + return False + return True From 2e7ad3b253cbaff13e1841cd70da741ca0caa248 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Tue, 12 Aug 2025 11:19:11 -0700 Subject: [PATCH 2/4] add test --- test/test_utils.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index cf2d095..efa3ce8 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,6 +7,7 @@ deserialize_args, _deserialize_tensor, uses_cuda_stream, + check_for_stable_output, ) # Check if CUDA is available @@ -531,5 +532,25 @@ def test_integer_tensors(self): assert tensor.shape == (10,) +class TestCheckForStableOutput: + """Test cases for check_for_stable_output function""" + + def test_stable_zeros_op(self): + """Test that zeros creation is stable""" + op = "aten.zeros" + inps = "(([3, 4],), {{'dtype': torch.float32, 'device': 'cuda'}})" + + result = check_for_stable_output(op, inps, n_iterations=5) + assert result + + def test_unstable_random_op(self): + """Test that random operations are correctly detected as unstable""" + op = "aten.randn" + inps = "(([3, 3],), {{'dtype': torch.float32, 'device': 'cuda'}})" + + result = check_for_stable_output(op, inps, n_iterations=5) + assert not result + + if __name__ == "__main__": pytest.main([__file__]) From fb4fba40aa5623a7fc2354f6d87065b4674c37c1 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Tue, 12 Aug 2025 11:26:00 -0700 Subject: [PATCH 3/4] fix test --- test/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index efa3ce8..303e770 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -538,7 +538,7 @@ class TestCheckForStableOutput: def test_stable_zeros_op(self): """Test that zeros creation is stable""" op = "aten.zeros" - inps = "(([3, 4],), {{'dtype': torch.float32, 'device': 'cuda'}})" + inps = "(([3, 4],), {'dtype': torch.float32})" result = check_for_stable_output(op, inps, n_iterations=5) assert result @@ -546,7 +546,7 @@ def test_stable_zeros_op(self): def test_unstable_random_op(self): """Test that random operations are correctly detected as unstable""" op = "aten.randn" - inps = "(([3, 3],), {{'dtype': torch.float32, 'device': 'cuda'}})" + inps = "(([3, 3],), {'dtype': torch.float32})" result = check_for_stable_output(op, inps, n_iterations=5) assert not result From deeb8676897b5a5287ac63fa07bfbc7f69e9bb37 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 14 Aug 2025 11:21:30 -0700 Subject: [PATCH 4/4] add inputs --- BackendBench/eval.py | 19 ++++++++-- BackendBench/utils.py | 25 ++++++++++++- test/test_utils.py | 87 ++++++++++++++++++++++++++++++++++++++----- 3 files changed, 118 insertions(+), 13 deletions(-) diff --git a/BackendBench/eval.py b/BackendBench/eval.py index 5ea54d3..608f077 100644 --- a/BackendBench/eval.py +++ b/BackendBench/eval.py @@ -5,7 +5,7 @@ import triton.testing -from BackendBench.utils import uses_cuda_stream, check_for_stable_output +from BackendBench.utils import uses_cuda_stream, check_for_constant_output, check_constant_inputs from BackendBench.utils import serialize_args logger = logging.getLogger(__name__) @@ -48,11 +48,17 @@ def eval_correctness_test(op, impl, test): def eval_correctness(op, impl, tests): correct, total = 0, 0 for test in tests: - if check_for_stable_output(op, serialize_args(test.args, test.kwargs)): + if check_for_constant_output(op, serialize_args(test.args, test.kwargs)): logger.warning( f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the output is always the same" ) continue + if check_constant_inputs(test.args, test.kwargs): + logger.warning( + f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the a tensor in the inputs is mostly ones/zeros or contain NaNs" + ) + continue + logging.debug(f"Testing {op.__name__} with args {serialize_args(test.args, test.kwargs)}") if eval_correctness_test(op, impl, test): correct += 1 @@ -78,11 +84,18 @@ def eval_performance(op, impl, tests): base_times = [] test_times = [] for test in tests: - if check_for_stable_output(op, serialize_args(test.args, test.kwargs)): + if check_for_constant_output(op, serialize_args(test.args, test.kwargs)): logger.warning( f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the output is always the same" ) continue + + if check_constant_inputs(test.args, test.kwargs): + logger.warning( + f"Skipping {op.__name__} with args {serialize_args(test.args, test.kwargs)} because the a tensor in the inputs is mostly ones/zeros or contain NaNs" + ) + continue + logging.debug( f"Benchmarking {op.__name__} with args {serialize_args(test.args, test.kwargs)}" ) diff --git a/BackendBench/utils.py b/BackendBench/utils.py index b28b09e..1edac78 100644 --- a/BackendBench/utils.py +++ b/BackendBench/utils.py @@ -155,7 +155,7 @@ def deserialize_args(inps): return eval(inps.strip().strip("'").strip('"'), global_vals) -def check_for_stable_output(op, inps, n_iterations=10): +def check_for_constant_output(op, inps, n_iterations=10): op_func = eval(f"torch.ops.{op}") args, kwargs = deserialize_args(inps) initial_output = op_func(*args, **kwargs) @@ -165,3 +165,26 @@ def check_for_stable_output(op, inps, n_iterations=10): if not torch.allclose(initial_output, output, atol=1e-2, rtol=1e-2): return False return True + + +def check_constant_inputs(args, kwargs, threshold=0.01): + """Check if any tensor in args or kwargs is mostly zeros, ones, or NaNs""" + + def _check_tensor(tensor): + zeros_tensor = torch.zeros_like(tensor) + ones_tensor = torch.ones_like(tensor) + return ( + torch.allclose(tensor, zeros_tensor, atol=threshold) + or torch.allclose(tensor, ones_tensor, atol=threshold) + or torch.isnan(tensor).any() + ) + + for arg in args: + if isinstance(arg, torch.Tensor): + if _check_tensor(arg): + return True + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + if _check_tensor(value): + return True + return False diff --git a/test/test_utils.py b/test/test_utils.py index 303e770..117e2cd 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,7 +7,8 @@ deserialize_args, _deserialize_tensor, uses_cuda_stream, - check_for_stable_output, + check_for_constant_output, + check_constant_inputs, ) # Check if CUDA is available @@ -532,25 +533,93 @@ def test_integer_tensors(self): assert tensor.shape == (10,) -class TestCheckForStableOutput: - """Test cases for check_for_stable_output function""" +class TestCheckForConstantOutput: + """Test cases for check_for_constant_output function""" - def test_stable_zeros_op(self): - """Test that zeros creation is stable""" + def test_constant_zeros_op(self): + """Test that zeros creation is constant""" op = "aten.zeros" inps = "(([3, 4],), {'dtype': torch.float32})" - result = check_for_stable_output(op, inps, n_iterations=5) + result = check_for_constant_output(op, inps, n_iterations=5) assert result - def test_unstable_random_op(self): - """Test that random operations are correctly detected as unstable""" + def test_unconstant_random_op(self): + """Test that random operations are correctly detected as unconstant""" op = "aten.randn" inps = "(([3, 3],), {'dtype': torch.float32})" - result = check_for_stable_output(op, inps, n_iterations=5) + result = check_for_constant_output(op, inps, n_iterations=5) assert not result +class TestCheckConstantInputs: + """Test cases for check_constant_inputs function""" + + def test_zeros_tensor(self): + """Test detection of all-zeros and mostly-zeros tensors""" + # Test all zeros + zeros = torch.zeros(5, 5) + args = (zeros,) + kwargs = {} + assert check_constant_inputs(args, kwargs) + + # Test mostly zeros (within threshold) + mostly_zeros = torch.zeros(10, 10) + mostly_zeros[0, 0] = 0.005 # Small deviation within default threshold of 0.01 + args = (mostly_zeros,) + kwargs = {} + assert check_constant_inputs(args, kwargs) + + def test_ones_tensor(self): + """Test detection of all-ones and mostly-ones tensors""" + # Test all ones + ones = torch.ones(3, 4) + args = (ones,) + kwargs = {} + assert check_constant_inputs(args, kwargs) + + # Test mostly ones (within threshold) + mostly_ones = torch.ones(10, 10) + mostly_ones[0, 0] = 0.995 # Small deviation within default threshold of 0.01 + args = (mostly_ones,) + kwargs = {} + assert check_constant_inputs(args, kwargs) + + def test_nan_tensor(self): + """Test detection of tensor with NaN values""" + nan_tensor = torch.tensor([1.0, float("nan"), 3.0]) + args = (nan_tensor,) + kwargs = {} + + assert check_constant_inputs(args, kwargs) + + def test_random_tensor(self): + """Test that random tensor is not flagged as constant""" + random = torch.randn(5, 5) + args = (random,) + kwargs = {} + + assert not check_constant_inputs(args, kwargs) + + def test_kwargs_with_ones(self): + """Test detection in kwargs""" + ones = torch.ones(2, 2) + args = () + kwargs = {"weight": ones} + + assert check_constant_inputs(args, kwargs) + + def test_mixed_inputs(self): + """Test with both constant and non-constant tensors""" + zeros = torch.zeros(3, 3) + random = torch.randn(3, 3) + args = (random, zeros) + kwargs = {} + + # Should return True because at least one tensor is constant + assert check_constant_inputs(args, kwargs) + + if __name__ == "__main__": pytest.main([__file__])