From 8803f09a70ac57731a52dc574959bfc95c0a5599 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Tue, 29 Jul 2025 16:38:30 -0700 Subject: [PATCH 01/32] Add tests for serialization and deserialization --- BackendBench/scripts/utils.py | 102 +++++++++ test/test_utils.py | 379 ++++++++++++++++++++++++++++++++++ 2 files changed, 481 insertions(+) create mode 100644 BackendBench/scripts/utils.py create mode 100644 test/test_utils.py diff --git a/BackendBench/scripts/utils.py b/BackendBench/scripts/utils.py new file mode 100644 index 0000000..4f1a31a --- /dev/null +++ b/BackendBench/scripts/utils.py @@ -0,0 +1,102 @@ +import math +import torch +from torch.testing import make_tensor + +dtype_abbrs = { + torch.bfloat16: "bf16", + torch.float64: "f64", + torch.float32: "f32", + torch.float16: "f16", + torch.complex32: "c32", + torch.complex64: "c64", + torch.complex128: "c128", + torch.int8: "i8", + torch.int16: "i16", + torch.int32: "i32", + torch.int64: "i64", + torch.bool: "b8", + torch.uint8: "u8", +} + +dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()} + +_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64] + + +def _deserialize_tensor(size, dtype, stride=None, device="cuda"): + kwargs = {} + if dtype in _FLOATING_TYPES: + kwargs.update({"low": 0, "high": 1}) + + # Fall back to CPU if CUDA is not available + if device == "cuda" and not torch.cuda.is_available(): + device = "cpu" + + if stride is not None: + extent = 1 + sum((size - 1) * stride for size, stride in zip(size, stride)) + data = make_tensor(extent, dtype=dtype, device=device, **kwargs) + return data.as_strided(size, stride) + return make_tensor(size, dtype=dtype, device=device, **kwargs) + +def _serialize_tensor(tensor): + """Helper function to serialize a tensor to string format""" + shape = list(tensor.shape) + dtype = dtype_abbrs[tensor.dtype] + stride = tensor.stride() if not tensor.is_contiguous() else None + + if stride: + return f"T({shape}, {dtype}, {list(stride)})" + else: + return f"T({shape}, {dtype})" + + +def _serialize_value(value): + """Helper function to serialize any value (tensor, list, primitive)""" + if isinstance(value, torch.Tensor): + return _serialize_tensor(value) + elif isinstance(value, list): + list_parts = [_serialize_value(item) for item in value] + return f"[{', '.join(list_parts)}]" + else: + return repr(value) + + +def serialize_args(args, kwargs) -> str: + """Convert args and kwargs back to the BackendBench string format + + Args: + args: List of arguments (can contain tensors, lists, primitives) + kwargs: Dict of keyword arguments + + Returns: + Serialized string in format: (arg1, arg2, ..., key1=val1, key2=val2, ...) + """ + if args is None or kwargs is None: + return "None" + + # Process positional arguments + parts = [_serialize_value(arg) for arg in args] + + # Process keyword arguments + kwargs_parts = [f"{key}={_serialize_value(val)}" for key, val in kwargs.items()] + + return f"(({', '.join(parts)},), {{{', '.join(kwargs_parts)}}})" + + +# Alias for backward compatibility +reserialize_args = serialize_args + + +def deserialize_args(inps): + inps = inps.strip().strip("'") + global_vals = { + "T": _deserialize_tensor, + "th": torch, + "inf": math.inf, + "torch": torch, + **dtype_abbrs_parsing, + } + # f strings introduce quotations we dont want + for key in dtype_abbrs_parsing: + inps = inps.replace(f"'{key}'", key) + return eval(inps.strip().strip("'").strip('"'), global_vals) diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000..59ab028 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,379 @@ +import pytest +import torch +import math +from BackendBench.scripts.utils import serialize_args, deserialize_args, reserialize_args, _deserialize_tensor + + +class TestDeserializeArgs: + """Test cases for deserialize_args function""" + + def test_single_tensor_arg(self): + """Test deserializing a single tensor argument""" + input_str = "((T([48, 24, 28, 28], f16),), {})" + args, kwargs = deserialize_args(input_str) + + assert len(args) == 1 + assert len(kwargs) == 0 + assert isinstance(args[0], torch.Tensor) + assert args[0].shape == (48, 24, 28, 28) + assert args[0].dtype == torch.float16 + # Device will be 'cuda' if available, otherwise 'cpu' + assert args[0].device.type in ['cuda', 'cpu'] + + def test_user_specified_input_1(self): + """Test deserializing user-specified input case 1""" + input_str = "((T([48, 24, 2816, 2816], f16),), {})" + args, kwargs = deserialize_args(input_str) + + assert len(args) == 1 + assert len(kwargs) == 0 + assert isinstance(args[0], torch.Tensor) + assert args[0].shape == (48, 24, 2816, 2816) + assert args[0].dtype == torch.float16 + assert args[0].device.type in ['cuda', 'cpu'] + + def test_user_specified_input_2(self): + """Test deserializing user-specified input case 2""" + input_str = "((T([512, 64, 64, 64, 64], f16), T([512, 64, 64, 64, 64], f16),), {})" + args, kwargs = deserialize_args(input_str) + + assert len(args) == 2 + assert len(kwargs) == 0 + assert all(isinstance(arg, torch.Tensor) for arg in args) + assert all(arg.shape == (512, 64, 64, 64, 64) for arg in args) + assert all(arg.dtype == torch.float16 for arg in args) + assert all(arg.device.type in ['cuda', 'cpu'] for arg in args) + + def test_user_specified_input_3(self): + """Test deserializing user-specified input case 3""" + input_str = "((T([32768, 988032], f16), [1024, 249, 249],), {'dtype': torch.float16, 'layout': torch.strided, 'device': 'cuda'})" + args, kwargs = deserialize_args(input_str) + + assert len(args) == 2 + assert len(kwargs) == 3 + assert isinstance(args[0], torch.Tensor) + assert args[0].shape == (32768, 988032) + assert args[0].dtype == torch.float16 + assert args[1] == [1024, 249, 249] + assert kwargs['dtype'] == torch.float16 + assert kwargs['layout'] == torch.strided + assert kwargs['device'] == 'cuda' + + def test_multiple_tensor_args(self): + """Test deserializing multiple tensor arguments with smaller tensors""" + input_str = "((T([5, 6, 7, 8, 9], f16), T([5, 6, 7, 8, 9], f16),), {})" + args, kwargs = deserialize_args(input_str) + + assert len(args) == 2 + assert len(kwargs) == 0 + assert all(isinstance(arg, torch.Tensor) for arg in args) + assert all(arg.shape == (5, 6, 7, 8, 9) for arg in args) + assert all(arg.dtype == torch.float16 for arg in args) + assert all(arg.device.type in ['cuda', 'cpu'] for arg in args) + + def test_tensor_with_negative_values(self): + """Test deserializing with negative numbers in lists""" + input_str = "((T([10, 20], f32), [-1, -2, -3],), {})" + args, kwargs = deserialize_args(input_str) + + assert len(args) == 2 + assert isinstance(args[0], torch.Tensor) + assert args[0].shape == (10, 20) + assert args[0].dtype == torch.float32 + assert args[1] == [-1, -2, -3] + + def test_different_dtypes(self): + """Test deserializing tensors with different dtypes""" + test_cases = [ + ("((T([10, 20], f32),), {})", torch.float32), + ("((T([10, 20], f64),), {})", torch.float64), + ("((T([10, 20], bf16),), {})", torch.bfloat16), + ("((T([10, 20], i32),), {})", torch.int32), + ("((T([10, 20], i64),), {})", torch.int64), + ("((T([10, 20], b8),), {})", torch.bool), + ] + + for input_str, expected_dtype in test_cases: + args, kwargs = deserialize_args(input_str) + assert args[0].dtype == expected_dtype + + def test_tensor_with_stride(self): + """Test deserializing tensor with custom stride""" + input_str = "((T([10, 20], f16, [40, 2]),), {})" + args, kwargs = deserialize_args(input_str) + + assert len(args) == 1 + tensor = args[0] + assert tensor.shape == (10, 20) + assert tensor.stride() == (40, 2) + assert tensor.dtype == torch.float16 + + def test_empty_args_kwargs(self): + """Test deserializing empty args and kwargs""" + input_str = "((), {})" + args, kwargs = deserialize_args(input_str) + + assert len(args) == 0 + assert len(kwargs) == 0 + + def test_primitive_args(self): + """Test deserializing primitive arguments""" + input_str = "((1, 2.5, 'hello', True, None,), {})" + args, kwargs = deserialize_args(input_str) + + assert len(args) == 5 + assert args[0] == 1 + assert args[1] == 2.5 + assert args[2] == 'hello' + assert args[3] is True + assert args[4] is None + + def test_math_inf(self): + """Test deserializing math.inf""" + input_str = "((inf,), {})" + args, kwargs = deserialize_args(input_str) + + assert len(args) == 1 + assert args[0] == math.inf + + def test_torch_constants(self): + """Test deserializing torch constants""" + input_str = "((torch.float16,), {})" + args, kwargs = deserialize_args(input_str) + + assert len(args) == 1 + assert args[0] == torch.float16 + + def test_mixed_args_kwargs(self): + """Test deserializing mixed args and kwargs""" + input_str = "((T([5, 5], f32), 42,), {'alpha': 0.5, 'beta': T([3, 3], i64)})" + args, kwargs = deserialize_args(input_str) + + assert len(args) == 2 + assert len(kwargs) == 2 + assert isinstance(args[0], torch.Tensor) + assert args[0].shape == (5, 5) + assert args[0].dtype == torch.float32 + assert args[1] == 42 + assert kwargs['alpha'] == 0.5 + assert isinstance(kwargs['beta'], torch.Tensor) + assert kwargs['beta'].shape == (3, 3) + assert kwargs['beta'].dtype == torch.int64 + + +class TestSerializeArgs: + """Test cases for serialize_args function""" + + def test_single_tensor_arg(self): + """Test serializing a single tensor argument""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor = torch.randn(48, 24, 2816, 2816, dtype=torch.float16, device=device) + args = (tensor,) + kwargs = {} + + result = serialize_args(args, kwargs) + expected = "((T([48, 24, 2816, 2816], f16),), {})" + assert result == expected + + def test_multiple_tensor_args(self): + """Test serializing multiple tensor arguments""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor1 = torch.randn(512, 64, 64, 64, 64, dtype=torch.float16, device=device) + tensor2 = torch.randn(512, 64, 64, 64, 64, dtype=torch.float16, device=device) + args = (tensor1, tensor2) + kwargs = {} + + result = serialize_args(args, kwargs) + expected = "((T([512, 64, 64, 64, 64], f16), T([512, 64, 64, 64, 64], f16),), {})" + assert result == expected + + def test_tensor_with_list_and_kwargs(self): + """Test serializing tensor with list and keyword arguments""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor = torch.randn(32768, 988032, dtype=torch.float16, device=device) + args = (tensor, [1024, 249, 249]) + kwargs = {'dtype': torch.float16, 'layout': torch.strided, 'device': device} + + result = serialize_args(args, kwargs) + expected = f"((T([32768, 988032], f16), [1024, 249, 249],), {{'dtype': torch.float16, 'layout': torch.strided, 'device': '{device}'}})" + assert result == expected + + def test_different_dtypes(self): + """Test reserializing tensors with different dtypes""" + test_cases = [ + (torch.float32, "f32"), + (torch.float64, "f64"), + (torch.bfloat16, "bf16"), + (torch.int32, "i32"), + (torch.int64, "i64"), + (torch.bool, "b8"), + ] + + for dtype, expected_abbr in test_cases: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor = torch.randn(10, 20, dtype=dtype, device=device) + args = (tensor,) + kwargs = {} + + result = serialize_args(args, kwargs) + expected = f"((T([10, 20], {expected_abbr}),), {{}})" + assert result == expected + + def test_tensor_with_stride(self): + """Test serializing tensor with custom stride""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor = torch.randn(20, 10, dtype=torch.float16, device=device) + # Create a strided tensor + strided_tensor = tensor.transpose(0, 1) # This creates a non-contiguous tensor + args = (strided_tensor,) + kwargs = {} + + result = serialize_args(args, kwargs) + # The exact stride depends on the tensor layout, but it should include stride info + assert "T([10, 20], f16, [" in result + assert "])" in result + + def test_empty_args_kwargs(self): + """Test reserializing empty args and kwargs""" + args = () + kwargs = {} + + result = serialize_args(args, kwargs) + expected = "((), {})" + assert result == expected + + def test_primitive_args(self): + """Test reserializing primitive arguments""" + args = (1, 2.5, 'hello', True, None) + kwargs = {} + + result = serialize_args(args, kwargs) + expected = "((1, 2.5, 'hello', True, None,), {})" + assert result == expected + + def test_none_inputs(self): + """Test reserializing None inputs""" + assert serialize_args(None, {}) == "None" + assert serialize_args([], None) == "None" + assert serialize_args(None, None) == "None" + + def test_list_with_tensors(self): + """Test serializing list containing tensors""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor1 = torch.randn(5, 5, dtype=torch.float32, device=device) + tensor2 = torch.ones(3, 3, dtype=torch.int64, device=device) # Use ones for int tensor + args = ([tensor1, tensor2, 42],) + kwargs = {} + + result = serialize_args(args, kwargs) + expected = "(([T([5, 5], f32), T([3, 3], i64), 42],), {})" + assert result == expected + + def test_kwargs_with_tensors(self): + """Test serializing kwargs containing tensors""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor = torch.randn(3, 3, dtype=torch.float32, device=device) + args = () + kwargs = {'weight': tensor, 'bias': None, 'alpha': 0.5} + + result = serialize_args(args, kwargs) + expected = "((), {'weight': T([3, 3], f32), 'bias': None, 'alpha': 0.5})" + assert result == expected + + def test_reserialize_args_alias(self): + """Test that reserialize_args is an alias for serialize_args""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor = torch.randn(5, 5, dtype=torch.float32, device=device) + args = (tensor,) + kwargs = {} + + result1 = serialize_args(args, kwargs) + result2 = reserialize_args(args, kwargs) + assert result1 == result2 + assert reserialize_args is serialize_args + + +class TestRoundTrip: + """Test round-trip serialization/deserialization""" + + def test_roundtrip_single_tensor(self): + """Test that serialize->deserialize produces equivalent tensors""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + original_tensor = torch.randn(10, 20, dtype=torch.float16, device=device) + original_args = (original_tensor,) + original_kwargs = {} + + # Serialize + serialized = serialize_args(original_args, original_kwargs) + + # Deserialize + deserialized_args, deserialized_kwargs = deserialize_args(serialized) + + # Check equivalence + assert len(deserialized_args) == len(original_args) + assert len(deserialized_kwargs) == len(original_kwargs) + assert deserialized_args[0].shape == original_args[0].shape + assert deserialized_args[0].dtype == original_args[0].dtype + # Device type might differ due to CUDA availability fallback + assert deserialized_args[0].device.type in ['cuda', 'cpu'] + + def test_roundtrip_complex_args(self): + """Test round-trip with complex arguments""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + tensor = torch.randn(5, 5, dtype=torch.float32, device=device) + original_args = (tensor, [1, 2, 3], 'test') + original_kwargs = {'alpha': 0.5, 'beta': tensor} + + # Serialize + serialized = serialize_args(original_args, original_kwargs) + + # Deserialize + deserialized_args, deserialized_kwargs = deserialize_args(serialized) + + # Check equivalence + assert len(deserialized_args) == len(original_args) + assert len(deserialized_kwargs) == len(original_kwargs) + assert deserialized_args[0].shape == original_args[0].shape + assert deserialized_args[0].dtype == original_args[0].dtype + assert deserialized_args[1] == original_args[1] + assert deserialized_args[2] == original_args[2] + assert deserialized_kwargs['alpha'] == original_kwargs['alpha'] + assert deserialized_kwargs['beta'].shape == original_kwargs['beta'].shape + assert deserialized_kwargs['beta'].dtype == original_kwargs['beta'].dtype + + +class TestDeserializeTensor: + """Test cases for _deserialize_tensor helper function""" + + def test_basic_tensor_creation(self): + """Test basic tensor creation with different dtypes""" + tensor = _deserialize_tensor([10, 20], torch.float32) + assert tensor.shape == (10, 20) + assert tensor.dtype == torch.float32 + assert tensor.device.type == 'cuda' + + def test_tensor_with_stride(self): + """Test tensor creation with custom stride""" + tensor = _deserialize_tensor([5, 4], torch.float16, stride=[8, 2]) + assert tensor.shape == (5, 4) + assert tensor.stride() == (8, 2) + assert tensor.dtype == torch.float16 + + def test_tensor_different_device(self): + """Test tensor creation with different device""" + tensor = _deserialize_tensor([3, 3], torch.float32, device='cpu') + assert tensor.device.type == 'cpu' + + def test_floating_point_range(self): + """Test that floating point tensors have values in [0, 1] range""" + for dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16]: + tensor = _deserialize_tensor([100], dtype) + assert tensor.min() >= 0 + assert tensor.max() <= 1 + + def test_integer_tensors(self): + """Test integer tensor creation""" + for dtype in [torch.int32, torch.int64, torch.int8, torch.int16]: + tensor = _deserialize_tensor([10], dtype) + assert tensor.dtype == dtype + assert tensor.shape == (10,) \ No newline at end of file From 74951079a425cead593c5d8935c61ac92504670f Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Tue, 29 Jul 2025 17:29:33 -0700 Subject: [PATCH 02/32] fix --- BackendBench/scripts/utils.py | 16 ++- test/test_utils.py | 231 +++++++++++++++++----------------- 2 files changed, 129 insertions(+), 118 deletions(-) diff --git a/BackendBench/scripts/utils.py b/BackendBench/scripts/utils.py index 4f1a31a..bd7937e 100644 --- a/BackendBench/scripts/utils.py +++ b/BackendBench/scripts/utils.py @@ -27,23 +27,24 @@ def _deserialize_tensor(size, dtype, stride=None, device="cuda"): kwargs = {} if dtype in _FLOATING_TYPES: kwargs.update({"low": 0, "high": 1}) - + # Fall back to CPU if CUDA is not available if device == "cuda" and not torch.cuda.is_available(): device = "cpu" - + if stride is not None: extent = 1 + sum((size - 1) * stride for size, stride in zip(size, stride)) data = make_tensor(extent, dtype=dtype, device=device, **kwargs) return data.as_strided(size, stride) return make_tensor(size, dtype=dtype, device=device, **kwargs) + def _serialize_tensor(tensor): """Helper function to serialize a tensor to string format""" shape = list(tensor.shape) dtype = dtype_abbrs[tensor.dtype] stride = tensor.stride() if not tensor.is_contiguous() else None - + if stride: return f"T({shape}, {dtype}, {list(stride)})" else: @@ -78,9 +79,12 @@ def serialize_args(args, kwargs) -> str: parts = [_serialize_value(arg) for arg in args] # Process keyword arguments - kwargs_parts = [f"{key}={_serialize_value(val)}" for key, val in kwargs.items()] - - return f"(({', '.join(parts)},), {{{', '.join(kwargs_parts)}}})" + kwargs_parts = [f"'{key}': {_serialize_value(val)}" for key, val in kwargs.items()] + + # Handle empty args tuple properly + args_str = f"({', '.join(parts)},)" if parts else "()" + + return f"({args_str}, {{{', '.join(kwargs_parts)}}})" # Alias for backward compatibility diff --git a/test/test_utils.py b/test/test_utils.py index 59ab028..da70906 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,87 +1,91 @@ -import pytest import torch import math -from BackendBench.scripts.utils import serialize_args, deserialize_args, reserialize_args, _deserialize_tensor +from BackendBench.scripts.utils import ( + serialize_args, + deserialize_args, + reserialize_args, + _deserialize_tensor, +) class TestDeserializeArgs: """Test cases for deserialize_args function""" - + def test_single_tensor_arg(self): """Test deserializing a single tensor argument""" input_str = "((T([48, 24, 28, 28], f16),), {})" args, kwargs = deserialize_args(input_str) - + assert len(args) == 1 assert len(kwargs) == 0 assert isinstance(args[0], torch.Tensor) assert args[0].shape == (48, 24, 28, 28) assert args[0].dtype == torch.float16 # Device will be 'cuda' if available, otherwise 'cpu' - assert args[0].device.type in ['cuda', 'cpu'] - + assert args[0].device.type in ["cuda", "cpu"] + def test_user_specified_input_1(self): """Test deserializing user-specified input case 1""" - input_str = "((T([48, 24, 2816, 2816], f16),), {})" + input_str = "((T([48, 24, 28, 28], f16),), {})" args, kwargs = deserialize_args(input_str) - + assert len(args) == 1 assert len(kwargs) == 0 assert isinstance(args[0], torch.Tensor) - assert args[0].shape == (48, 24, 2816, 2816) + assert args[0].shape == (48, 24, 28, 28) assert args[0].dtype == torch.float16 - assert args[0].device.type in ['cuda', 'cpu'] - + assert args[0].device.type in ["cuda", "cpu"] + def test_user_specified_input_2(self): """Test deserializing user-specified input case 2""" - input_str = "((T([512, 64, 64, 64, 64], f16), T([512, 64, 64, 64, 64], f16),), {})" + input_str = "((T([8, 8, 8, 8, 8], f16), T([8, 8, 8, 8, 8], f16),), {})" args, kwargs = deserialize_args(input_str) - + assert len(args) == 2 assert len(kwargs) == 0 assert all(isinstance(arg, torch.Tensor) for arg in args) - assert all(arg.shape == (512, 64, 64, 64, 64) for arg in args) + assert all(arg.shape == (8, 8, 8, 8, 8) for arg in args) assert all(arg.dtype == torch.float16 for arg in args) - assert all(arg.device.type in ['cuda', 'cpu'] for arg in args) - + assert all(arg.device.type in ["cuda", "cpu"] for arg in args) + def test_user_specified_input_3(self): """Test deserializing user-specified input case 3""" - input_str = "((T([32768, 988032], f16), [1024, 249, 249],), {'dtype': torch.float16, 'layout': torch.strided, 'device': 'cuda'})" + input_str = "((T([128, 256], f16), [1024, 249, 249],), {'dtype': torch.float16, 'layout': torch.strided, 'device': 'cuda'})" args, kwargs = deserialize_args(input_str) - + assert len(args) == 2 assert len(kwargs) == 3 assert isinstance(args[0], torch.Tensor) - assert args[0].shape == (32768, 988032) + assert args[0].shape == (128, 256) assert args[0].dtype == torch.float16 assert args[1] == [1024, 249, 249] - assert kwargs['dtype'] == torch.float16 - assert kwargs['layout'] == torch.strided - assert kwargs['device'] == 'cuda' - + assert kwargs["dtype"] == torch.float16 + assert kwargs["layout"] == torch.strided + assert kwargs["device"] == "cuda" + def test_multiple_tensor_args(self): """Test deserializing multiple tensor arguments with smaller tensors""" input_str = "((T([5, 6, 7, 8, 9], f16), T([5, 6, 7, 8, 9], f16),), {})" args, kwargs = deserialize_args(input_str) - + assert len(args) == 2 assert len(kwargs) == 0 assert all(isinstance(arg, torch.Tensor) for arg in args) assert all(arg.shape == (5, 6, 7, 8, 9) for arg in args) assert all(arg.dtype == torch.float16 for arg in args) - assert all(arg.device.type in ['cuda', 'cpu'] for arg in args) - + assert all(arg.device.type in ["cuda", "cpu"] for arg in args) + def test_tensor_with_negative_values(self): """Test deserializing with negative numbers in lists""" input_str = "((T([10, 20], f32), [-1, -2, -3],), {})" args, kwargs = deserialize_args(input_str) - + assert len(args) == 2 assert isinstance(args[0], torch.Tensor) assert args[0].shape == (10, 20) assert args[0].dtype == torch.float32 assert args[1] == [-1, -2, -3] - + def test_different_dtypes(self): """Test deserializing tensors with different dtypes""" test_cases = [ @@ -92,112 +96,112 @@ def test_different_dtypes(self): ("((T([10, 20], i64),), {})", torch.int64), ("((T([10, 20], b8),), {})", torch.bool), ] - + for input_str, expected_dtype in test_cases: args, kwargs = deserialize_args(input_str) assert args[0].dtype == expected_dtype - + def test_tensor_with_stride(self): """Test deserializing tensor with custom stride""" input_str = "((T([10, 20], f16, [40, 2]),), {})" args, kwargs = deserialize_args(input_str) - + assert len(args) == 1 tensor = args[0] assert tensor.shape == (10, 20) assert tensor.stride() == (40, 2) assert tensor.dtype == torch.float16 - + def test_empty_args_kwargs(self): """Test deserializing empty args and kwargs""" input_str = "((), {})" args, kwargs = deserialize_args(input_str) - + assert len(args) == 0 assert len(kwargs) == 0 - + def test_primitive_args(self): """Test deserializing primitive arguments""" input_str = "((1, 2.5, 'hello', True, None,), {})" args, kwargs = deserialize_args(input_str) - + assert len(args) == 5 assert args[0] == 1 assert args[1] == 2.5 - assert args[2] == 'hello' + assert args[2] == "hello" assert args[3] is True assert args[4] is None - + def test_math_inf(self): """Test deserializing math.inf""" input_str = "((inf,), {})" args, kwargs = deserialize_args(input_str) - + assert len(args) == 1 assert args[0] == math.inf - + def test_torch_constants(self): """Test deserializing torch constants""" input_str = "((torch.float16,), {})" args, kwargs = deserialize_args(input_str) - + assert len(args) == 1 assert args[0] == torch.float16 - + def test_mixed_args_kwargs(self): """Test deserializing mixed args and kwargs""" input_str = "((T([5, 5], f32), 42,), {'alpha': 0.5, 'beta': T([3, 3], i64)})" args, kwargs = deserialize_args(input_str) - + assert len(args) == 2 assert len(kwargs) == 2 assert isinstance(args[0], torch.Tensor) assert args[0].shape == (5, 5) assert args[0].dtype == torch.float32 assert args[1] == 42 - assert kwargs['alpha'] == 0.5 - assert isinstance(kwargs['beta'], torch.Tensor) - assert kwargs['beta'].shape == (3, 3) - assert kwargs['beta'].dtype == torch.int64 + assert kwargs["alpha"] == 0.5 + assert isinstance(kwargs["beta"], torch.Tensor) + assert kwargs["beta"].shape == (3, 3) + assert kwargs["beta"].dtype == torch.int64 class TestSerializeArgs: """Test cases for serialize_args function""" - + def test_single_tensor_arg(self): """Test serializing a single tensor argument""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' - tensor = torch.randn(48, 24, 2816, 2816, dtype=torch.float16, device=device) + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor = torch.randn(48, 24, 28, 28, dtype=torch.float16, device=device) args = (tensor,) kwargs = {} - + result = serialize_args(args, kwargs) - expected = "((T([48, 24, 2816, 2816], f16),), {})" + expected = "((T([48, 24, 28, 28], f16),), {})" assert result == expected - + def test_multiple_tensor_args(self): """Test serializing multiple tensor arguments""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' - tensor1 = torch.randn(512, 64, 64, 64, 64, dtype=torch.float16, device=device) - tensor2 = torch.randn(512, 64, 64, 64, 64, dtype=torch.float16, device=device) + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor1 = torch.randn(8, 8, 8, 8, 8, dtype=torch.float16, device=device) + tensor2 = torch.randn(8, 8, 8, 8, 8, dtype=torch.float16, device=device) args = (tensor1, tensor2) kwargs = {} - + result = serialize_args(args, kwargs) - expected = "((T([512, 64, 64, 64, 64], f16), T([512, 64, 64, 64, 64], f16),), {})" + expected = "((T([8, 8, 8, 8, 8], f16), T([8, 8, 8, 8, 8], f16),), {})" assert result == expected - + def test_tensor_with_list_and_kwargs(self): """Test serializing tensor with list and keyword arguments""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' - tensor = torch.randn(32768, 988032, dtype=torch.float16, device=device) + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor = torch.randn(128, 256, dtype=torch.float16, device=device) args = (tensor, [1024, 249, 249]) - kwargs = {'dtype': torch.float16, 'layout': torch.strided, 'device': device} - + kwargs = {"dtype": torch.float16, "layout": torch.strided, "device": device} + result = serialize_args(args, kwargs) - expected = f"((T([32768, 988032], f16), [1024, 249, 249],), {{'dtype': torch.float16, 'layout': torch.strided, 'device': '{device}'}})" + expected = f"((T([128, 256], f16), [1024, 249, 249],), {{'dtype': torch.float16, 'layout': torch.strided, 'device': '{device}'}})" assert result == expected - + def test_different_dtypes(self): """Test reserializing tensors with different dtypes""" test_cases = [ @@ -208,85 +212,88 @@ def test_different_dtypes(self): (torch.int64, "i64"), (torch.bool, "b8"), ] - + for dtype, expected_abbr in test_cases: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - tensor = torch.randn(10, 20, dtype=dtype, device=device) + device = "cuda" if torch.cuda.is_available() else "cpu" + if dtype in [torch.int32, torch.int64, torch.bool]: + tensor = torch.ones(10, 20, dtype=dtype, device=device) + else: + tensor = torch.randn(10, 20, dtype=dtype, device=device) args = (tensor,) kwargs = {} - + result = serialize_args(args, kwargs) expected = f"((T([10, 20], {expected_abbr}),), {{}})" assert result == expected - + def test_tensor_with_stride(self): """Test serializing tensor with custom stride""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" tensor = torch.randn(20, 10, dtype=torch.float16, device=device) # Create a strided tensor strided_tensor = tensor.transpose(0, 1) # This creates a non-contiguous tensor args = (strided_tensor,) kwargs = {} - + result = serialize_args(args, kwargs) # The exact stride depends on the tensor layout, but it should include stride info assert "T([10, 20], f16, [" in result assert "])" in result - + def test_empty_args_kwargs(self): """Test reserializing empty args and kwargs""" args = () kwargs = {} - + result = serialize_args(args, kwargs) expected = "((), {})" assert result == expected - + def test_primitive_args(self): """Test reserializing primitive arguments""" - args = (1, 2.5, 'hello', True, None) + args = (1, 2.5, "hello", True, None) kwargs = {} - + result = serialize_args(args, kwargs) expected = "((1, 2.5, 'hello', True, None,), {})" assert result == expected - + def test_none_inputs(self): """Test reserializing None inputs""" assert serialize_args(None, {}) == "None" assert serialize_args([], None) == "None" assert serialize_args(None, None) == "None" - + def test_list_with_tensors(self): """Test serializing list containing tensors""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" tensor1 = torch.randn(5, 5, dtype=torch.float32, device=device) tensor2 = torch.ones(3, 3, dtype=torch.int64, device=device) # Use ones for int tensor args = ([tensor1, tensor2, 42],) kwargs = {} - + result = serialize_args(args, kwargs) expected = "(([T([5, 5], f32), T([3, 3], i64), 42],), {})" assert result == expected - + def test_kwargs_with_tensors(self): """Test serializing kwargs containing tensors""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" tensor = torch.randn(3, 3, dtype=torch.float32, device=device) args = () - kwargs = {'weight': tensor, 'bias': None, 'alpha': 0.5} - + kwargs = {"weight": tensor, "bias": None, "alpha": 0.5} + result = serialize_args(args, kwargs) expected = "((), {'weight': T([3, 3], f32), 'bias': None, 'alpha': 0.5})" assert result == expected - + def test_reserialize_args_alias(self): """Test that reserialize_args is an alias for serialize_args""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" tensor = torch.randn(5, 5, dtype=torch.float32, device=device) args = (tensor,) kwargs = {} - + result1 = serialize_args(args, kwargs) result2 = reserialize_args(args, kwargs) assert result1 == result2 @@ -295,41 +302,41 @@ def test_reserialize_args_alias(self): class TestRoundTrip: """Test round-trip serialization/deserialization""" - + def test_roundtrip_single_tensor(self): """Test that serialize->deserialize produces equivalent tensors""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" original_tensor = torch.randn(10, 20, dtype=torch.float16, device=device) original_args = (original_tensor,) original_kwargs = {} - + # Serialize serialized = serialize_args(original_args, original_kwargs) - + # Deserialize deserialized_args, deserialized_kwargs = deserialize_args(serialized) - + # Check equivalence assert len(deserialized_args) == len(original_args) assert len(deserialized_kwargs) == len(original_kwargs) assert deserialized_args[0].shape == original_args[0].shape assert deserialized_args[0].dtype == original_args[0].dtype # Device type might differ due to CUDA availability fallback - assert deserialized_args[0].device.type in ['cuda', 'cpu'] - + assert deserialized_args[0].device.type in ["cuda", "cpu"] + def test_roundtrip_complex_args(self): """Test round-trip with complex arguments""" - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" tensor = torch.randn(5, 5, dtype=torch.float32, device=device) - original_args = (tensor, [1, 2, 3], 'test') - original_kwargs = {'alpha': 0.5, 'beta': tensor} - + original_args = (tensor, [1, 2, 3], "test") + original_kwargs = {"alpha": 0.5, "beta": tensor} + # Serialize serialized = serialize_args(original_args, original_kwargs) - + # Deserialize deserialized_args, deserialized_kwargs = deserialize_args(serialized) - + # Check equivalence assert len(deserialized_args) == len(original_args) assert len(deserialized_kwargs) == len(original_kwargs) @@ -337,43 +344,43 @@ def test_roundtrip_complex_args(self): assert deserialized_args[0].dtype == original_args[0].dtype assert deserialized_args[1] == original_args[1] assert deserialized_args[2] == original_args[2] - assert deserialized_kwargs['alpha'] == original_kwargs['alpha'] - assert deserialized_kwargs['beta'].shape == original_kwargs['beta'].shape - assert deserialized_kwargs['beta'].dtype == original_kwargs['beta'].dtype + assert deserialized_kwargs["alpha"] == original_kwargs["alpha"] + assert deserialized_kwargs["beta"].shape == original_kwargs["beta"].shape + assert deserialized_kwargs["beta"].dtype == original_kwargs["beta"].dtype class TestDeserializeTensor: """Test cases for _deserialize_tensor helper function""" - + def test_basic_tensor_creation(self): """Test basic tensor creation with different dtypes""" tensor = _deserialize_tensor([10, 20], torch.float32) assert tensor.shape == (10, 20) assert tensor.dtype == torch.float32 - assert tensor.device.type == 'cuda' - + assert tensor.device.type in ["cuda", "cpu"] + def test_tensor_with_stride(self): """Test tensor creation with custom stride""" tensor = _deserialize_tensor([5, 4], torch.float16, stride=[8, 2]) assert tensor.shape == (5, 4) assert tensor.stride() == (8, 2) assert tensor.dtype == torch.float16 - + def test_tensor_different_device(self): """Test tensor creation with different device""" - tensor = _deserialize_tensor([3, 3], torch.float32, device='cpu') - assert tensor.device.type == 'cpu' - + tensor = _deserialize_tensor([3, 3], torch.float32, device="cpu") + assert tensor.device.type == "cpu" + def test_floating_point_range(self): """Test that floating point tensors have values in [0, 1] range""" for dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16]: tensor = _deserialize_tensor([100], dtype) assert tensor.min() >= 0 assert tensor.max() <= 1 - + def test_integer_tensors(self): """Test integer tensor creation""" for dtype in [torch.int32, torch.int64, torch.int8, torch.int16]: tensor = _deserialize_tensor([10], dtype) assert tensor.dtype == dtype - assert tensor.shape == (10,) \ No newline at end of file + assert tensor.shape == (10,) From e23bd3a650f660107bc55a26d309097a633f6702 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Tue, 29 Jul 2025 17:39:19 -0700 Subject: [PATCH 03/32] fix --- BackendBench/scripts/utils.py | 4 ---- test/test_utils.py | 13 ------------- 2 files changed, 17 deletions(-) diff --git a/BackendBench/scripts/utils.py b/BackendBench/scripts/utils.py index bd7937e..52b07ae 100644 --- a/BackendBench/scripts/utils.py +++ b/BackendBench/scripts/utils.py @@ -87,10 +87,6 @@ def serialize_args(args, kwargs) -> str: return f"({args_str}, {{{', '.join(kwargs_parts)}}})" -# Alias for backward compatibility -reserialize_args = serialize_args - - def deserialize_args(inps): inps = inps.strip().strip("'") global_vals = { diff --git a/test/test_utils.py b/test/test_utils.py index da70906..37ec391 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -3,7 +3,6 @@ from BackendBench.scripts.utils import ( serialize_args, deserialize_args, - reserialize_args, _deserialize_tensor, ) @@ -287,18 +286,6 @@ def test_kwargs_with_tensors(self): expected = "((), {'weight': T([3, 3], f32), 'bias': None, 'alpha': 0.5})" assert result == expected - def test_reserialize_args_alias(self): - """Test that reserialize_args is an alias for serialize_args""" - device = "cuda" if torch.cuda.is_available() else "cpu" - tensor = torch.randn(5, 5, dtype=torch.float32, device=device) - args = (tensor,) - kwargs = {} - - result1 = serialize_args(args, kwargs) - result2 = reserialize_args(args, kwargs) - assert result1 == result2 - assert reserialize_args is serialize_args - class TestRoundTrip: """Test round-trip serialization/deserialization""" From 0d54c1c0538e9a7933aa4bbd2a699294ab38aa6d Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Wed, 23 Jul 2025 10:28:54 -0400 Subject: [PATCH 04/32] [ez] get workflows to run on prs (#39) --- .github/workflows/ruff.yml | 1 + .github/workflows/smoke-test.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 9abe861..ac4efdd 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -2,6 +2,7 @@ name: Ruff on: push: + pull_request: jobs: ruff: runs-on: ubuntu-latest diff --git a/.github/workflows/smoke-test.yml b/.github/workflows/smoke-test.yml index 0630c56..2b2a5e6 100644 --- a/.github/workflows/smoke-test.yml +++ b/.github/workflows/smoke-test.yml @@ -2,6 +2,7 @@ name: Smoke Test on: push: + pull_request: jobs: smoke-test: From 0eb07536fde87bd29c083dadc9ecd7cbafb7ad02 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Wed, 23 Jul 2025 11:47:55 -0400 Subject: [PATCH 05/32] Grab txt file from huggingface as the default (#38) --- BackendBench/torchbench_suite.py | 30 ++++++++++++++++++++++++++++-- pyproject.toml | 2 +- requirements.txt | 3 ++- scripts/main.py | 4 ++-- 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/BackendBench/torchbench_suite.py b/BackendBench/torchbench_suite.py index 78f6704..c1ec454 100644 --- a/BackendBench/torchbench_suite.py +++ b/BackendBench/torchbench_suite.py @@ -4,12 +4,18 @@ import math import re +import tempfile from collections import defaultdict from pathlib import Path +import requests import torch from torch.testing import make_tensor +# the schema for this dataset is the one defined in tritonbench traces. +# ie. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/data/input_configs/hf_train/AlbertForMaskedLM_training.txt +DEFAULT_HUGGINGFACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/tritonbench_op_trace.txt" + dtype_abbrs = { torch.bfloat16: "bf16", @@ -120,11 +126,29 @@ def _parse_inputs(filename, filter, op_inputs): class TorchBenchTestSuite: - def __init__(self, name, filename, filter=None, topn=None): + def __init__(self, name, filename=None, filter=None, topn=None): self.name = name self.topn = topn self.optests = defaultdict(list) - if Path(filename).is_dir(): + + # Use default URL if no filename provided + if filename is None: + filename = DEFAULT_HUGGINGFACE_URL + + # Check if filename is a URL + if isinstance(filename, str) and ( + filename.startswith("http://") or filename.startswith("https://") + ): + with ( + tempfile.NamedTemporaryFile(mode="w+", suffix=".txt", delete=False) as tmp_file, + requests.get(filename) as response, + ): + response.raise_for_status() + tmp_file.write(response.text) + tmp_file.flush() + _parse_inputs(tmp_file.name, filter, self.optests) + Path(tmp_file.name).unlink(missing_ok=True) + elif Path(filename).is_dir(): for file_path in Path(filename).glob("**/*.txt"): _parse_inputs(str(file_path), filter, self.optests) else: @@ -148,6 +172,8 @@ def __iter__(self): "native_layer_norm_backward", "upsample_nearest2d_backward.vec", "upsample_bilinear2d_backward.vec", + "_cudnn_rnn_backward.default", # RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM + "_fft_c2c.default", # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision ] ): # TODO: indexing ops need valid indices diff --git a/pyproject.toml b/pyproject.toml index 6db2bf9..b5d514e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ [tool.ruff] line-length = 100 -[tool.ruff.format] \ No newline at end of file +[tool.ruff.format] diff --git a/requirements.txt b/requirements.txt index 1216f77..0ad97e0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ click numpy expecttest anthropic>=0.34.0 -pytest \ No newline at end of file +pytest +requests diff --git a/scripts/main.py b/scripts/main.py index ed6376b..779fb80 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -10,7 +10,7 @@ from BackendBench.llm_client import ClaudeKernelGenerator from BackendBench.opinfo_suite import OpInfoTestSuite from BackendBench.suite import SmokeTestSuite -from BackendBench.torchbench_suite import TorchBenchTestSuite +from BackendBench.torchbench_suite import DEFAULT_HUGGINGFACE_URL, TorchBenchTestSuite logger = logging.getLogger(__name__) @@ -80,7 +80,7 @@ def setup_logging(log_level): ) @click.option( "--torchbench-data-path", - default="third_party/tritonbench/tritonbench/data/input_configs", + default=DEFAULT_HUGGINGFACE_URL, type=str, help="Path to TorchBench operator data", ) From 037f7c55458a10ffd041ac0d015fd5804587e4cc Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 23 Jul 2025 10:43:55 -0700 Subject: [PATCH 06/32] Installable backends (#27) --- .github/workflows/ruff.yml | 8 +- .github/workflows/smoke-test.yml | 15 ++- .gitignore | 1 + BackendBench/__init__.py | 123 ++++++++++++++++++ BackendBench/scripts/__init__.py | 1 + .../scripts}/create_simple_test_ops.py | 0 {scripts => BackendBench/scripts}/main.py | 0 pyproject.toml | 46 +++++++ requirements-dev.txt | 5 - requirements.txt | 7 - test/test_directory_backend.py | 4 +- 11 files changed, 191 insertions(+), 19 deletions(-) create mode 100644 BackendBench/__init__.py create mode 100644 BackendBench/scripts/__init__.py rename {scripts => BackendBench/scripts}/create_simple_test_ops.py (100%) rename {scripts => BackendBench/scripts}/main.py (100%) delete mode 100644 requirements-dev.txt delete mode 100644 requirements.txt diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index ac4efdd..4393871 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -2,7 +2,11 @@ name: Ruff on: push: + branches: + - main pull_request: + branches: + - main jobs: ruff: runs-on: ubuntu-latest @@ -14,8 +18,8 @@ jobs: with: python-version: '3.x' - - name: Install ruff - run: pip install ruff==0.12.1 + - name: Install package with dev dependencies + run: pip install -e .[dev] - name: Run ruff check run: ruff check . diff --git a/.github/workflows/smoke-test.yml b/.github/workflows/smoke-test.yml index 2b2a5e6..4bd44fe 100644 --- a/.github/workflows/smoke-test.yml +++ b/.github/workflows/smoke-test.yml @@ -2,7 +2,11 @@ name: Smoke Test on: push: + branches: + - main pull_request: + branches: + - main jobs: smoke-test: @@ -16,11 +20,14 @@ jobs: with: python-version: '3.x' - - name: Install dependencies + - name: Install package and dependencies run: | - pip install -r requirements.txt - pip install -r requirements-dev.txt + pip install -e .[dev] - name: Run smoke test run: | - PYTHONPATH=. pytest test/ + python -m BackendBench.scripts.main --suite smoke --backend aten + + - name: Run pytest tests + run: | + pytest test/ diff --git a/.gitignore b/.gitignore index 092f9a2..8438e87 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__/ .vscode/ .ruff_cache/ generated_kernels/ +backendbench.egg-info/ CLAUDE.md venv/ ops/ diff --git a/BackendBench/__init__.py b/BackendBench/__init__.py new file mode 100644 index 0000000..7eb3f4a --- /dev/null +++ b/BackendBench/__init__.py @@ -0,0 +1,123 @@ +""" +BackendBench: A PyTorch backend evaluation framework with monkey patching support. + +Import this module to automatically monkey patch PyTorch operations with custom backends. +""" + +import os + +from .backends import AtenBackend, FlagGemsBackend + + +class BackendRegistry: + """Registry for managing different PyTorch backends.""" + + def __init__(self): + self._current_backend = None + self._original_ops = {} + self._patched = False + + def register_backend(self, backend_name: str, backend_instance=None): + """Register and activate a backend.""" + if backend_instance is None: + backend_instance = self._create_backend(backend_name) + + if self._patched: + self.unpatch() + + self._current_backend = backend_instance + self._patch_torch_ops() + + def _create_backend(self, backend_name: str): + """Create a backend instance.""" + backends = {"aten": AtenBackend, "flag_gems": FlagGemsBackend} + + if backend_name not in backends: + raise ValueError(f"Unknown backend: {backend_name}. Available: {list(backends.keys())}") + + return backends[backend_name]() + + def _patch_torch_ops(self): + """Monkey patch torch operations with current backend.""" + if self._current_backend is None: + return + + # Get all torch ops that the backend supports + if hasattr(self._current_backend, "ops"): + for torch_op, backend_impl in self._current_backend.ops.items(): + if torch_op not in self._original_ops: + self._original_ops[torch_op] = torch_op.default + torch_op.default = backend_impl + + self._patched = True + print( + f"BackendBench: Monkey patched {len(self._original_ops)} operations with {self._current_backend.name} backend" + ) + + def unpatch(self): + """Restore original torch operations.""" + if not self._patched: + return + + for torch_op, original_impl in self._original_ops.items(): + torch_op.default = original_impl + + self._original_ops.clear() + self._patched = False + print("BackendBench: Restored original PyTorch operations") + + def get_current_backend(self): + """Get the currently active backend.""" + return self._current_backend + + def is_patched(self): + """Check if operations are currently patched.""" + return self._patched + + +# Global registry instance +_registry = BackendRegistry() + + +def use_backend(backend_name: str, backend_instance=None): + """ + Switch to a different backend. + + Args: + backend_name: Name of the backend ('aten', 'flag_gems') + backend_instance: Optional pre-configured backend instance + """ + _registry.register_backend(backend_name, backend_instance) + + +def get_backend(): + """Get the currently active backend.""" + return _registry.get_current_backend() + + +def restore_pytorch(): + """Restore original PyTorch operations.""" + _registry.unpatch() + + +def is_patched(): + """Check if BackendBench is currently patching operations.""" + return _registry.is_patched() + + +# Auto-configuration based on environment variables +def _auto_configure(): + """Auto-configure backend based on environment variables.""" + backend_name = os.getenv("BACKENDBENCH_BACKEND", "aten") + + try: + use_backend(backend_name) + except Exception as e: + print(f"Warning: Failed to initialize {backend_name} backend: {e}") + print("Falling back to aten backend") + use_backend("aten") + + +# Auto-configure on import unless explicitly disabled +if os.getenv("BACKENDBENCH_NO_AUTO_PATCH", "").lower() not in ("1", "true", "yes"): + _auto_configure() diff --git a/BackendBench/scripts/__init__.py b/BackendBench/scripts/__init__.py new file mode 100644 index 0000000..eb7e680 --- /dev/null +++ b/BackendBench/scripts/__init__.py @@ -0,0 +1 @@ +# Scripts module for BackendBench diff --git a/scripts/create_simple_test_ops.py b/BackendBench/scripts/create_simple_test_ops.py similarity index 100% rename from scripts/create_simple_test_ops.py rename to BackendBench/scripts/create_simple_test_ops.py diff --git a/scripts/main.py b/BackendBench/scripts/main.py similarity index 100% rename from scripts/main.py rename to BackendBench/scripts/main.py diff --git a/pyproject.toml b/pyproject.toml index b5d514e..601d389 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,49 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "backendbench" +version = "0.1.0" +description = "A PyTorch backend evaluation suite" +readme = "README.md" +requires-python = ">=3.10" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] +dependencies = [ + "torch", + "click", + "numpy", + "expecttest", + "anthropic>=0.34.0", + "pytest", + "requests", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-cov", + "pytest-mock", + "pytest-timeout", + "ruff==0.12.1", +] +flaggems = [ + "flag_gems", +] + +[project.scripts] +backendbench = "BackendBench.scripts.main:cli" + +[tool.setuptools.packages.find] +include = ["BackendBench*"] + [tool.ruff] line-length = 100 diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 851900e..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,5 +0,0 @@ -pytest -pytest-cov -pytest-mock -pytest-timeout -ruff==0.12.1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 0ad97e0..0000000 --- a/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -torch -click -numpy -expecttest -anthropic>=0.34.0 -pytest -requests diff --git a/test/test_directory_backend.py b/test/test_directory_backend.py index d6682ba..8ed5272 100644 --- a/test/test_directory_backend.py +++ b/test/test_directory_backend.py @@ -19,7 +19,9 @@ def backend(): # Import and run the existing script import subprocess - subprocess.run([sys.executable, "scripts/create_simple_test_ops.py"], check=True) + subprocess.run( + [sys.executable, "BackendBench/scripts/create_simple_test_ops.py"], check=True + ) return DirectoryBackend(ops_dir="generated_kernels") From 455b443fc1647089d683e71e30099e48234fa3ad Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Mon, 28 Jul 2025 13:01:45 -0400 Subject: [PATCH 07/32] Fix flag gems tests and imports (#35) --- BackendBench/backends.py | 15 ++++++--------- test/test_backends.py | 2 +- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/BackendBench/backends.py b/BackendBench/backends.py index 409a203..31d0f37 100644 --- a/BackendBench/backends.py +++ b/BackendBench/backends.py @@ -2,9 +2,15 @@ import importlib.util import logging from typing import Dict, Callable, List +import torch logger = logging.getLogger(__name__) +try: + import flag_gems +except ImportError: + flag_gems = None + class Backend: def __init__(self, name): @@ -67,8 +73,6 @@ def _load_kernel_from_file(self, file_path: str, op_name: str) -> Callable: def _find_pytorch_op(self, op_name: str): """Map operation name to PyTorch operation.""" - import torch - # Try common patterns try: return getattr(torch.ops.aten, op_name).default @@ -106,14 +110,10 @@ def __contains__(self, key): def _flag_gems_softmax(*args, **kwargs): # half_to_float is not supported in flag_gems - import flag_gems - return flag_gems.ops.softmax(*args[:-1], **kwargs) def _flag_gems_layernorm(*args, **kwargs): - import flag_gems - x, m, v = flag_gems.ops.layer_norm(*args[:-1], **kwargs) mv_shape = [*x.shape[:-1], 1] return x, m.view(*mv_shape), v.view(*mv_shape) @@ -122,9 +122,6 @@ def _flag_gems_layernorm(*args, **kwargs): class FlagGemsBackend(Backend): def __init__(self) -> None: super().__init__("flaggems") - import flag_gems - import torch - self.ops = { torch.ops.aten.abs.default: flag_gems.ops.abs, torch.ops.aten.abs_.default: flag_gems.ops.abs_, diff --git a/test/test_backends.py b/test/test_backends.py index 66a42f4..24ec6f3 100644 --- a/test/test_backends.py +++ b/test/test_backends.py @@ -73,7 +73,7 @@ def test_flag_gems_backend_contains_op(self, mock_flag_gems): @patch("BackendBench.backends.flag_gems") def test_flag_gems_backend_getitem(self, mock_flag_gems): mock_abs_impl = Mock() - mock_flag_gems.abs = mock_abs_impl + mock_flag_gems.ops.abs = mock_abs_impl backend = FlagGemsBackend() From dd1aa1c8770984a33bd8eb55befbc33826529690 Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Mon, 28 Jul 2025 21:56:30 -0400 Subject: [PATCH 08/32] Fixes to kernel agent backend tests (#46) --- BackendBench/backends.py | 23 ++++++++++++------- test/test_backends.py | 48 +++++++++++++++++++++++----------------- 2 files changed, 43 insertions(+), 28 deletions(-) diff --git a/BackendBench/backends.py b/BackendBench/backends.py index 31d0f37..09a4c24 100644 --- a/BackendBench/backends.py +++ b/BackendBench/backends.py @@ -1,7 +1,8 @@ -import os import importlib.util import logging -from typing import Dict, Callable, List +import os +from typing import Callable, Dict, List + import torch logger = logging.getLogger(__name__) @@ -397,7 +398,8 @@ def __init__(self) -> None: # Create README for this run readme_path = os.path.join(self.kernels_dir, "README.md") with open(readme_path, "w") as f: - f.write(f"""# Generated Kernels - {timestamp} + f.write( + f"""# Generated Kernels - {timestamp} This directory contains PyTorch/Triton kernels generated by the LLM Backend. @@ -413,7 +415,8 @@ def __init__(self) -> None: ## Usage You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. -""") +""" + ) print(f"Saving generated kernels to: {self.kernels_dir}") @@ -521,8 +524,8 @@ def test_kernel_correctness( f.write(full_code) print(f"Saved kernel to: {kernel_file}") - import sys import importlib.util + import sys spec = importlib.util.spec_from_file_location( f"test_kernel_{op_name}_{attempt}", kernel_file @@ -633,7 +636,8 @@ def __init__(self) -> None: # Create README for this run readme_path = os.path.join(self.kernels_dir, "README.md") with open(readme_path, "w") as f: - f.write(f"""# Generated Kernels - KernelAgent - {timestamp} + f.write( + f"""# Generated Kernels - KernelAgent - {timestamp} This directory contains PyTorch/Triton kernels generated by the KernelAgent Backend. @@ -656,7 +660,8 @@ def __init__(self) -> None: ## Usage You can inspect these files to debug kernel generation, analyze the parallel worker outputs, or understand the sophisticated generation process used by KernelAgent. -""") +""" + ) print(f"Saving KernelAgent generated kernels to: {self.kernels_dir}") @@ -688,7 +693,9 @@ def _get_kernel_agent(self): os.makedirs(agent_log_dir, exist_ok=True) self.kernel_agent = TritonKernelAgent( - log_dir=agent_log_dir, num_workers=self.num_workers, max_rounds=self.max_rounds + log_dir=agent_log_dir, + num_workers=self.num_workers, + max_rounds=self.max_rounds, ) print(f"✓ KernelAgent initialized with log directory: {agent_log_dir}") diff --git a/test/test_backends.py b/test/test_backends.py index 24ec6f3..4f6a46a 100644 --- a/test/test_backends.py +++ b/test/test_backends.py @@ -1,7 +1,13 @@ +from unittest.mock import Mock, patch + import pytest import torch -from unittest.mock import Mock, patch -from BackendBench.backends import AtenBackend, FlagGemsBackend, LLMBackend, KernelAgentBackend +from BackendBench.backends import ( + AtenBackend, + FlagGemsBackend, + KernelAgentBackend, + LLMBackend, +) try: import importlib.util @@ -11,9 +17,9 @@ HAS_FLAG_GEMS = False try: - import sys - import os import importlib.util + import os + import sys kernel_agent_path = os.path.join(os.path.dirname(__file__), "..", "KernelAgent") sys.path.insert(0, os.path.abspath(kernel_agent_path)) @@ -159,27 +165,24 @@ def generated_kernel(x): class TestKernelAgentBackend: @pytest.mark.skipif(not HAS_KERNEL_AGENT, reason="KernelAgent not available") def test_kernel_agent_backend_initialization(self): - with patch("os.makedirs"): - backend = KernelAgentBackend() - assert backend.name == "kernel_agent" - assert "kernel_agent_run_" in backend.kernels_dir - assert backend.num_workers == 4 # default value - assert backend.max_rounds == 10 # default value + backend = KernelAgentBackend() + assert backend.name == "kernel_agent" + assert "kernel_agent_run_" in backend.kernels_dir + assert backend.num_workers == 4 # default value + assert backend.max_rounds == 10 # default value @pytest.mark.skipif(not HAS_KERNEL_AGENT, reason="KernelAgent not available") def test_kernel_agent_backend_set_config(self): - with patch("os.makedirs"): - backend = KernelAgentBackend() + backend = KernelAgentBackend() - backend.set_config(num_workers=8, max_rounds=20) + backend.set_config(num_workers=8, max_rounds=20) - assert backend.num_workers == 8 - assert backend.max_rounds == 20 + assert backend.num_workers == 8 + assert backend.max_rounds == 20 @pytest.mark.skipif(not HAS_KERNEL_AGENT, reason="KernelAgent not available") def test_kernel_agent_backend_generate_kernel(self): with ( - patch("os.makedirs"), patch("triton_kernel_agent.TritonKernelAgent") as mock_kernel_agent_class, ): backend = KernelAgentBackend() @@ -187,7 +190,13 @@ def test_kernel_agent_backend_generate_kernel(self): mock_agent = Mock() mock_kernel_agent_class.return_value = mock_agent - mock_agent.generate_kernel.return_value = (True, "def kernel(): pass") + mock_agent.generate_kernel.return_value = { + "success": True, + "kernel_code": "def kernel(): pass", + "rounds": 1, + "session_dir": "test_session_dir", + "worker_id": 0, + } mock_op = Mock() mock_op.__str__ = Mock(return_value="test_op") @@ -205,9 +214,8 @@ def test_backend_polymorphism(self): backends.append(AtenBackend()) with patch("BackendBench.backends.flag_gems"): backends.append(FlagGemsBackend()) - with patch("os.makedirs"): - backends.append(LLMBackend()) - backends.append(KernelAgentBackend()) + backends.append(LLMBackend()) + backends.append(KernelAgentBackend()) for backend in backends: assert hasattr(backend, "name") assert hasattr(backend, "__contains__") From 5a5702a7da68907700a300369b61d2b85ee7dbcd Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Wed, 30 Jul 2025 14:24:12 -0700 Subject: [PATCH 09/32] Filter out solutions that have cuda streams (#56) --- BackendBench/eval.py | 7 ++ BackendBench/utils.py | 53 ++++++++++++++ pyproject.toml | 3 + test/test_utils.py | 162 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 225 insertions(+) create mode 100644 BackendBench/utils.py diff --git a/BackendBench/eval.py b/BackendBench/eval.py index 816697f..25b0067 100644 --- a/BackendBench/eval.py +++ b/BackendBench/eval.py @@ -5,6 +5,8 @@ import triton.testing +from BackendBench.utils import uses_cuda_stream + logger = logging.getLogger(__name__) EXC_MSG = """ @@ -101,6 +103,11 @@ def eval_performance(op, impl, tests): def eval_one_op(op, impl, correctness_tests, performance_tests): """Evaluate impl of op against correctness_tests and performance_tests.""" + # TODO: We should have proper error reporting instead of just saying this is 0, + # but that should be a separate PR. + if uses_cuda_stream(impl): + logger.warning(f"Skipping {op.__name__} because it uses CUDA stream") + return 0, 0 return eval_correctness(op, impl, correctness_tests), eval_performance( op, impl, performance_tests ) diff --git a/BackendBench/utils.py b/BackendBench/utils.py new file mode 100644 index 0000000..0d9cd0c --- /dev/null +++ b/BackendBench/utils.py @@ -0,0 +1,53 @@ +import ast +import inspect +import re +import textwrap + + +def uses_cuda_stream(func) -> bool: + """ + Detects whether a Python function creates CUDA streams. + + Args: + func: The Python function to analyze + + Returns: + bool: True if CUDA streams are created, False otherwise + """ + try: + source = inspect.getsource(func) + except (TypeError, OSError): + # Handle builtin functions, OpOverload objects, and other callables + # without source code. These cannot create CUDA streams. + return False + + # Check for stream creation patterns + patterns = [ + r"torch\.cuda\.Stream\(", # torch.cuda.Stream() constructor + r"cupy\.cuda\.Stream\(", # cupy.cuda.Stream() constructor + r"cuda\.Stream\(", # Generic cuda.Stream() constructor + r"pycuda.*Stream\(", # PyCUDA stream creation + r"\bStream\(", # Stream() constructor calls + r"make_stream\(", # make_stream() factory function + r"create_stream\(", # create_stream() factory function + ] + + if any(re.search(p, source, re.IGNORECASE) for p in patterns): + return True + + class StreamCreationFinder(ast.NodeVisitor): + def __init__(self): + self.found = False + + def visit_Call(self, node): + # Check for Stream() constructor calls + if hasattr(node.func, "attr") and node.func.attr == "Stream": + self.found = True + elif hasattr(node.func, "id") and node.func.id == "Stream": + self.found = True + self.generic_visit(node) + + tree = ast.parse(textwrap.dedent(source)) + finder = StreamCreationFinder() + finder.visit(tree) + return finder.found diff --git a/pyproject.toml b/pyproject.toml index 601d389..d476838 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,9 @@ dev = [ "pytest-mock", "pytest-timeout", "ruff==0.12.1", + "torch", + "numpy", + "cupy-cuda12x", ] flaggems = [ "flag_gems", diff --git a/test/test_utils.py b/test/test_utils.py index 37ec391..30f8162 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,10 +1,168 @@ import torch import math +import pytest + from BackendBench.scripts.utils import ( serialize_args, deserialize_args, _deserialize_tensor, ) +from BackendBench.utils import uses_cuda_stream + +# Check if CUDA is available +HAS_CUDA = torch.cuda.is_available() + + +class TestCudaStreamDetection: + @pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available") + def test_pytorch_stream_creation(self): + """Test detection of PyTorch CUDA stream creation.""" + + def func_with_pytorch_stream(): + import torch + + stream = torch.cuda.Stream() + return stream + + assert uses_cuda_stream(func_with_pytorch_stream) + + @pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available") + def test_cupy_stream_creation(self): + import cupy + + """Test detection of CuPy CUDA stream creation.""" + + def func_with_cupy_stream(): + stream = cupy.cuda.Stream() + return stream + + assert uses_cuda_stream(func_with_cupy_stream) + + @pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available") + def test_generic_stream_creation(self): + """Test detection of generic Stream() calls.""" + + def func_with_generic_stream(): + from torch.cuda import Stream + + stream = Stream() + return stream + + assert uses_cuda_stream(func_with_generic_stream) + + @pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available") + def test_stream_with_device_id(self): + """Test detection of Stream with device ID.""" + + def func_with_device_stream(): + from torch.cuda import Stream + + stream = Stream(0) + return stream + + assert uses_cuda_stream(func_with_device_stream) + + def test_no_stream_creation(self): + """Test functions without stream creation return False.""" + + def func_without_stream(): + import torch + + x = torch.randn(100, 100) + y = x @ x.T + return y + + assert not uses_cuda_stream(func_without_stream) + + @pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available") + def test_lambda_function(self): + """Test detection in lambda functions.""" + + def func_lambda_with_stream(): + return torch.cuda.Stream() + + def func_lambda_without(x): + return x * 2 + + assert uses_cuda_stream(func_lambda_with_stream) + assert not uses_cuda_stream(func_lambda_without) + + @pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available") + def test_nested_function(self): + """Test detection in nested functions.""" + + def outer_function(): + def inner_with_stream(): + import torch + + return torch.cuda.Stream() + + return inner_with_stream + + inner = outer_function() + assert uses_cuda_stream(inner) + + @pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available") + def test_class_method(self): + """Test detection in class methods.""" + + class StreamClass: + def method_with_stream(self): + import torch + + self.stream = torch.cuda.Stream() + return self.stream + + def method_without_stream(self): + return "no stream here" + + obj = StreamClass() + assert uses_cuda_stream(obj.method_with_stream) + assert not uses_cuda_stream(obj.method_without_stream) + + @pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available") + def test_various_formats(self): + """Test various formatting of stream creation.""" + + def func_spaces(): + stream = torch.cuda.Stream() + return stream + + def func_multiline(): + stream = torch.cuda.Stream(device=0) + return stream + + def func_chained(): + result = torch.cuda.Stream().query() + return result + + assert uses_cuda_stream(func_spaces) + assert uses_cuda_stream(func_multiline) + assert uses_cuda_stream(func_chained) + + @pytest.mark.skipif(not HAS_CUDA, reason="CUDA not available") + def test_case_sensitivity(self): + """Test case-insensitive detection.""" + + def func_lowercase(): + stream = torch.cuda.stream() # lowercase (if it existed) + return stream + + def func_uppercase(): + stream = torch.cuda.STREAM() # uppercase (if it existed) + return stream + + # These should still be detected due to case-insensitive regex + assert uses_cuda_stream(func_lowercase) + assert uses_cuda_stream(func_uppercase) + + def test_opoverload_callables(self): + """Test that OpOverload objects don't raise exceptions.""" + import torch + + # Test OpOverload (torch operators) + assert not uses_cuda_stream(torch.add) + assert not uses_cuda_stream(torch.ops.aten.add) class TestDeserializeArgs: @@ -371,3 +529,7 @@ def test_integer_tensors(self): tensor = _deserialize_tensor([10], dtype) assert tensor.dtype == dtype assert tensor.shape == (10,) + + +if __name__ == "__main__": + pytest.main([__file__]) From e6bb19a44980af5c1f117c1c752fd3b3cd699f53 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Tue, 29 Jul 2025 16:38:30 -0700 Subject: [PATCH 10/32] Add tests for serialization and deserialization --- BackendBench/scripts/utils.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/BackendBench/scripts/utils.py b/BackendBench/scripts/utils.py index 52b07ae..7a78e57 100644 --- a/BackendBench/scripts/utils.py +++ b/BackendBench/scripts/utils.py @@ -27,24 +27,39 @@ def _deserialize_tensor(size, dtype, stride=None, device="cuda"): kwargs = {} if dtype in _FLOATING_TYPES: kwargs.update({"low": 0, "high": 1}) +<<<<<<< HEAD # Fall back to CPU if CUDA is not available if device == "cuda" and not torch.cuda.is_available(): device = "cpu" +======= + + # Fall back to CPU if CUDA is not available + if device == "cuda" and not torch.cuda.is_available(): + device = "cpu" + +>>>>>>> 201e39a (Add tests for serialization and deserialization) if stride is not None: extent = 1 + sum((size - 1) * stride for size, stride in zip(size, stride)) data = make_tensor(extent, dtype=dtype, device=device, **kwargs) return data.as_strided(size, stride) return make_tensor(size, dtype=dtype, device=device, **kwargs) +<<<<<<< HEAD +======= +>>>>>>> 201e39a (Add tests for serialization and deserialization) def _serialize_tensor(tensor): """Helper function to serialize a tensor to string format""" shape = list(tensor.shape) dtype = dtype_abbrs[tensor.dtype] stride = tensor.stride() if not tensor.is_contiguous() else None +<<<<<<< HEAD +======= + +>>>>>>> 201e39a (Add tests for serialization and deserialization) if stride: return f"T({shape}, {dtype}, {list(stride)})" else: @@ -79,12 +94,22 @@ def serialize_args(args, kwargs) -> str: parts = [_serialize_value(arg) for arg in args] # Process keyword arguments +<<<<<<< HEAD kwargs_parts = [f"'{key}': {_serialize_value(val)}" for key, val in kwargs.items()] # Handle empty args tuple properly args_str = f"({', '.join(parts)},)" if parts else "()" return f"({args_str}, {{{', '.join(kwargs_parts)}}})" +======= + kwargs_parts = [f"{key}={_serialize_value(val)}" for key, val in kwargs.items()] + + return f"(({', '.join(parts)},), {{{', '.join(kwargs_parts)}}})" + + +# Alias for backward compatibility +reserialize_args = serialize_args +>>>>>>> 201e39a (Add tests for serialization and deserialization) def deserialize_args(inps): From 4b1722b1f343e65ddc3ffdc3b16964eae281eedc Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Tue, 29 Jul 2025 17:29:33 -0700 Subject: [PATCH 11/32] fix --- BackendBench/scripts/utils.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/BackendBench/scripts/utils.py b/BackendBench/scripts/utils.py index 7a78e57..11b674a 100644 --- a/BackendBench/scripts/utils.py +++ b/BackendBench/scripts/utils.py @@ -27,6 +27,7 @@ def _deserialize_tensor(size, dtype, stride=None, device="cuda"): kwargs = {} if dtype in _FLOATING_TYPES: kwargs.update({"low": 0, "high": 1}) +<<<<<<< HEAD <<<<<<< HEAD # Fall back to CPU if CUDA is not available @@ -40,6 +41,13 @@ def _deserialize_tensor(size, dtype, stride=None, device="cuda"): device = "cpu" >>>>>>> 201e39a (Add tests for serialization and deserialization) +======= + + # Fall back to CPU if CUDA is not available + if device == "cuda" and not torch.cuda.is_available(): + device = "cpu" + +>>>>>>> a15dcbc (fix) if stride is not None: extent = 1 + sum((size - 1) * stride for size, stride in zip(size, stride)) data = make_tensor(extent, dtype=dtype, device=device, **kwargs) @@ -47,19 +55,27 @@ def _deserialize_tensor(size, dtype, stride=None, device="cuda"): return make_tensor(size, dtype=dtype, device=device, **kwargs) <<<<<<< HEAD +<<<<<<< HEAD ======= >>>>>>> 201e39a (Add tests for serialization and deserialization) +======= + +>>>>>>> a15dcbc (fix) def _serialize_tensor(tensor): """Helper function to serialize a tensor to string format""" shape = list(tensor.shape) dtype = dtype_abbrs[tensor.dtype] stride = tensor.stride() if not tensor.is_contiguous() else None <<<<<<< HEAD +<<<<<<< HEAD ======= >>>>>>> 201e39a (Add tests for serialization and deserialization) +======= + +>>>>>>> a15dcbc (fix) if stride: return f"T({shape}, {dtype}, {list(stride)})" else: @@ -95,16 +111,22 @@ def serialize_args(args, kwargs) -> str: # Process keyword arguments <<<<<<< HEAD +<<<<<<< HEAD +======= +>>>>>>> a15dcbc (fix) kwargs_parts = [f"'{key}': {_serialize_value(val)}" for key, val in kwargs.items()] # Handle empty args tuple properly args_str = f"({', '.join(parts)},)" if parts else "()" return f"({args_str}, {{{', '.join(kwargs_parts)}}})" +<<<<<<< HEAD ======= kwargs_parts = [f"{key}={_serialize_value(val)}" for key, val in kwargs.items()] return f"(({', '.join(parts)},), {{{', '.join(kwargs_parts)}}})" +======= +>>>>>>> a15dcbc (fix) # Alias for backward compatibility From 3a670c6624f9f01af2cff55cf22b36c0b4085166 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Tue, 29 Jul 2025 17:39:19 -0700 Subject: [PATCH 12/32] fix --- BackendBench/scripts/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/BackendBench/scripts/utils.py b/BackendBench/scripts/utils.py index 11b674a..3b317a6 100644 --- a/BackendBench/scripts/utils.py +++ b/BackendBench/scripts/utils.py @@ -129,11 +129,14 @@ def serialize_args(args, kwargs) -> str: >>>>>>> a15dcbc (fix) +<<<<<<< HEAD # Alias for backward compatibility reserialize_args = serialize_args >>>>>>> 201e39a (Add tests for serialization and deserialization) +======= +>>>>>>> 4ed1e55 (fix) def deserialize_args(inps): inps = inps.strip().strip("'") global_vals = { From e4ccfb84f42338baf6052dc19816a8d27e8d02a3 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 31 Jul 2025 09:52:49 -0700 Subject: [PATCH 13/32] rebase --- BackendBench/scripts/utils.py | 52 +---------------------------------- 1 file changed, 1 insertion(+), 51 deletions(-) diff --git a/BackendBench/scripts/utils.py b/BackendBench/scripts/utils.py index 3b317a6..82acb03 100644 --- a/BackendBench/scripts/utils.py +++ b/BackendBench/scripts/utils.py @@ -27,55 +27,24 @@ def _deserialize_tensor(size, dtype, stride=None, device="cuda"): kwargs = {} if dtype in _FLOATING_TYPES: kwargs.update({"low": 0, "high": 1}) -<<<<<<< HEAD -<<<<<<< HEAD # Fall back to CPU if CUDA is not available if device == "cuda" and not torch.cuda.is_available(): device = "cpu" -======= - - # Fall back to CPU if CUDA is not available - if device == "cuda" and not torch.cuda.is_available(): - device = "cpu" - ->>>>>>> 201e39a (Add tests for serialization and deserialization) -======= - - # Fall back to CPU if CUDA is not available - if device == "cuda" and not torch.cuda.is_available(): - device = "cpu" - ->>>>>>> a15dcbc (fix) if stride is not None: extent = 1 + sum((size - 1) * stride for size, stride in zip(size, stride)) data = make_tensor(extent, dtype=dtype, device=device, **kwargs) return data.as_strided(size, stride) return make_tensor(size, dtype=dtype, device=device, **kwargs) -<<<<<<< HEAD -<<<<<<< HEAD - -======= ->>>>>>> 201e39a (Add tests for serialization and deserialization) -======= ->>>>>>> a15dcbc (fix) def _serialize_tensor(tensor): """Helper function to serialize a tensor to string format""" shape = list(tensor.shape) dtype = dtype_abbrs[tensor.dtype] stride = tensor.stride() if not tensor.is_contiguous() else None -<<<<<<< HEAD -<<<<<<< HEAD -======= - ->>>>>>> 201e39a (Add tests for serialization and deserialization) -======= - ->>>>>>> a15dcbc (fix) if stride: return f"T({shape}, {dtype}, {list(stride)})" else: @@ -110,33 +79,14 @@ def serialize_args(args, kwargs) -> str: parts = [_serialize_value(arg) for arg in args] # Process keyword arguments -<<<<<<< HEAD -<<<<<<< HEAD -======= ->>>>>>> a15dcbc (fix) kwargs_parts = [f"'{key}': {_serialize_value(val)}" for key, val in kwargs.items()] # Handle empty args tuple properly args_str = f"({', '.join(parts)},)" if parts else "()" return f"({args_str}, {{{', '.join(kwargs_parts)}}})" -<<<<<<< HEAD -======= - kwargs_parts = [f"{key}={_serialize_value(val)}" for key, val in kwargs.items()] - - return f"(({', '.join(parts)},), {{{', '.join(kwargs_parts)}}})" -======= ->>>>>>> a15dcbc (fix) - - -<<<<<<< HEAD -# Alias for backward compatibility -reserialize_args = serialize_args ->>>>>>> 201e39a (Add tests for serialization and deserialization) -======= ->>>>>>> 4ed1e55 (fix) def deserialize_args(inps): inps = inps.strip().strip("'") global_vals = { @@ -149,4 +99,4 @@ def deserialize_args(inps): # f strings introduce quotations we dont want for key in dtype_abbrs_parsing: inps = inps.replace(f"'{key}'", key) - return eval(inps.strip().strip("'").strip('"'), global_vals) + return eval(inps.strip().strip("'").strip('"'), global_vals) \ No newline at end of file From 7618519841027f35c91c4ab7c94d2755ac867093 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 31 Jul 2025 10:00:51 -0700 Subject: [PATCH 14/32] rebase fix --- BackendBench/scripts/utils.py | 102 ------------------------------- BackendBench/torchbench_suite.py | 56 ++--------------- BackendBench/utils.py | 102 +++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 154 deletions(-) delete mode 100644 BackendBench/scripts/utils.py diff --git a/BackendBench/scripts/utils.py b/BackendBench/scripts/utils.py deleted file mode 100644 index 82acb03..0000000 --- a/BackendBench/scripts/utils.py +++ /dev/null @@ -1,102 +0,0 @@ -import math -import torch -from torch.testing import make_tensor - -dtype_abbrs = { - torch.bfloat16: "bf16", - torch.float64: "f64", - torch.float32: "f32", - torch.float16: "f16", - torch.complex32: "c32", - torch.complex64: "c64", - torch.complex128: "c128", - torch.int8: "i8", - torch.int16: "i16", - torch.int32: "i32", - torch.int64: "i64", - torch.bool: "b8", - torch.uint8: "u8", -} - -dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()} - -_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64] - - -def _deserialize_tensor(size, dtype, stride=None, device="cuda"): - kwargs = {} - if dtype in _FLOATING_TYPES: - kwargs.update({"low": 0, "high": 1}) - - # Fall back to CPU if CUDA is not available - if device == "cuda" and not torch.cuda.is_available(): - device = "cpu" - - if stride is not None: - extent = 1 + sum((size - 1) * stride for size, stride in zip(size, stride)) - data = make_tensor(extent, dtype=dtype, device=device, **kwargs) - return data.as_strided(size, stride) - return make_tensor(size, dtype=dtype, device=device, **kwargs) - - -def _serialize_tensor(tensor): - """Helper function to serialize a tensor to string format""" - shape = list(tensor.shape) - dtype = dtype_abbrs[tensor.dtype] - stride = tensor.stride() if not tensor.is_contiguous() else None - - if stride: - return f"T({shape}, {dtype}, {list(stride)})" - else: - return f"T({shape}, {dtype})" - - -def _serialize_value(value): - """Helper function to serialize any value (tensor, list, primitive)""" - if isinstance(value, torch.Tensor): - return _serialize_tensor(value) - elif isinstance(value, list): - list_parts = [_serialize_value(item) for item in value] - return f"[{', '.join(list_parts)}]" - else: - return repr(value) - - -def serialize_args(args, kwargs) -> str: - """Convert args and kwargs back to the BackendBench string format - - Args: - args: List of arguments (can contain tensors, lists, primitives) - kwargs: Dict of keyword arguments - - Returns: - Serialized string in format: (arg1, arg2, ..., key1=val1, key2=val2, ...) - """ - if args is None or kwargs is None: - return "None" - - # Process positional arguments - parts = [_serialize_value(arg) for arg in args] - - # Process keyword arguments - kwargs_parts = [f"'{key}': {_serialize_value(val)}" for key, val in kwargs.items()] - - # Handle empty args tuple properly - args_str = f"({', '.join(parts)},)" if parts else "()" - - return f"({args_str}, {{{', '.join(kwargs_parts)}}})" - - -def deserialize_args(inps): - inps = inps.strip().strip("'") - global_vals = { - "T": _deserialize_tensor, - "th": torch, - "inf": math.inf, - "torch": torch, - **dtype_abbrs_parsing, - } - # f strings introduce quotations we dont want - for key in dtype_abbrs_parsing: - inps = inps.replace(f"'{key}'", key) - return eval(inps.strip().strip("'").strip('"'), global_vals) \ No newline at end of file diff --git a/BackendBench/torchbench_suite.py b/BackendBench/torchbench_suite.py index c1ec454..7aa91e7 100644 --- a/BackendBench/torchbench_suite.py +++ b/BackendBench/torchbench_suite.py @@ -2,7 +2,6 @@ Load aten inputs from serialized txt files. """ -import math import re import tempfile from collections import defaultdict @@ -10,60 +9,13 @@ import requests import torch -from torch.testing import make_tensor +from BackendBench.utils import deserialize_args # the schema for this dataset is the one defined in tritonbench traces. # ie. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/data/input_configs/hf_train/AlbertForMaskedLM_training.txt DEFAULT_HUGGINGFACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/tritonbench_op_trace.txt" -dtype_abbrs = { - torch.bfloat16: "bf16", - torch.float64: "f64", - torch.float32: "f32", - torch.float16: "f16", - torch.complex32: "c32", - torch.complex64: "c64", - torch.complex128: "c128", - torch.int8: "i8", - torch.int16: "i16", - torch.int32: "i32", - torch.int64: "i64", - torch.bool: "b8", - torch.uint8: "u8", -} - -dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()} - -_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64] - - -def _deserialize_tensor(size, dtype, stride=None, device="cuda"): - kwargs = {} - if dtype in _FLOATING_TYPES: - kwargs.update({"low": 0, "high": 1}) - if stride is not None: - extent = 1 + sum((size - 1) * stride for size, stride in zip(size, stride)) - data = make_tensor(extent, dtype=dtype, device=device, **kwargs) - return data.as_strided(size, stride) - return make_tensor(size, dtype=dtype, device=device, **kwargs) - - -def _deserialize_args(inps): - inps = inps.strip().strip("'") - global_vals = { - "T": _deserialize_tensor, - "th": torch, - "inf": math.inf, - "torch": torch, - **dtype_abbrs_parsing, - } - # f strings introduce quotations we dont want - for key in dtype_abbrs_parsing: - inps = inps.replace(f"'{key}'", key) - return eval(inps.strip().strip("'").strip('"'), global_vals) - - class TorchBenchTest: def __init__(self, *args, **kwargs): self.args = args @@ -89,7 +41,7 @@ def __init__(self, op, inputs, topn): def tests(self): inputs_and_sizes = [] for inp in self.inputs: - args, kwargs = _deserialize_args(inp) + args, kwargs = deserialize_args(inp) size = _args_size(args) + _args_size(list(kwargs.values())) inputs_and_sizes.append((size, inp)) ret = [x[1] for x in sorted(inputs_and_sizes, reverse=True)] @@ -98,13 +50,13 @@ def tests(self): @property def correctness_tests(self): for inp in self.tests(): - args, kwargs = _deserialize_args(inp) + args, kwargs = deserialize_args(inp) yield TorchBenchTest(*args, **kwargs) @property def performance_tests(self): for inp in self.tests(): - args, kwargs = _deserialize_args(inp) + args, kwargs = deserialize_args(inp) yield TorchBenchTest(*args, **kwargs) diff --git a/BackendBench/utils.py b/BackendBench/utils.py index 0d9cd0c..600934f 100644 --- a/BackendBench/utils.py +++ b/BackendBench/utils.py @@ -2,6 +2,29 @@ import inspect import re import textwrap +import math +import torch +from torch.testing import make_tensor + +dtype_abbrs = { + torch.bfloat16: "bf16", + torch.float64: "f64", + torch.float32: "f32", + torch.float16: "f16", + torch.complex32: "c32", + torch.complex64: "c64", + torch.complex128: "c128", + torch.int8: "i8", + torch.int16: "i16", + torch.int32: "i32", + torch.int64: "i64", + torch.bool: "b8", + torch.uint8: "u8", +} + +dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()} + +_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64] def uses_cuda_stream(func) -> bool: @@ -51,3 +74,82 @@ def visit_Call(self, node): finder = StreamCreationFinder() finder.visit(tree) return finder.found + + +def _deserialize_tensor(size, dtype, stride=None, device="cuda"): + kwargs = {} + if dtype in _FLOATING_TYPES: + kwargs.update({"low": 0, "high": 1}) + + # Fall back to CPU if CUDA is not available + if device == "cuda" and not torch.cuda.is_available(): + device = "cpu" + + if stride is not None: + extent = 1 + sum((size - 1) * stride for size, stride in zip(size, stride)) + data = make_tensor(extent, dtype=dtype, device=device, **kwargs) + return data.as_strided(size, stride) + return make_tensor(size, dtype=dtype, device=device, **kwargs) + + +def _serialize_tensor(tensor): + """Helper function to serialize a tensor to string format""" + shape = list(tensor.shape) + dtype = dtype_abbrs[tensor.dtype] + stride = tensor.stride() if not tensor.is_contiguous() else None + + if stride: + return f"T({shape}, {dtype}, {list(stride)})" + else: + return f"T({shape}, {dtype})" + + +def _serialize_value(value): + """Helper function to serialize any value (tensor, list, primitive)""" + if isinstance(value, torch.Tensor): + return _serialize_tensor(value) + elif isinstance(value, list): + list_parts = [_serialize_value(item) for item in value] + return f"[{', '.join(list_parts)}]" + else: + return repr(value) + + +def serialize_args(args, kwargs) -> str: + """Convert args and kwargs back to the BackendBench string format + + Args: + args: List of arguments (can contain tensors, lists, primitives) + kwargs: Dict of keyword arguments + + Returns: + Serialized string in format: (arg1, arg2, ..., key1=val1, key2=val2, ...) + """ + if args is None or kwargs is None: + return "None" + + # Process positional arguments + parts = [_serialize_value(arg) for arg in args] + + # Process keyword arguments + kwargs_parts = [f"'{key}': {_serialize_value(val)}" for key, val in kwargs.items()] + + # Handle empty args tuple properly + args_str = f"({', '.join(parts)},)" if parts else "()" + + return f"({args_str}, {{{', '.join(kwargs_parts)}}})" + + +def deserialize_args(inps): + inps = inps.strip().strip("'") + global_vals = { + "T": _deserialize_tensor, + "th": torch, + "inf": math.inf, + "torch": torch, + **dtype_abbrs_parsing, + } + # f strings introduce quotations we dont want + for key in dtype_abbrs_parsing: + inps = inps.replace(f"'{key}'", key) + return eval(inps.strip().strip("'").strip('"'), global_vals) From 32d52d1f345ef4c0c643226b99d1a072744af03f Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 31 Jul 2025 10:01:37 -0700 Subject: [PATCH 15/32] rebase fix --- test/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 30f8162..bd30141 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2,12 +2,12 @@ import math import pytest -from BackendBench.scripts.utils import ( +from BackendBench.utils import ( serialize_args, deserialize_args, _deserialize_tensor, + uses_cuda_stream, ) -from BackendBench.utils import uses_cuda_stream # Check if CUDA is available HAS_CUDA = torch.cuda.is_available() From 1c1824748495734a1f39ac7771470748bdae3687 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 31 Jul 2025 16:18:12 -0700 Subject: [PATCH 16/32] Adding parquet file --- .gitignore | 1 + .../scripts/parquet_trace_converter.py | 114 ++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 BackendBench/scripts/parquet_trace_converter.py diff --git a/.gitignore b/.gitignore index 8438e87..746aa5f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ backendbench.egg-info/ CLAUDE.md venv/ ops/ +datasets/ diff --git a/BackendBench/scripts/parquet_trace_converter.py b/BackendBench/scripts/parquet_trace_converter.py new file mode 100644 index 0000000..2e51e4b --- /dev/null +++ b/BackendBench/scripts/parquet_trace_converter.py @@ -0,0 +1,114 @@ +# utility functions to convert parquet and trace files back and forth + +import pyarrow.parquet as pq +import pyarrow.csv as csv +import pyarrow as pa +from BackendBench.torchbench_suite import DEFAULT_HUGGINGFACE_URL, _args_size +from BackendBench.utils import deserialize_args +import os +import requests +import tempfile +from pathlib import Path +import hashlib +import re +from tqdm import tqdm + +""" +For the dataset release we generally would want to versions +1. A production version which has what you would want to run a benchmark with an llm +2. A "dev" version. This version is much more verbose, has more information on each test, includes tests/ops we decided to axe (and why they were axed), and possibly some runtime numbers + +The point of 1 is for something to have folks able to benchmark their agents against. Therefore, there is a high quality bar for inclusion +At the end of the day we still need solutions to be general for inclusion in pytorch, therefore, the mroe verbose dev version is useful in this case. It also allows us to record information on the ops and decisions as well + +Columns for the production version: +- uuid (int) (hash of op + args) +- op_name (string) +- args (string) +- arg size (float)(in MB) +- count (int) (number of times this op + set of args was called in real models) +- is_synthetic (boolean) (did we generate this op or is it from a real model) + + +Columns for the dev version: +All columns in the production version, plus: +- include_in_prod (boolean) +- why_excluded (string) (empty if included) +- runtime_ms (float) (timings on H100 gpu) +- runnable (bool) (does this op + test work) [we may remove this column later after we solve for special ops] +- in_models (string) (which models did we include this op in) +""" + +def _parse_trace(filename): + + # given a trace file it returns a list of dicts which include + # uuid, op_name, args, arg_size, count + + op_inputs = [] + + with open(filename, "r") as f: + for line in tqdm(f, desc="Parsing trace file"): + if m := re.match("Operator: (.*)", line): + op = m.group(1) + if op == "aten.sum.SymInt": + op = "aten.sum.dim_IntList" + if m := re.match("cnt: \\d+, (.*)", line): + assert op is not None + args_str = m.group(1) + # extract cnt value from group 0 + cnt = int(m.group(0).split(",")[0].split(":")[1]) + args, kwargs = deserialize_args(args_str) + size = _args_size(args) + _args_size(list(kwargs.values())) + # convert size to MB from bytes + size = size / (1024 * 1024) + # if cnt is 0 then it is synthetic + is_synthetic = cnt == 0 + op_inputs.append({ + "uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(), + "op_name": op, + "args": args_str, + "arg_size": size, + "count": cnt, + "is_synthetic": is_synthetic, + }) + return op_inputs + +def convert_trace_to_parquets(trace_file, prod_parquet_file=None, dev_parquet_file=None): + """ + Convert a trace file to a parquet file + """ + + ops = [] + + # Check if filename is a URL + if isinstance(trace_file, str) and ( + trace_file.startswith("http://") or trace_file.startswith("https://") + ): + with ( + tempfile.NamedTemporaryFile(mode="w+", suffix=".txt", delete=False) as tmp_file, + requests.get(trace_file) as response, + ): + response.raise_for_status() + tmp_file.write(response.text) + tmp_file.flush() + ops.extend(_parse_trace(tmp_file.name)) + Path(tmp_file.name).unlink(missing_ok=True) + elif Path(trace_file).is_dir(): + for file_path in Path(trace_file).glob("**/*.txt"): + ops.extend(_parse_trace(str(file_path))) + else: + ops.extend(_parse_trace(trace_file)) + + # create dict for dev version + print(ops) + + +def convert_parquet_to_trace(parquet_file, trace_file): + """ + Convert a parquet file to a trace file + """ + pass + +if __name__ == "__main__": + file_path = DEFAULT_HUGGINGFACE_URL + convert_trace_to_parquets(file_path, "prod.parquet", "dev.parquet") \ No newline at end of file From 1ecb1f7f3d1c6ed56c8efc64584d37d6bd7f0acb Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Fri, 1 Aug 2025 11:47:02 -0700 Subject: [PATCH 17/32] filtering logic --- .../scripts/parquet_trace_converter.py | 179 +++++++++++++++++- BackendBench/torchbench_suite.py | 31 +-- 2 files changed, 187 insertions(+), 23 deletions(-) diff --git a/BackendBench/scripts/parquet_trace_converter.py b/BackendBench/scripts/parquet_trace_converter.py index 2e51e4b..a982793 100644 --- a/BackendBench/scripts/parquet_trace_converter.py +++ b/BackendBench/scripts/parquet_trace_converter.py @@ -12,6 +12,9 @@ import hashlib import re from tqdm import tqdm +from BackendBench.torchbench_suite import SKIP_OPERATORS +import logging +import click """ For the dataset release we generally would want to versions @@ -32,13 +35,88 @@ Columns for the dev version: All columns in the production version, plus: -- include_in_prod (boolean) -- why_excluded (string) (empty if included) +- included_in_benchmark (boolean) +- why_excluded (list of strings) (empty if included) - runtime_ms (float) (timings on H100 gpu) - runnable (bool) (does this op + test work) [we may remove this column later after we solve for special ops] -- in_models (string) (which models did we include this op in) +- in_models (string) (which models did we include this op in) [@TODO add this] """ +logger = logging.getLogger(__name__) + + +def setup_logging(log_level): + """Configure logging with the specified level.""" + numeric_level = getattr(logging, log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f"Invalid log level: {log_level}") + + logging.basicConfig( + level=numeric_level, + format="[%(asctime)s][%(levelname)s][%(filename)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + +# Memory and view operations - create copies or views of tensors +MEMORY_VIEW_OPS = [ + "copy", + "view", + "clone", + "as_strided_", +] + +# Tensor creation and initialization operations +TENSOR_CREATION_OPS = [ + "fill", + "ones", + "zeros", + "empty", + "full", +] + +# Shape manipulation operations - change tensor structure +SHAPE_MANIPULATION_OPS = [ + "cat", + "repeat", + "roll", # @NOTE: I'm also not sure about aten.roll.default + "unbind", +] + +# Element-wise predicates and boolean operations +PREDICATE_OPS = [ + "any", # @NOTE: I don't think this is intereting as I'm unsure how'd it'd be optimized + "isinf", # @NOTE: Similar to any I'm not sure about this one + "isnan", # @NOTE: Similar to any I'm not sure about this one + "nonzero", # @NOTE: I'm also not sure about aten.nonzero.default + "where", +] + + +def _apply_skip_ops_filter(ops): + for op in ops: + if any(skip_op in op["op_name"] for skip_op in SKIP_OPERATORS): + op["included_in_benchmark"] = False + op["runnable"] = False + op["why_excluded"].append("Operation is not runnable in BackendBench yet.") + return ops + + +def _apply_non_interesting_ops_filter(ops): + for op in ops: + if any(skip_op in op["op_name"] for skip_op in MEMORY_VIEW_OPS): + op["included_in_benchmark"] = False + op["why_excluded"].append("Memory view ops are excluded from the benchmark.") + if any(skip_op in op["op_name"] for skip_op in TENSOR_CREATION_OPS): + op["included_in_benchmark"] = False + op["why_excluded"].append("Tensor creation ops are excluded from the benchmark.") + if any(skip_op in op["op_name"] for skip_op in SHAPE_MANIPULATION_OPS): + op["included_in_benchmark"] = False + op["why_excluded"].append("Shape manipulation ops are excluded from the benchmark.") + if any(skip_op in op["op_name"] for skip_op in PREDICATE_OPS): + op["included_in_benchmark"] = False + op["why_excluded"].append("Predicate ops are excluded from the benchmark.") + return ops + def _parse_trace(filename): # given a trace file it returns a list of dicts which include @@ -70,6 +148,10 @@ def _parse_trace(filename): "arg_size": size, "count": cnt, "is_synthetic": is_synthetic, + "included_in_benchmark": True, + "why_excluded": [], + "runtime_ms": 0, + "runnable": True, }) return op_inputs @@ -99,16 +181,97 @@ def convert_trace_to_parquets(trace_file, prod_parquet_file=None, dev_parquet_fi else: ops.extend(_parse_trace(trace_file)) - # create dict for dev version - print(ops) + # apply filters + ops = _apply_skip_ops_filter(ops) + ops = _apply_non_interesting_ops_filter(ops) + + # create prod dict + prod_ops = [op for op in ops if op["included_in_benchmark"]] + dev_table = pa.Table.from_pydict(ops) + pq.write_table(dev_table, dev_parquet_file) + + prod_table = pa.Table.from_pydict(prod_ops) + pq.write_table(prod_table, prod_parquet_file) def convert_parquet_to_trace(parquet_file, trace_file): """ Convert a parquet file to a trace file """ - pass + table = pq.read_table(parquet_file) + op_inputs = {} + # go through each row and add to op_inputs + for row in table: + formatted_entry = f"cnt: {row['count']}, {row['args']}" + op_inputs[row["op_name"]] = formatted_entry + # write to trace file + with open(trace_file, "w") as f: + for op, args in op_inputs.items(): + f.write(f"Operator: {op}\n") + for arg in args: + f.write(f"{arg}\n") + +@click.command() +@click.option( + "--log-level", + default=os.getenv("LOG_LEVEL", "INFO"), + type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False), + help="Set the logging level", +) +@click.option( + "--trace-file", + default=DEFAULT_HUGGINGFACE_URL, + type=str, + help="Path to trace file (can be URL, file path, or directory)", +) +@click.option( + "--prod-parquet", + default="prod.parquet", + type=str, + help="Output path for production parquet file", +) +@click.option( + "--dev-parquet", + default="dev.parquet", + type=str, + help="Output path for dev parquet file", +) +@click.option( + "--mode", + default="trace-to-parquet", + type=click.Choice(["trace-to-parquet", "parquet-to-trace"]), + help="Conversion mode", +) +@click.option( + "--parquet-file", + default=None, + type=str, + help="Input parquet file path (for parquet-to-trace mode)", +) +@click.option( + "--output-trace", + default="output.txt", + type=str, + help="Output trace file path (for parquet-to-trace mode)", +) +def main(log_level, trace_file, prod_parquet, dev_parquet, mode, parquet_file, output_trace): + """Convert trace files to parquet format or vice versa.""" + setup_logging(log_level) + + if mode == "trace-to-parquet": + logger.info(f"Converting trace file {trace_file} to parquet files") + logger.info(f"Production parquet: {prod_parquet}") + logger.info(f"Dev parquet: {dev_parquet}") + convert_trace_to_parquets(trace_file, prod_parquet, dev_parquet) + logger.info("Conversion completed successfully") + elif mode == "parquet-to-trace": + if parquet_file is None: + logger.error("--parquet-file is required for parquet-to-trace mode") + return + logger.info(f"Converting parquet file {parquet_file} to trace file {output_trace}") + convert_parquet_to_trace(parquet_file, output_trace) + logger.info("Conversion completed successfully") + if __name__ == "__main__": - file_path = DEFAULT_HUGGINGFACE_URL - convert_trace_to_parquets(file_path, "prod.parquet", "dev.parquet") \ No newline at end of file + main() \ No newline at end of file diff --git a/BackendBench/torchbench_suite.py b/BackendBench/torchbench_suite.py index 7aa91e7..3ab0eee 100644 --- a/BackendBench/torchbench_suite.py +++ b/BackendBench/torchbench_suite.py @@ -15,7 +15,21 @@ # ie. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/data/input_configs/hf_train/AlbertForMaskedLM_training.txt DEFAULT_HUGGINGFACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/tritonbench_op_trace.txt" - +# Operators to skip for indexing ops that need valid indices +SKIP_OPERATORS = [ + "embedding", + "scatter", + "gather", + "index", + "nll_loss", + "im2col_backward", + "col2im_backward", + "native_layer_norm_backward", + "upsample_nearest2d_backward.vec", + "upsample_bilinear2d_backward.vec", + "_cudnn_rnn_backward.default", # RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM + "_fft_c2c.default", # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision +] class TorchBenchTest: def __init__(self, *args, **kwargs): self.args = args @@ -113,20 +127,7 @@ def __iter__(self): for op, inputs in self.optests.items(): if any( s in op - for s in [ - "embedding", - "scatter", - "gather", - "index", - "nll_loss", - "im2col_backward", - "col2im_backward", - "native_layer_norm_backward", - "upsample_nearest2d_backward.vec", - "upsample_bilinear2d_backward.vec", - "_cudnn_rnn_backward.default", # RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM - "_fft_c2c.default", # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision - ] + for s in SKIP_OPERATORS ): # TODO: indexing ops need valid indices continue From a1bdf7af91741a8a4370aeb56b8d473ed170b486 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Fri, 1 Aug 2025 14:00:43 -0700 Subject: [PATCH 18/32] cleanup --- BackendBench/data_loaders.py | 158 +++++++++++++++++ BackendBench/scripts/dataset_filters.py | 75 +++++++++ .../scripts/parquet_trace_converter.py | 159 +++++++----------- BackendBench/torchbench_suite.py | 87 ++-------- 4 files changed, 309 insertions(+), 170 deletions(-) create mode 100644 BackendBench/data_loaders.py create mode 100644 BackendBench/scripts/dataset_filters.py diff --git a/BackendBench/data_loaders.py b/BackendBench/data_loaders.py new file mode 100644 index 0000000..8ff5a3c --- /dev/null +++ b/BackendBench/data_loaders.py @@ -0,0 +1,158 @@ +""" +Shared data loading utilities for reading trace and parquet files. +""" + +import re +import tempfile +from pathlib import Path +from collections import defaultdict +from typing import Dict, List, Optional, Union + +import requests +import pyarrow.parquet as pq + +from BackendBench.utils import deserialize_args + + +def _args_size(args): + """Calculate the size of arguments in bytes.""" + import torch + size = 0 + for arg in args: + if isinstance(arg, torch.Tensor): + size += arg.numel() * arg.element_size() + elif isinstance(arg, (tuple, list)): + size += _args_size(arg) + return size + + + + + +def _parse_trace_file_simple(filename: str, filter: Optional[List[str]], op_inputs: Dict) -> Dict: + """ + Parse a single trace file for TorchBenchSuite (simpler format). + + Returns defaultdict where keys are op names and values are lists of args strings. + """ + op = None + + with open(filename, "r") as f: + for line in f: + if m := re.match("Operator: (.*)", line): + op = m.group(1) + if op == "aten.sum.SymInt": + op = "aten.sum.dim_IntList" + if m := re.match("cnt: \\d+, (.*)", line): + assert op is not None + args = m.group(1) + if filter is None or any(f in op for f in filter): + op_inputs[op].append(args) + return op_inputs + + +def load_ops_from_source( + source: Union[str, Path], + format: str = "auto", + filter: Optional[List[str]] = None, + simple_format: bool = False +) -> Union[List[Dict], Dict]: + """ + Load operation data from various sources and formats. + + Args: + source: File path, URL, or directory + format: "trace", "parquet", or "auto" (detect from file extension) + filter: Optional list of operation name filters + simple_format: If True, return defaultdict format for TorchBenchSuite compatibility + + Returns: + If simple_format=True: defaultdict with op names as keys, args lists as values + If simple_format=False: List of dictionaries with detailed operation info + + Auto-detection behavior: + - https://domain.com/data.parquet → parquet format + - https://domain.com/data.txt → trace format + - https://domain.com/data → trace format (fallback) + - local_file.parquet → parquet format + - local_file.txt → trace format + - directory_path/ → trace format (scans for .txt files) + """ + + # Auto-detect format if not specified + if format == "auto": + if isinstance(source, str): + # Check file extension first (works for both local files and URLs) + if source.endswith('.parquet'): + format = "parquet" + elif source.endswith('.txt'): + format = "trace" + elif source.startswith(('http://', 'https://')): + # Remote URL without recognizable extension - default to trace + format = "trace" + else: + # Local path - check if it's a directory + if Path(source).is_dir(): + format = "trace" # Directory scan for .txt files + else: + format = "trace" # Default to trace + else: + format = "trace" + + if format == "parquet": + return _load_from_parquet(source, filter, simple_format) + elif format == "trace": + return _load_from_trace(source, filter, simple_format) + else: + raise ValueError(f"Unsupported format: {format}") + + +def _load_from_parquet(source: Union[str, Path], filter: Optional[List[str]], simple_format: bool): + """Load operations from parquet file.""" + table = pq.read_table(source) + + if simple_format: + # Convert to TorchBenchSuite format + op_inputs = defaultdict(list) + for batch in table.to_batches(): + df = batch.to_pandas() + for _, row in df.iterrows(): + op_name = row['op_name'] + if filter is None or any(f in op_name for f in filter): + op_inputs[op_name].append(row['args']) + return op_inputs + else: + # Convert to list of dicts + df = table.to_pandas() + return df.to_dict('records') + + +def _load_from_trace(source: Union[str, Path], filter: Optional[List[str]], simple_format: bool): + """Load operations from trace file(s). Only supports simple_format=True for TorchBenchSuite.""" + if not simple_format: + raise ValueError("Detailed trace parsing has been moved to parquet_trace_converter.py. Use simple_format=True.") + + op_inputs = defaultdict(list) + + # Handle URLs + if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")): + with ( + tempfile.NamedTemporaryFile(mode="w+", suffix=".txt", delete=False) as tmp_file, + requests.get(source) as response, + ): + response.raise_for_status() + tmp_file.write(response.text) + tmp_file.flush() + _parse_trace_file_simple(tmp_file.name, filter, op_inputs) + Path(tmp_file.name).unlink(missing_ok=True) + + # Handle directories + elif Path(source).is_dir(): + for file_path in Path(source).glob("**/*.txt"): + _parse_trace_file_simple(str(file_path), filter, op_inputs) + + # Handle single files + else: + _parse_trace_file_simple(source, filter, op_inputs) + + return op_inputs \ No newline at end of file diff --git a/BackendBench/scripts/dataset_filters.py b/BackendBench/scripts/dataset_filters.py new file mode 100644 index 0000000..92e6536 --- /dev/null +++ b/BackendBench/scripts/dataset_filters.py @@ -0,0 +1,75 @@ +# Operators to skip for indexing ops that need valid indices +SKIP_OPERATORS = [ + "embedding", + "scatter", + "gather", + "index", + "nll_loss", + "im2col_backward", + "col2im_backward", + "native_layer_norm_backward", + "upsample_nearest2d_backward.vec", + "upsample_bilinear2d_backward.vec", + "_cudnn_rnn_backward.default", # RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM + "_fft_c2c.default", # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision +] + +# Memory and view operations - create copies or views of tensors +MEMORY_VIEW_OPS = [ + "copy", + "view", + "clone", + "as_strided_", +] + +# Tensor creation and initialization operations +TENSOR_CREATION_OPS = [ + "fill", + "ones", + "zeros", + "empty", + "full", +] + +# Shape manipulation operations - change tensor structure +SHAPE_MANIPULATION_OPS = [ + "cat", + "repeat", + "roll", # @NOTE: I'm also not sure about aten.roll.default + "unbind", +] + +# Element-wise predicates and boolean operations +PREDICATE_OPS = [ + "any", # @NOTE: I don't think this is intereting as I'm unsure how'd it'd be optimized + "isinf", # @NOTE: Similar to any I'm not sure about this one + "isnan", # @NOTE: Similar to any I'm not sure about this one + "nonzero", # @NOTE: I'm also not sure about aten.nonzero.default + "where", +] + + +def _apply_skip_ops_filter(ops): + for op in ops: + if any(skip_op in op["op_name"] for skip_op in SKIP_OPERATORS): + op["included_in_benchmark"] = False + op["runnable"] = False + op["why_excluded"].append("Operation is not runnable in BackendBench yet.") + return ops + + +def _apply_non_interesting_ops_filter(ops): + for op in ops: + if any(skip_op in op["op_name"] for skip_op in MEMORY_VIEW_OPS): + op["included_in_benchmark"] = False + op["why_excluded"].append("Memory view ops are excluded from the benchmark.") + if any(skip_op in op["op_name"] for skip_op in TENSOR_CREATION_OPS): + op["included_in_benchmark"] = False + op["why_excluded"].append("Tensor creation ops are excluded from the benchmark.") + if any(skip_op in op["op_name"] for skip_op in SHAPE_MANIPULATION_OPS): + op["included_in_benchmark"] = False + op["why_excluded"].append("Shape manipulation ops are excluded from the benchmark.") + if any(skip_op in op["op_name"] for skip_op in PREDICATE_OPS): + op["included_in_benchmark"] = False + op["why_excluded"].append("Predicate ops are excluded from the benchmark.") + return ops diff --git a/BackendBench/scripts/parquet_trace_converter.py b/BackendBench/scripts/parquet_trace_converter.py index a982793..4ed558e 100644 --- a/BackendBench/scripts/parquet_trace_converter.py +++ b/BackendBench/scripts/parquet_trace_converter.py @@ -1,20 +1,21 @@ # utility functions to convert parquet and trace files back and forth import pyarrow.parquet as pq -import pyarrow.csv as csv import pyarrow as pa -from BackendBench.torchbench_suite import DEFAULT_HUGGINGFACE_URL, _args_size +from BackendBench.torchbench_suite import DEFAULT_HUGGINGFACE_URL +from BackendBench.data_loaders import _args_size +from BackendBench.scripts.dataset_filters import _apply_skip_ops_filter, _apply_non_interesting_ops_filter from BackendBench.utils import deserialize_args import os -import requests -import tempfile -from pathlib import Path -import hashlib -import re -from tqdm import tqdm -from BackendBench.torchbench_suite import SKIP_OPERATORS import logging import click +import re +import hashlib +from tqdm import tqdm +import tempfile +import requests +from pathlib import Path +from typing import List, Dict """ For the dataset release we generally would want to versions @@ -57,75 +58,18 @@ def setup_logging(log_level): datefmt="%Y-%m-%d %H:%M:%S", ) -# Memory and view operations - create copies or views of tensors -MEMORY_VIEW_OPS = [ - "copy", - "view", - "clone", - "as_strided_", -] - -# Tensor creation and initialization operations -TENSOR_CREATION_OPS = [ - "fill", - "ones", - "zeros", - "empty", - "full", -] -# Shape manipulation operations - change tensor structure -SHAPE_MANIPULATION_OPS = [ - "cat", - "repeat", - "roll", # @NOTE: I'm also not sure about aten.roll.default - "unbind", -] - -# Element-wise predicates and boolean operations -PREDICATE_OPS = [ - "any", # @NOTE: I don't think this is intereting as I'm unsure how'd it'd be optimized - "isinf", # @NOTE: Similar to any I'm not sure about this one - "isnan", # @NOTE: Similar to any I'm not sure about this one - "nonzero", # @NOTE: I'm also not sure about aten.nonzero.default - "where", -] - - -def _apply_skip_ops_filter(ops): - for op in ops: - if any(skip_op in op["op_name"] for skip_op in SKIP_OPERATORS): - op["included_in_benchmark"] = False - op["runnable"] = False - op["why_excluded"].append("Operation is not runnable in BackendBench yet.") - return ops - - -def _apply_non_interesting_ops_filter(ops): - for op in ops: - if any(skip_op in op["op_name"] for skip_op in MEMORY_VIEW_OPS): - op["included_in_benchmark"] = False - op["why_excluded"].append("Memory view ops are excluded from the benchmark.") - if any(skip_op in op["op_name"] for skip_op in TENSOR_CREATION_OPS): - op["included_in_benchmark"] = False - op["why_excluded"].append("Tensor creation ops are excluded from the benchmark.") - if any(skip_op in op["op_name"] for skip_op in SHAPE_MANIPULATION_OPS): - op["included_in_benchmark"] = False - op["why_excluded"].append("Shape manipulation ops are excluded from the benchmark.") - if any(skip_op in op["op_name"] for skip_op in PREDICATE_OPS): - op["included_in_benchmark"] = False - op["why_excluded"].append("Predicate ops are excluded from the benchmark.") - return ops - -def _parse_trace(filename): +def _parse_trace_file(filename: str) -> List[Dict]: + """ + Parse a single trace file and return a list of operation dictionaries. - # given a trace file it returns a list of dicts which include - # uuid, op_name, args, arg_size, count - + Returns list of dicts with keys: uuid, op_name, args, arg_size, count, is_synthetic + """ op_inputs = [] + op = None with open(filename, "r") as f: - for line in tqdm(f, desc="Parsing trace file"): + for line in tqdm(f, desc=f"Parsing {Path(filename).name}"): if m := re.match("Operator: (.*)", line): op = m.group(1) if op == "aten.sum.SymInt": @@ -148,38 +92,54 @@ def _parse_trace(filename): "arg_size": size, "count": cnt, "is_synthetic": is_synthetic, - "included_in_benchmark": True, - "why_excluded": [], - "runtime_ms": 0, - "runnable": True, }) return op_inputs -def convert_trace_to_parquets(trace_file, prod_parquet_file=None, dev_parquet_file=None): + +def _load_trace_for_parquet_conversion(source: str) -> List[Dict]: """ - Convert a trace file to a parquet file + Load operations from trace file(s) with detailed metadata for parquet conversion. """ - ops = [] - - # Check if filename is a URL - if isinstance(trace_file, str) and ( - trace_file.startswith("http://") or trace_file.startswith("https://") - ): + + # Handle URLs + if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")): with ( tempfile.NamedTemporaryFile(mode="w+", suffix=".txt", delete=False) as tmp_file, - requests.get(trace_file) as response, + requests.get(source) as response, ): response.raise_for_status() tmp_file.write(response.text) tmp_file.flush() - ops.extend(_parse_trace(tmp_file.name)) + ops.extend(_parse_trace_file(tmp_file.name)) Path(tmp_file.name).unlink(missing_ok=True) - elif Path(trace_file).is_dir(): - for file_path in Path(trace_file).glob("**/*.txt"): - ops.extend(_parse_trace(str(file_path))) + + # Handle directories + elif Path(source).is_dir(): + for file_path in Path(source).glob("**/*.txt"): + ops.extend(_parse_trace_file(str(file_path))) + + # Handle single files else: - ops.extend(_parse_trace(trace_file)) + ops.extend(_parse_trace_file(source)) + + return ops + + +def convert_trace_to_parquets(trace_file, prod_parquet_file=None, dev_parquet_file=None): + """ + Convert a trace file to a parquet file + """ + + # Load operations using local trace parsing function + ops = _load_trace_for_parquet_conversion(trace_file) + + # Add additional metadata fields required for the parquet format + for op in ops: + op["included_in_benchmark"] = True + op["why_excluded"] = [] + op["runtime_ms"] = 0 + op["runnable"] = True # apply filters ops = _apply_skip_ops_filter(ops) @@ -193,7 +153,10 @@ def convert_trace_to_parquets(trace_file, prod_parquet_file=None, dev_parquet_fi prod_table = pa.Table.from_pydict(prod_ops) pq.write_table(prod_table, prod_parquet_file) - + + logger.info(f"Wrote {len(prod_ops)} ops and inputs to {prod_parquet_file}") + logger.info(f"Wrote {len(ops)} ops and inputs to {dev_parquet_file}") + def convert_parquet_to_trace(parquet_file, trace_file): """ Convert a parquet file to a trace file @@ -226,13 +189,13 @@ def convert_parquet_to_trace(parquet_file, trace_file): ) @click.option( "--prod-parquet", - default="prod.parquet", + default="backend_bench_problems.parquet", type=str, help="Output path for production parquet file", ) @click.option( "--dev-parquet", - default="dev.parquet", + default="backend_bench_problems_dev.parquet", type=str, help="Output path for dev parquet file", ) @@ -244,13 +207,13 @@ def convert_parquet_to_trace(parquet_file, trace_file): ) @click.option( "--parquet-file", - default=None, + default="datasets/backend_bench_problems.parquet", type=str, help="Input parquet file path (for parquet-to-trace mode)", ) @click.option( "--output-trace", - default="output.txt", + default="datasets/output.txt", type=str, help="Output trace file path (for parquet-to-trace mode)", ) @@ -258,11 +221,13 @@ def main(log_level, trace_file, prod_parquet, dev_parquet, mode, parquet_file, o """Convert trace files to parquet format or vice versa.""" setup_logging(log_level) + os.makedirs("datasets", exist_ok=True) + if mode == "trace-to-parquet": logger.info(f"Converting trace file {trace_file} to parquet files") + convert_trace_to_parquets(trace_file, prod_parquet, dev_parquet) logger.info(f"Production parquet: {prod_parquet}") logger.info(f"Dev parquet: {dev_parquet}") - convert_trace_to_parquets(trace_file, prod_parquet, dev_parquet) logger.info("Conversion completed successfully") elif mode == "parquet-to-trace": if parquet_file is None: diff --git a/BackendBench/torchbench_suite.py b/BackendBench/torchbench_suite.py index 3ab0eee..aa98380 100644 --- a/BackendBench/torchbench_suite.py +++ b/BackendBench/torchbench_suite.py @@ -1,51 +1,22 @@ """ -Load aten inputs from serialized txt files. +Load aten inputs from serialized txt files and parquet files. """ -import re -import tempfile -from collections import defaultdict -from pathlib import Path - -import requests import torch +from collections import defaultdict from BackendBench.utils import deserialize_args +from BackendBench.scripts.dataset_filters import SKIP_OPERATORS +from BackendBench.data_loaders import load_ops_from_source, _args_size # the schema for this dataset is the one defined in tritonbench traces. # ie. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/data/input_configs/hf_train/AlbertForMaskedLM_training.txt DEFAULT_HUGGINGFACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/tritonbench_op_trace.txt" - -# Operators to skip for indexing ops that need valid indices -SKIP_OPERATORS = [ - "embedding", - "scatter", - "gather", - "index", - "nll_loss", - "im2col_backward", - "col2im_backward", - "native_layer_norm_backward", - "upsample_nearest2d_backward.vec", - "upsample_bilinear2d_backward.vec", - "_cudnn_rnn_backward.default", # RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM - "_fft_c2c.default", # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision -] class TorchBenchTest: def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs -def _args_size(args): - size = 0 - for arg in args: - if isinstance(arg, torch.Tensor): - size += arg.numel() * arg.element_size() - elif isinstance(arg, (tuple, list)): - size += _args_size(arg) - return size - - class TorchBenchOpTest: def __init__(self, op, inputs, topn): self.op = eval(f"torch.ops.{op}") @@ -74,61 +45,31 @@ def performance_tests(self): yield TorchBenchTest(*args, **kwargs) -def _parse_inputs(filename, filter, op_inputs): - op = None - - with open(filename, "r") as f: - for line in f: - if m := re.match("Operator: (.*)", line): - op = m.group(1) - if op == "aten.sum.SymInt": - op = "aten.sum.dim_IntList" - if m := re.match("cnt: \\d+, (.*)", line): - assert op is not None - args = m.group(1) - if filter is None or any(f in op for f in filter): - op_inputs[op].append(args) - return op_inputs - - class TorchBenchTestSuite: def __init__(self, name, filename=None, filter=None, topn=None): self.name = name self.topn = topn - self.optests = defaultdict(list) # Use default URL if no filename provided if filename is None: filename = DEFAULT_HUGGINGFACE_URL - # Check if filename is a URL - if isinstance(filename, str) and ( - filename.startswith("http://") or filename.startswith("https://") - ): - with ( - tempfile.NamedTemporaryFile(mode="w+", suffix=".txt", delete=False) as tmp_file, - requests.get(filename) as response, - ): - response.raise_for_status() - tmp_file.write(response.text) - tmp_file.flush() - _parse_inputs(tmp_file.name, filter, self.optests) - Path(tmp_file.name).unlink(missing_ok=True) - elif Path(filename).is_dir(): - for file_path in Path(filename).glob("**/*.txt"): - _parse_inputs(str(file_path), filter, self.optests) - else: - _parse_inputs(filename, filter, self.optests) + # Load operations using the shared data loader + # Use simple_format=True to get the defaultdict format for compatibility + self.optests = load_ops_from_source( + source=filename, + format="auto", # Auto-detect based on file extension + filter=filter, + simple_format=True + ) + # Deduplicate the strings in self.optests for op in self.optests: self.optests[op] = list(set(self.optests[op])) def __iter__(self): for op, inputs in self.optests.items(): - if any( - s in op - for s in SKIP_OPERATORS - ): + if any(s in op for s in SKIP_OPERATORS): # TODO: indexing ops need valid indices continue yield TorchBenchOpTest(op, inputs, self.topn) From 7408e7a0995b85d525011f08fc6342b89e76e191 Mon Sep 17 00:00:00 2001 From: PaliC Date: Fri, 1 Aug 2025 15:09:02 -0700 Subject: [PATCH 19/32] parquet --- BackendBench/data_loaders.py | 56 +++-- .../scripts/parquet_trace_converter.py | 193 ++++++++++++------ BackendBench/torchbench_suite.py | 8 +- 3 files changed, 165 insertions(+), 92 deletions(-) diff --git a/BackendBench/data_loaders.py b/BackendBench/data_loaders.py index 8ff5a3c..896c6e6 100644 --- a/BackendBench/data_loaders.py +++ b/BackendBench/data_loaders.py @@ -10,13 +10,12 @@ import requests import pyarrow.parquet as pq - -from BackendBench.utils import deserialize_args +import torch def _args_size(args): """Calculate the size of arguments in bytes.""" - import torch + size = 0 for arg in args: if isinstance(arg, torch.Tensor): @@ -26,13 +25,10 @@ def _args_size(args): return size - - - def _parse_trace_file_simple(filename: str, filter: Optional[List[str]], op_inputs: Dict) -> Dict: """ Parse a single trace file for TorchBenchSuite (simpler format). - + Returns defaultdict where keys are op names and values are lists of args strings. """ op = None @@ -52,42 +48,42 @@ def _parse_trace_file_simple(filename: str, filter: Optional[List[str]], op_inpu def load_ops_from_source( - source: Union[str, Path], + source: Union[str, Path], format: str = "auto", filter: Optional[List[str]] = None, - simple_format: bool = False + simple_format: bool = False, ) -> Union[List[Dict], Dict]: """ Load operation data from various sources and formats. - + Args: source: File path, URL, or directory format: "trace", "parquet", or "auto" (detect from file extension) filter: Optional list of operation name filters simple_format: If True, return defaultdict format for TorchBenchSuite compatibility - + Returns: If simple_format=True: defaultdict with op names as keys, args lists as values If simple_format=False: List of dictionaries with detailed operation info - + Auto-detection behavior: - https://domain.com/data.parquet → parquet format - - https://domain.com/data.txt → trace format + - https://domain.com/data.txt → trace format - https://domain.com/data → trace format (fallback) - local_file.parquet → parquet format - local_file.txt → trace format - directory_path/ → trace format (scans for .txt files) """ - + # Auto-detect format if not specified if format == "auto": if isinstance(source, str): # Check file extension first (works for both local files and URLs) - if source.endswith('.parquet'): + if source.endswith(".parquet"): format = "parquet" - elif source.endswith('.txt'): + elif source.endswith(".txt"): format = "trace" - elif source.startswith(('http://', 'https://')): + elif source.startswith(("http://", "https://")): # Remote URL without recognizable extension - default to trace format = "trace" else: @@ -98,7 +94,7 @@ def load_ops_from_source( format = "trace" # Default to trace else: format = "trace" - + if format == "parquet": return _load_from_parquet(source, filter, simple_format) elif format == "trace": @@ -110,30 +106,32 @@ def load_ops_from_source( def _load_from_parquet(source: Union[str, Path], filter: Optional[List[str]], simple_format: bool): """Load operations from parquet file.""" table = pq.read_table(source) - + if simple_format: # Convert to TorchBenchSuite format op_inputs = defaultdict(list) for batch in table.to_batches(): df = batch.to_pandas() for _, row in df.iterrows(): - op_name = row['op_name'] + op_name = row["op_name"] if filter is None or any(f in op_name for f in filter): - op_inputs[op_name].append(row['args']) + op_inputs[op_name].append(row["args"]) return op_inputs else: # Convert to list of dicts df = table.to_pandas() - return df.to_dict('records') + return df.to_dict("records") def _load_from_trace(source: Union[str, Path], filter: Optional[List[str]], simple_format: bool): """Load operations from trace file(s). Only supports simple_format=True for TorchBenchSuite.""" if not simple_format: - raise ValueError("Detailed trace parsing has been moved to parquet_trace_converter.py. Use simple_format=True.") - + raise ValueError( + "Detailed trace parsing has been moved to parquet_trace_converter.py. Use simple_format=True." + ) + op_inputs = defaultdict(list) - + # Handle URLs if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")): with ( @@ -145,14 +143,14 @@ def _load_from_trace(source: Union[str, Path], filter: Optional[List[str]], simp tmp_file.flush() _parse_trace_file_simple(tmp_file.name, filter, op_inputs) Path(tmp_file.name).unlink(missing_ok=True) - + # Handle directories elif Path(source).is_dir(): for file_path in Path(source).glob("**/*.txt"): _parse_trace_file_simple(str(file_path), filter, op_inputs) - + # Handle single files else: _parse_trace_file_simple(source, filter, op_inputs) - - return op_inputs \ No newline at end of file + + return op_inputs diff --git a/BackendBench/scripts/parquet_trace_converter.py b/BackendBench/scripts/parquet_trace_converter.py index 4ed558e..5eb87bc 100644 --- a/BackendBench/scripts/parquet_trace_converter.py +++ b/BackendBench/scripts/parquet_trace_converter.py @@ -4,7 +4,10 @@ import pyarrow as pa from BackendBench.torchbench_suite import DEFAULT_HUGGINGFACE_URL from BackendBench.data_loaders import _args_size -from BackendBench.scripts.dataset_filters import _apply_skip_ops_filter, _apply_non_interesting_ops_filter +from BackendBench.scripts.dataset_filters import ( + _apply_skip_ops_filter, + _apply_non_interesting_ops_filter, +) from BackendBench.utils import deserialize_args import os import logging @@ -62,7 +65,7 @@ def setup_logging(log_level): def _parse_trace_file(filename: str) -> List[Dict]: """ Parse a single trace file and return a list of operation dictionaries. - + Returns list of dicts with keys: uuid, op_name, args, arg_size, count, is_synthetic """ op_inputs = [] @@ -85,14 +88,16 @@ def _parse_trace_file(filename: str) -> List[Dict]: size = size / (1024 * 1024) # if cnt is 0 then it is synthetic is_synthetic = cnt == 0 - op_inputs.append({ - "uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(), - "op_name": op, - "args": args_str, - "arg_size": size, - "count": cnt, - "is_synthetic": is_synthetic, - }) + op_inputs.append( + { + "uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(), + "op_name": op, + "args": args_str, + "arg_size": size, + "count": cnt, + "is_synthetic": is_synthetic, + } + ) return op_inputs @@ -101,7 +106,7 @@ def _load_trace_for_parquet_conversion(source: str) -> List[Dict]: Load operations from trace file(s) with detailed metadata for parquet conversion. """ ops = [] - + # Handle URLs if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")): with ( @@ -113,16 +118,16 @@ def _load_trace_for_parquet_conversion(source: str) -> List[Dict]: tmp_file.flush() ops.extend(_parse_trace_file(tmp_file.name)) Path(tmp_file.name).unlink(missing_ok=True) - + # Handle directories elif Path(source).is_dir(): for file_path in Path(source).glob("**/*.txt"): ops.extend(_parse_trace_file(str(file_path))) - + # Handle single files else: ops.extend(_parse_trace_file(source)) - + return ops @@ -136,43 +141,117 @@ def convert_trace_to_parquets(trace_file, prod_parquet_file=None, dev_parquet_fi # Add additional metadata fields required for the parquet format for op in ops: + op["uuid"] = hashlib.sha256(op["args"].encode() + op["op_name"].encode()).hexdigest() op["included_in_benchmark"] = True op["why_excluded"] = [] - op["runtime_ms"] = 0 + op["runtime_ms"] = "" op["runnable"] = True # apply filters ops = _apply_skip_ops_filter(ops) ops = _apply_non_interesting_ops_filter(ops) - # create prod dict - prod_ops = [op for op in ops if op["included_in_benchmark"]] + # create production version (copy ops to avoid modifying original) + prod_ops = [] + for op in ops: + if op["included_in_benchmark"]: + # Create production version without metadata fields + prod_op = { + "uuid": op["uuid"], + "op_name": op["op_name"], + "args": op["args"], + "arg_size": op["arg_size"], + "count": op["count"], + "is_synthetic": op["is_synthetic"], + } + prod_ops.append(prod_op) - dev_table = pa.Table.from_pydict(ops) - pq.write_table(dev_table, dev_parquet_file) + # Create parquet tables with proper headers + dev_table = pa.Table.from_pylist(ops) + prod_table = pa.Table.from_pylist(prod_ops) - prod_table = pa.Table.from_pydict(prod_ops) + # Write parquet files + pq.write_table(dev_table, dev_parquet_file) pq.write_table(prod_table, prod_parquet_file) logger.info(f"Wrote {len(prod_ops)} ops and inputs to {prod_parquet_file}") logger.info(f"Wrote {len(ops)} ops and inputs to {dev_parquet_file}") + # Log column information for verification + logger.debug(f"Production parquet columns: {prod_table.column_names}") + logger.debug(f"Dev parquet columns: {dev_table.column_names}") + + def convert_parquet_to_trace(parquet_file, trace_file): """ Convert a parquet file to a trace file """ table = pq.read_table(parquet_file) op_inputs = {} - # go through each row and add to op_inputs - for row in table: + + for row in table.to_pylist(): formatted_entry = f"cnt: {row['count']}, {row['args']}" - op_inputs[row["op_name"]] = formatted_entry + + if row["op_name"] not in op_inputs: + op_inputs[row["op_name"]] = [] + op_inputs[row["op_name"]].append(formatted_entry) + # write to trace file with open(trace_file, "w") as f: for op, args in op_inputs.items(): f.write(f"Operator: {op}\n") for arg in args: f.write(f"{arg}\n") + total_args = sum(len(op_inputs[op]) for op in op_inputs) + logging.info(f"Wrote {total_args} ops and inputs to {trace_file}") + + +def _validate_parquet_name(parquet_name: str, is_input: bool = False) -> str: + """Validate parquet filename. URLs allowed only for inputs.""" + # URLs are allowed only if this is an input file + if parquet_name.startswith(("http://", "https://")): + if is_input: + return parquet_name # URL allowed for input + else: + raise click.BadParameter("Output parquet file cannot be a URL") + + if not parquet_name.endswith(".parquet"): + raise click.BadParameter("Parquet file must end with .parquet suffix") + + # Ensure local files are in datasets directory + if not parquet_name.startswith("datasets/"): + parquet_name = f"datasets/{parquet_name}" + + return parquet_name + + +def _validate_trace_file(trace_file: str, is_input: bool = True) -> str: + """Validate trace file. URLs allowed only for inputs.""" + # URLs are allowed only if this is an input file + if trace_file.startswith(("http://", "https://")): + if is_input: + return trace_file # URL allowed for input + else: + raise click.BadParameter("Output trace file cannot be a URL") + + # For local files, check extension + if not (trace_file.endswith(".txt") or Path(trace_file).is_dir()): + raise click.BadParameter("Local trace file must end with .txt or be a directory") + + if Path(trace_file).is_dir() and not is_input: + raise click.BadParameter("Output trace file cannot be a directory") + + return trace_file + + +def _generate_dev_parquet_name(parquet_name: str) -> str: + """Generate dev parquet name by appending -dev before .parquet suffix.""" + if parquet_name.endswith(".parquet"): + base_name = parquet_name[:-8] # Remove .parquet + return f"{base_name}-dev.parquet" + else: + return f"{parquet_name}-dev" + @click.command() @click.option( @@ -181,24 +260,6 @@ def convert_parquet_to_trace(parquet_file, trace_file): type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False), help="Set the logging level", ) -@click.option( - "--trace-file", - default=DEFAULT_HUGGINGFACE_URL, - type=str, - help="Path to trace file (can be URL, file path, or directory)", -) -@click.option( - "--prod-parquet", - default="backend_bench_problems.parquet", - type=str, - help="Output path for production parquet file", -) -@click.option( - "--dev-parquet", - default="backend_bench_problems_dev.parquet", - type=str, - help="Output path for dev parquet file", -) @click.option( "--mode", default="trace-to-parquet", @@ -206,37 +267,51 @@ def convert_parquet_to_trace(parquet_file, trace_file): help="Conversion mode", ) @click.option( - "--parquet-file", - default="datasets/backend_bench_problems.parquet", + "--trace-file", + default=DEFAULT_HUGGINGFACE_URL, type=str, - help="Input parquet file path (for parquet-to-trace mode)", + help="Input trace file: URL (for downloads), local .txt file, or directory. Output trace files cannot be URLs", ) @click.option( - "--output-trace", - default="datasets/output.txt", + "--parquet-name", + default="backend_bench_problems-dev.parquet", type=str, - help="Output trace file path (for parquet-to-trace mode)", + help="Parquet filename: URL allowed as input in parquet-to-trace mode, local files in datasets/. Dev version auto-generated as filename-dev.parquet.", ) -def main(log_level, trace_file, prod_parquet, dev_parquet, mode, parquet_file, output_trace): +def main(log_level, mode, trace_file, parquet_name): """Convert trace files to parquet format or vice versa.""" setup_logging(log_level) - + + # Create datasets directory os.makedirs("datasets", exist_ok=True) - + if mode == "trace-to-parquet": + # Validate inputs/outputs + trace_file = _validate_trace_file(trace_file, is_input=True) # Input: URLs allowed + parquet_name = _validate_parquet_name( + parquet_name, is_input=False + ) # Output: URLs not allowed + + # Generate dev version name + dev_parquet_name = _generate_dev_parquet_name(parquet_name) + logger.info(f"Converting trace file {trace_file} to parquet files") - convert_trace_to_parquets(trace_file, prod_parquet, dev_parquet) - logger.info(f"Production parquet: {prod_parquet}") - logger.info(f"Dev parquet: {dev_parquet}") + logger.info(f"Production parquet: {parquet_name}") + logger.info(f"Dev parquet: {dev_parquet_name}") + + convert_trace_to_parquets(trace_file, parquet_name, dev_parquet_name) logger.info("Conversion completed successfully") + elif mode == "parquet-to-trace": - if parquet_file is None: - logger.error("--parquet-file is required for parquet-to-trace mode") - return - logger.info(f"Converting parquet file {parquet_file} to trace file {output_trace}") - convert_parquet_to_trace(parquet_file, output_trace) + # Validate parquet input (URLs allowed for input in this mode) + parquet_input = _validate_parquet_name(parquet_name, is_input=True) # Input: URLs allowed + # Validate trace output (URLs not allowed for output) + trace_output = _validate_trace_file(trace_file, is_input=False) # Output: URLs not allowed + + logger.info(f"Converting parquet file {parquet_input} to trace file {trace_output}") + convert_parquet_to_trace(parquet_input, trace_output) logger.info("Conversion completed successfully") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/BackendBench/torchbench_suite.py b/BackendBench/torchbench_suite.py index aa98380..153201f 100644 --- a/BackendBench/torchbench_suite.py +++ b/BackendBench/torchbench_suite.py @@ -2,8 +2,6 @@ Load aten inputs from serialized txt files and parquet files. """ -import torch -from collections import defaultdict from BackendBench.utils import deserialize_args from BackendBench.scripts.dataset_filters import SKIP_OPERATORS from BackendBench.data_loaders import load_ops_from_source, _args_size @@ -11,6 +9,8 @@ # the schema for this dataset is the one defined in tritonbench traces. # ie. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/data/input_configs/hf_train/AlbertForMaskedLM_training.txt DEFAULT_HUGGINGFACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/tritonbench_op_trace.txt" + + class TorchBenchTest: def __init__(self, *args, **kwargs): self.args = args @@ -60,9 +60,9 @@ def __init__(self, name, filename=None, filter=None, topn=None): source=filename, format="auto", # Auto-detect based on file extension filter=filter, - simple_format=True + simple_format=True, ) - + # Deduplicate the strings in self.optests for op in self.optests: self.optests[op] = list(set(self.optests[op])) From f535d8a0afa447cf038814bb91de5783d962e13f Mon Sep 17 00:00:00 2001 From: PaliC Date: Fri, 1 Aug 2025 15:12:19 -0700 Subject: [PATCH 20/32] udpate deps --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5b27108..9286964 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "anthropic>=0.34.0", "pytest", "requests", + "pyarrow", ] [project.optional-dependencies] @@ -40,12 +41,13 @@ packages = ["BackendBench"] [tool.uv] dev-dependencies = [ "pytest", - "pytest-cov", + "pytest-cov", "pytest-mock", "pytest-timeout", "ruff==0.12.1", "torch", "numpy", + "pyarrow", # cupy-cuda12x is platform specific, install manually if needed ] From a68fbda8d9a83c242e75751567224d2f33898974 Mon Sep 17 00:00:00 2001 From: PaliC Date: Fri, 1 Aug 2025 15:15:52 -0700 Subject: [PATCH 21/32] undo lint --- BackendBench/scripts/parquet_trace_converter.py | 2 +- BackendBench/torchbench_suite.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/BackendBench/scripts/parquet_trace_converter.py b/BackendBench/scripts/parquet_trace_converter.py index 5eb87bc..0cb3cb7 100644 --- a/BackendBench/scripts/parquet_trace_converter.py +++ b/BackendBench/scripts/parquet_trace_converter.py @@ -26,7 +26,7 @@ 2. A "dev" version. This version is much more verbose, has more information on each test, includes tests/ops we decided to axe (and why they were axed), and possibly some runtime numbers The point of 1 is for something to have folks able to benchmark their agents against. Therefore, there is a high quality bar for inclusion -At the end of the day we still need solutions to be general for inclusion in pytorch, therefore, the mroe verbose dev version is useful in this case. It also allows us to record information on the ops and decisions as well +At the end of the day we still need solutions to be general for inclusion in pytorch, therefore, the more verbose dev version is useful in this case. It also allows us to record information on the ops and decisions as well Columns for the production version: - uuid (int) (hash of op + args) diff --git a/BackendBench/torchbench_suite.py b/BackendBench/torchbench_suite.py index 153201f..dcfd05f 100644 --- a/BackendBench/torchbench_suite.py +++ b/BackendBench/torchbench_suite.py @@ -5,6 +5,7 @@ from BackendBench.utils import deserialize_args from BackendBench.scripts.dataset_filters import SKIP_OPERATORS from BackendBench.data_loaders import load_ops_from_source, _args_size +import torch # noqa: F401 # the schema for this dataset is the one defined in tritonbench traces. # ie. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/data/input_configs/hf_train/AlbertForMaskedLM_training.txt From a58f0d885a19a871daadc42944129a8392165e4d Mon Sep 17 00:00:00 2001 From: PaliC Date: Fri, 1 Aug 2025 15:55:07 -0700 Subject: [PATCH 22/32] update hf upload --- .../scripts/parquet_trace_converter.py | 58 ++++++++++++++----- BackendBench/torchbench_suite.py | 6 +- pyproject.toml | 1 + 3 files changed, 47 insertions(+), 18 deletions(-) diff --git a/BackendBench/scripts/parquet_trace_converter.py b/BackendBench/scripts/parquet_trace_converter.py index 0cb3cb7..0f1ace1 100644 --- a/BackendBench/scripts/parquet_trace_converter.py +++ b/BackendBench/scripts/parquet_trace_converter.py @@ -1,24 +1,26 @@ # utility functions to convert parquet and trace files back and forth -import pyarrow.parquet as pq +import hashlib +import logging +import os +import re +import tempfile +from pathlib import Path +from typing import Dict, List + +import click import pyarrow as pa -from BackendBench.torchbench_suite import DEFAULT_HUGGINGFACE_URL +import pyarrow.parquet as pq +import requests from BackendBench.data_loaders import _args_size from BackendBench.scripts.dataset_filters import ( - _apply_skip_ops_filter, _apply_non_interesting_ops_filter, + _apply_skip_ops_filter, ) +from BackendBench.torchbench_suite import DEFAULT_HUGGINGFACE_URL from BackendBench.utils import deserialize_args -import os -import logging -import click -import re -import hashlib +from huggingface_hub import HfApi from tqdm import tqdm -import tempfile -import requests -from pathlib import Path -from typing import List, Dict """ For the dataset release we generally would want to versions @@ -49,6 +51,21 @@ logger = logging.getLogger(__name__) +def _upload_to_hf(file_path: str) -> None: + """Upload file to GPUMODE/huggingface_op_trace.""" + try: + api = HfApi() + api.upload_file( + path_or_fileobj=file_path, + path_in_repo=Path(file_path).name, + repo_id="GPUMODE/huggingface_op_trace", + repo_type="dataset", + ) + logger.info(f"Uploaded {Path(file_path).name} to Hugging Face") + except Exception as e: + logger.warning(f"Failed to upload {Path(file_path).name}: {e}") + + def setup_logging(log_level): """Configure logging with the specified level.""" numeric_level = getattr(logging, log_level.upper(), None) @@ -220,7 +237,7 @@ def _validate_parquet_name(parquet_name: str, is_input: bool = False) -> str: # Ensure local files are in datasets directory if not parquet_name.startswith("datasets/"): - parquet_name = f"datasets/{parquet_name}" + parquet_name = os.path.join("datasets", parquet_name) return parquet_name @@ -274,11 +291,17 @@ def _generate_dev_parquet_name(parquet_name: str) -> str: ) @click.option( "--parquet-name", - default="backend_bench_problems-dev.parquet", + default="backend_bench_problems.parquet", type=str, help="Parquet filename: URL allowed as input in parquet-to-trace mode, local files in datasets/. Dev version auto-generated as filename-dev.parquet.", ) -def main(log_level, mode, trace_file, parquet_name): +@click.option( + "--upload-to-hf", + is_flag=True, + default=False, + help="Upload generated parquet files to Hugging Face (GPUMODE/huggingface_op_trace) in trace-to-parquet mode", +) +def main(log_level, mode, trace_file, parquet_name, upload_to_hf): """Convert trace files to parquet format or vice versa.""" setup_logging(log_level) @@ -302,6 +325,11 @@ def main(log_level, mode, trace_file, parquet_name): convert_trace_to_parquets(trace_file, parquet_name, dev_parquet_name) logger.info("Conversion completed successfully") + if upload_to_hf: + # Upload to Hugging Face + _upload_to_hf(os.path.abspath(parquet_name)) + _upload_to_hf(os.path.abspath(dev_parquet_name)) + elif mode == "parquet-to-trace": # Validate parquet input (URLs allowed for input in this mode) parquet_input = _validate_parquet_name(parquet_name, is_input=True) # Input: URLs allowed diff --git a/BackendBench/torchbench_suite.py b/BackendBench/torchbench_suite.py index dcfd05f..a9d01eb 100644 --- a/BackendBench/torchbench_suite.py +++ b/BackendBench/torchbench_suite.py @@ -2,10 +2,10 @@ Load aten inputs from serialized txt files and parquet files. """ -from BackendBench.utils import deserialize_args +import torch # noqa: F401 +from BackendBench.data_loaders import _args_size, load_ops_from_source from BackendBench.scripts.dataset_filters import SKIP_OPERATORS -from BackendBench.data_loaders import load_ops_from_source, _args_size -import torch # noqa: F401 +from BackendBench.utils import deserialize_args # the schema for this dataset is the one defined in tritonbench traces. # ie. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/data/input_configs/hf_train/AlbertForMaskedLM_training.txt diff --git a/pyproject.toml b/pyproject.toml index 9286964..fa66ec1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "pytest", "requests", "pyarrow", + "huggingface_hub", ] [project.optional-dependencies] From e8c5d1a6cef668f6eef5d674f94d39301ee7c5f4 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Wed, 13 Aug 2025 13:50:01 -0700 Subject: [PATCH 23/32] Mark's comments --- BackendBench/scripts/dataset_filters.py | 37 +++++++++++-------------- pyproject.toml | 4 ++- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/BackendBench/scripts/dataset_filters.py b/BackendBench/scripts/dataset_filters.py index 92e6536..9a6b783 100644 --- a/BackendBench/scripts/dataset_filters.py +++ b/BackendBench/scripts/dataset_filters.py @@ -35,18 +35,16 @@ SHAPE_MANIPULATION_OPS = [ "cat", "repeat", - "roll", # @NOTE: I'm also not sure about aten.roll.default + "roll", "unbind", ] -# Element-wise predicates and boolean operations -PREDICATE_OPS = [ - "any", # @NOTE: I don't think this is intereting as I'm unsure how'd it'd be optimized - "isinf", # @NOTE: Similar to any I'm not sure about this one - "isnan", # @NOTE: Similar to any I'm not sure about this one - "nonzero", # @NOTE: I'm also not sure about aten.nonzero.default - "where", -] + +def _apply_op_name_filter(op, filter, why_excluded_msg): + if any(skip_op in op["op_name"] for skip_op in filter): + op["included_in_benchmark"] = False + op["why_excluded"].append(why_excluded_msg) + return op def _apply_skip_ops_filter(ops): @@ -60,16 +58,13 @@ def _apply_skip_ops_filter(ops): def _apply_non_interesting_ops_filter(ops): for op in ops: - if any(skip_op in op["op_name"] for skip_op in MEMORY_VIEW_OPS): - op["included_in_benchmark"] = False - op["why_excluded"].append("Memory view ops are excluded from the benchmark.") - if any(skip_op in op["op_name"] for skip_op in TENSOR_CREATION_OPS): - op["included_in_benchmark"] = False - op["why_excluded"].append("Tensor creation ops are excluded from the benchmark.") - if any(skip_op in op["op_name"] for skip_op in SHAPE_MANIPULATION_OPS): - op["included_in_benchmark"] = False - op["why_excluded"].append("Shape manipulation ops are excluded from the benchmark.") - if any(skip_op in op["op_name"] for skip_op in PREDICATE_OPS): - op["included_in_benchmark"] = False - op["why_excluded"].append("Predicate ops are excluded from the benchmark.") + op = _apply_op_name_filter( + op, MEMORY_VIEW_OPS, "Memory view ops are excluded from the benchmark." + ) + op = _apply_op_name_filter( + op, TENSOR_CREATION_OPS, "Tensor creation ops are excluded from the benchmark." + ) + op = _apply_op_name_filter( + op, SHAPE_MANIPULATION_OPS, "Shape manipulation ops are excluded from the benchmark." + ) return ops diff --git a/pyproject.toml b/pyproject.toml index fa66ec1..4a05373 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ dependencies = [ "anthropic>=0.34.0", "pytest", "requests", - "pyarrow", "huggingface_hub", ] @@ -32,6 +31,9 @@ dependencies = [ flaggems = [ # flag_gems must be installed from source: https://github.com/FlagOpen/FlagGems ] +pyarrow = [ + "pyarrow", +] [project.scripts] backendbench = "BackendBench.scripts.main:cli" From d25c2d39009b9d732bb44f0e16a559a7c45e33f9 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Wed, 13 Aug 2025 13:57:32 -0700 Subject: [PATCH 24/32] lint --- BackendBench/torchbench_suite.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/BackendBench/torchbench_suite.py b/BackendBench/torchbench_suite.py index 1ca44a2..19fb886 100644 --- a/BackendBench/torchbench_suite.py +++ b/BackendBench/torchbench_suite.py @@ -11,22 +11,6 @@ # ie. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/data/input_configs/hf_train/AlbertForMaskedLM_training.txt DEFAULT_HUGGINGFACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/raw/main/augmented_hf_op_traces.txt" -# Operators to skip for indexing ops that need valid indices -SKIP_OPERATORS = [ - "embedding", - "scatter", - "gather", - "index", - "nll_loss", - "im2col_backward", - "col2im_backward", - "native_layer_norm_backward", - "upsample_nearest2d_backward.vec", - "upsample_bilinear2d_backward.vec", - "_cudnn_rnn_backward.default", # RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM - "_fft_c2c.default", # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision -] - class TorchBenchTest: def __init__(self, *args, **kwargs): From 9ae0cacb7ce56dd6367128331c9740a7e7859ff7 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Wed, 13 Aug 2025 16:34:40 -0700 Subject: [PATCH 25/32] stream from urls --- BackendBench/data_loaders.py | 38 +++++++--- .../scripts/parquet_trace_converter.py | 70 +++++++++++++------ 2 files changed, 78 insertions(+), 30 deletions(-) diff --git a/BackendBench/data_loaders.py b/BackendBench/data_loaders.py index 896c6e6..ecbf629 100644 --- a/BackendBench/data_loaders.py +++ b/BackendBench/data_loaders.py @@ -3,7 +3,6 @@ """ import re -import tempfile from pathlib import Path from collections import defaultdict from typing import Dict, List, Optional, Union @@ -47,6 +46,31 @@ def _parse_trace_file_simple(filename: str, filter: Optional[List[str]], op_inpu return op_inputs +def _parse_trace_stream(stream, filter: Optional[List[str]], op_inputs: Dict) -> Dict: + """ + Parse trace data from a text stream (e.g., from requests.Response.iter_lines()). + + Returns defaultdict where keys are op names and values are lists of args strings. + """ + op = None + + for line in stream: + # Handle bytes from response stream + if isinstance(line, bytes): + line = line.decode("utf-8") + + if m := re.match("Operator: (.*)", line): + op = m.group(1) + if op == "aten.sum.SymInt": + op = "aten.sum.dim_IntList" + if m := re.match("cnt: \\d+, (.*)", line): + assert op is not None + args = m.group(1) + if filter is None or any(f in op for f in filter): + op_inputs[op].append(args) + return op_inputs + + def load_ops_from_source( source: Union[str, Path], format: str = "auto", @@ -132,17 +156,11 @@ def _load_from_trace(source: Union[str, Path], filter: Optional[List[str]], simp op_inputs = defaultdict(list) - # Handle URLs + # Handle URLs - stream directly without saving to disk if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")): - with ( - tempfile.NamedTemporaryFile(mode="w+", suffix=".txt", delete=False) as tmp_file, - requests.get(source) as response, - ): + with requests.get(source, stream=True) as response: response.raise_for_status() - tmp_file.write(response.text) - tmp_file.flush() - _parse_trace_file_simple(tmp_file.name, filter, op_inputs) - Path(tmp_file.name).unlink(missing_ok=True) + _parse_trace_stream(response.iter_lines(), filter, op_inputs) # Handle directories elif Path(source).is_dir(): diff --git a/BackendBench/scripts/parquet_trace_converter.py b/BackendBench/scripts/parquet_trace_converter.py index 0f1ace1..8edc562 100644 --- a/BackendBench/scripts/parquet_trace_converter.py +++ b/BackendBench/scripts/parquet_trace_converter.py @@ -4,7 +4,6 @@ import logging import os import re -import tempfile from pathlib import Path from typing import Dict, List @@ -118,23 +117,59 @@ def _parse_trace_file(filename: str) -> List[Dict]: return op_inputs +def _parse_trace_stream(stream, desc: str = "Parsing stream") -> List[Dict]: + """ + Parse trace data from a text stream (e.g., from requests.Response.iter_lines()). + + Returns list of dicts with keys: uuid, op_name, args, arg_size, count, is_synthetic + """ + op_inputs = [] + op = None + + for line in tqdm(stream, desc=desc): + # Handle bytes from response stream + if isinstance(line, bytes): + line = line.decode("utf-8") + + if m := re.match("Operator: (.*)", line): + op = m.group(1) + if op == "aten.sum.SymInt": + op = "aten.sum.dim_IntList" + if m := re.match("cnt: \\d+, (.*)", line): + assert op is not None + args_str = m.group(1) + # extract cnt value from group 0 + cnt = int(m.group(0).split(",")[0].split(":")[1]) + args, kwargs = deserialize_args(args_str) + size = _args_size(args) + _args_size(list(kwargs.values())) + # convert size to MB from bytes + size = size / (1024 * 1024) + # if cnt is 0 then it is synthetic + is_synthetic = cnt == 0 + op_inputs.append( + { + "uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(), + "op_name": op, + "args": args_str, + "arg_size": size, + "count": cnt, + "is_synthetic": is_synthetic, + } + ) + return op_inputs + + def _load_trace_for_parquet_conversion(source: str) -> List[Dict]: """ Load operations from trace file(s) with detailed metadata for parquet conversion. """ ops = [] - # Handle URLs + # Handle URLs - stream directly without saving to disk if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")): - with ( - tempfile.NamedTemporaryFile(mode="w+", suffix=".txt", delete=False) as tmp_file, - requests.get(source) as response, - ): + with requests.get(source, stream=True) as response: response.raise_for_status() - tmp_file.write(response.text) - tmp_file.flush() - ops.extend(_parse_trace_file(tmp_file.name)) - Path(tmp_file.name).unlink(missing_ok=True) + ops.extend(_parse_trace_stream(response.iter_lines(), desc=f"Parsing {source}")) # Handle directories elif Path(source).is_dir(): @@ -223,14 +258,11 @@ def convert_parquet_to_trace(parquet_file, trace_file): logging.info(f"Wrote {total_args} ops and inputs to {trace_file}") -def _validate_parquet_name(parquet_name: str, is_input: bool = False) -> str: +def _validate_parquet_name(parquet_name: str) -> str: """Validate parquet filename. URLs allowed only for inputs.""" # URLs are allowed only if this is an input file if parquet_name.startswith(("http://", "https://")): - if is_input: - return parquet_name # URL allowed for input - else: - raise click.BadParameter("Output parquet file cannot be a URL") + raise click.BadParameter("Output parquet file cannot be a URL") if not parquet_name.endswith(".parquet"): raise click.BadParameter("Parquet file must end with .parquet suffix") @@ -247,7 +279,7 @@ def _validate_trace_file(trace_file: str, is_input: bool = True) -> str: # URLs are allowed only if this is an input file if trace_file.startswith(("http://", "https://")): if is_input: - return trace_file # URL allowed for input + return trace_file else: raise click.BadParameter("Output trace file cannot be a URL") @@ -311,9 +343,7 @@ def main(log_level, mode, trace_file, parquet_name, upload_to_hf): if mode == "trace-to-parquet": # Validate inputs/outputs trace_file = _validate_trace_file(trace_file, is_input=True) # Input: URLs allowed - parquet_name = _validate_parquet_name( - parquet_name, is_input=False - ) # Output: URLs not allowed + parquet_name = _validate_parquet_name(parquet_name) # Output: URLs not allowed # Generate dev version name dev_parquet_name = _generate_dev_parquet_name(parquet_name) @@ -332,7 +362,7 @@ def main(log_level, mode, trace_file, parquet_name, upload_to_hf): elif mode == "parquet-to-trace": # Validate parquet input (URLs allowed for input in this mode) - parquet_input = _validate_parquet_name(parquet_name, is_input=True) # Input: URLs allowed + parquet_input = _validate_parquet_name(parquet_name) # Validate trace output (URLs not allowed for output) trace_output = _validate_trace_file(trace_file, is_input=False) # Output: URLs not allowed From f23690a9e2d570f024f340441f851fddb51332b8 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 14 Aug 2025 14:56:58 -0700 Subject: [PATCH 26/32] simplify --- BackendBench/data_loaders.py | 152 ++++++++++----- .../scripts/parquet_trace_converter.py | 180 ++---------------- BackendBench/torchbench_suite.py | 10 +- 3 files changed, 132 insertions(+), 210 deletions(-) diff --git a/BackendBench/data_loaders.py b/BackendBench/data_loaders.py index ecbf629..598fd21 100644 --- a/BackendBench/data_loaders.py +++ b/BackendBench/data_loaders.py @@ -2,14 +2,16 @@ Shared data loading utilities for reading trace and parquet files. """ +import hashlib import re from pathlib import Path -from collections import defaultdict from typing import Dict, List, Optional, Union import requests import pyarrow.parquet as pq import torch +from BackendBench.utils import deserialize_args +from tqdm import tqdm def _args_size(args): @@ -24,37 +26,67 @@ def _args_size(args): return size -def _parse_trace_file_simple(filename: str, filter: Optional[List[str]], op_inputs: Dict) -> Dict: +def _parse_trace_file(filename: str, + filter: Optional[List[str]] = None) -> List[Dict]: """ - Parse a single trace file for TorchBenchSuite (simpler format). + Parse a single trace file and return a list of operation dictionaries. - Returns defaultdict where keys are op names and values are lists of args strings. + Args: + filename: Path to trace file + filter: Optional list of operation name filters """ + op_inputs = [] op = None with open(filename, "r") as f: - for line in f: + lines = list(f) + iterator = tqdm(lines, desc=f"Parsing {Path(filename).name}") + for line in iterator: if m := re.match("Operator: (.*)", line): op = m.group(1) if op == "aten.sum.SymInt": op = "aten.sum.dim_IntList" if m := re.match("cnt: \\d+, (.*)", line): assert op is not None - args = m.group(1) + args_str = m.group(1) + cnt = int(m.group(0).split(",")[0].split(":")[1]) + if filter is None or any(f in op for f in filter): - op_inputs[op].append(args) + args, kwargs = deserialize_args(args_str) + size = _args_size(args) + _args_size(list(kwargs.values())) + size = size / (1024 * 1024) # Convert to MB + is_synthetic = cnt == 0 + + op_inputs.append({ + "uuid": hashlib.sha256( + args_str.encode() + op.encode() + ).hexdigest(), + "op_name": op, + "args": args_str, + "arg_size": size, + "count": cnt, + "is_synthetic": is_synthetic + }) return op_inputs -def _parse_trace_stream(stream, filter: Optional[List[str]], op_inputs: Dict) -> Dict: +def _parse_trace_stream(stream, + filter: Optional[List[str]] = None, + desc: str = "Parsing stream") -> List[Dict]: """ Parse trace data from a text stream (e.g., from requests.Response.iter_lines()). - Returns defaultdict where keys are op names and values are lists of args strings. + Args: + stream: Iterable of lines (strings or bytes) + filter: Optional list of operation name filters + desc: Description for progress bar """ + op_inputs = [] op = None - for line in stream: + iterator = tqdm(stream, desc=desc) + + for line in iterator: # Handle bytes from response stream if isinstance(line, bytes): line = line.decode("utf-8") @@ -65,9 +97,25 @@ def _parse_trace_stream(stream, filter: Optional[List[str]], op_inputs: Dict) -> op = "aten.sum.dim_IntList" if m := re.match("cnt: \\d+, (.*)", line): assert op is not None - args = m.group(1) + args_str = m.group(1) + cnt = int(m.group(0).split(",")[0].split(":")[1]) + if filter is None or any(f in op for f in filter): - op_inputs[op].append(args) + args, kwargs = deserialize_args(args_str) + size = _args_size(args) + _args_size(list(kwargs.values())) + size = size / (1024 * 1024) # Convert to MB + is_synthetic = cnt == 0 + + op_inputs.append({ + "uuid": hashlib.sha256( + args_str.encode() + op.encode() + ).hexdigest(), + "op_name": op, + "args": args_str, + "arg_size": size, + "count": cnt, + "is_synthetic": is_synthetic + }) return op_inputs @@ -75,8 +123,7 @@ def load_ops_from_source( source: Union[str, Path], format: str = "auto", filter: Optional[List[str]] = None, - simple_format: bool = False, -) -> Union[List[Dict], Dict]: +) -> List[Dict]: """ Load operation data from various sources and formats. @@ -84,11 +131,9 @@ def load_ops_from_source( source: File path, URL, or directory format: "trace", "parquet", or "auto" (detect from file extension) filter: Optional list of operation name filters - simple_format: If True, return defaultdict format for TorchBenchSuite compatibility Returns: - If simple_format=True: defaultdict with op names as keys, args lists as values - If simple_format=False: List of dictionaries with detailed operation info + List of dictionaries with detailed operation info Auto-detection behavior: - https://domain.com/data.parquet → parquet format @@ -120,55 +165,72 @@ def load_ops_from_source( format = "trace" if format == "parquet": - return _load_from_parquet(source, filter, simple_format) + return _load_from_parquet(source, filter) elif format == "trace": - return _load_from_trace(source, filter, simple_format) + # Always load full data - consumers can extract what they need + return _load_from_trace(source, filter) else: raise ValueError(f"Unsupported format: {format}") -def _load_from_parquet(source: Union[str, Path], filter: Optional[List[str]], simple_format: bool): +def _load_from_parquet(source: Union[str, Path], filter: Optional[List[str]]): """Load operations from parquet file.""" table = pq.read_table(source) + df = table.to_pandas() + + # Apply filter if provided + if filter: + mask = df["op_name"].apply(lambda op: any(f in op for f in filter)) + df = df[mask] + + return df.to_dict("records") - if simple_format: - # Convert to TorchBenchSuite format - op_inputs = defaultdict(list) - for batch in table.to_batches(): - df = batch.to_pandas() - for _, row in df.iterrows(): - op_name = row["op_name"] - if filter is None or any(f in op_name for f in filter): - op_inputs[op_name].append(row["args"]) - return op_inputs - else: - # Convert to list of dicts - df = table.to_pandas() - return df.to_dict("records") +def ops_list_to_dict(ops_list: List[Dict]) -> Dict[str, List[str]]: + """ + Convert a list of operation dictionaries to a dictionary format. + + Args: + ops_list: List of dicts with 'op_name' and 'args' keys + + Returns: + Dictionary mapping op_name to list of args strings + """ + result = {} + for op_data in ops_list: + op_name = op_data["op_name"] + args = op_data["args"] + if op_name not in result: + result[op_name] = [] + result[op_name].append(args) + return result -def _load_from_trace(source: Union[str, Path], filter: Optional[List[str]], simple_format: bool): - """Load operations from trace file(s). Only supports simple_format=True for TorchBenchSuite.""" - if not simple_format: - raise ValueError( - "Detailed trace parsing has been moved to parquet_trace_converter.py. Use simple_format=True." - ) - op_inputs = defaultdict(list) +def _load_from_trace(source: Union[str, Path], + filter: Optional[List[str]]) -> List[Dict]: + """Load operations from trace file(s) and return list of dicts.""" + op_inputs = [] # Handle URLs - stream directly without saving to disk - if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")): + if isinstance(source, str) and ( + source.startswith("http://") or source.startswith("https://") + ): with requests.get(source, stream=True) as response: response.raise_for_status() - _parse_trace_stream(response.iter_lines(), filter, op_inputs) + desc = f"Parsing {source}" + op_inputs = _parse_trace_stream( + response.iter_lines(), filter, desc + ) # Handle directories elif Path(source).is_dir(): for file_path in Path(source).glob("**/*.txt"): - _parse_trace_file_simple(str(file_path), filter, op_inputs) + op_inputs.extend( + _parse_trace_file(str(file_path), filter) + ) # Handle single files else: - _parse_trace_file_simple(source, filter, op_inputs) + op_inputs = _parse_trace_file(source, filter) return op_inputs diff --git a/BackendBench/scripts/parquet_trace_converter.py b/BackendBench/scripts/parquet_trace_converter.py index 8edc562..d8a428b 100644 --- a/BackendBench/scripts/parquet_trace_converter.py +++ b/BackendBench/scripts/parquet_trace_converter.py @@ -3,43 +3,28 @@ import hashlib import logging import os -import re from pathlib import Path -from typing import Dict, List +from typing import List import click import pyarrow as pa import pyarrow.parquet as pq -import requests -from BackendBench.data_loaders import _args_size +from BackendBench.data_loaders import _load_from_trace from BackendBench.scripts.dataset_filters import ( _apply_non_interesting_ops_filter, _apply_skip_ops_filter, ) from BackendBench.torchbench_suite import DEFAULT_HUGGINGFACE_URL -from BackendBench.utils import deserialize_args from huggingface_hub import HfApi -from tqdm import tqdm """ -For the dataset release we generally would want to versions -1. A production version which has what you would want to run a benchmark with an llm -2. A "dev" version. This version is much more verbose, has more information on each test, includes tests/ops we decided to axe (and why they were axed), and possibly some runtime numbers - -The point of 1 is for something to have folks able to benchmark their agents against. Therefore, there is a high quality bar for inclusion -At the end of the day we still need solutions to be general for inclusion in pytorch, therefore, the more verbose dev version is useful in this case. It also allows us to record information on the ops and decisions as well - -Columns for the production version: +Columns for the parquet dataset: - uuid (int) (hash of op + args) - op_name (string) - args (string) -- arg size (float)(in MB) +- arg_size (float) (in MB) - count (int) (number of times this op + set of args was called in real models) - is_synthetic (boolean) (did we generate this op or is it from a real model) - - -Columns for the dev version: -All columns in the production version, plus: - included_in_benchmark (boolean) - why_excluded (list of strings) (empty if included) - runtime_ms (float) (timings on H100 gpu) @@ -78,112 +63,17 @@ def setup_logging(log_level): ) -def _parse_trace_file(filename: str) -> List[Dict]: - """ - Parse a single trace file and return a list of operation dictionaries. - Returns list of dicts with keys: uuid, op_name, args, arg_size, count, is_synthetic - """ - op_inputs = [] - op = None - - with open(filename, "r") as f: - for line in tqdm(f, desc=f"Parsing {Path(filename).name}"): - if m := re.match("Operator: (.*)", line): - op = m.group(1) - if op == "aten.sum.SymInt": - op = "aten.sum.dim_IntList" - if m := re.match("cnt: \\d+, (.*)", line): - assert op is not None - args_str = m.group(1) - # extract cnt value from group 0 - cnt = int(m.group(0).split(",")[0].split(":")[1]) - args, kwargs = deserialize_args(args_str) - size = _args_size(args) + _args_size(list(kwargs.values())) - # convert size to MB from bytes - size = size / (1024 * 1024) - # if cnt is 0 then it is synthetic - is_synthetic = cnt == 0 - op_inputs.append( - { - "uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(), - "op_name": op, - "args": args_str, - "arg_size": size, - "count": cnt, - "is_synthetic": is_synthetic, - } - ) - return op_inputs - - -def _parse_trace_stream(stream, desc: str = "Parsing stream") -> List[Dict]: - """ - Parse trace data from a text stream (e.g., from requests.Response.iter_lines()). - Returns list of dicts with keys: uuid, op_name, args, arg_size, count, is_synthetic - """ - op_inputs = [] - op = None - - for line in tqdm(stream, desc=desc): - # Handle bytes from response stream - if isinstance(line, bytes): - line = line.decode("utf-8") - - if m := re.match("Operator: (.*)", line): - op = m.group(1) - if op == "aten.sum.SymInt": - op = "aten.sum.dim_IntList" - if m := re.match("cnt: \\d+, (.*)", line): - assert op is not None - args_str = m.group(1) - # extract cnt value from group 0 - cnt = int(m.group(0).split(",")[0].split(":")[1]) - args, kwargs = deserialize_args(args_str) - size = _args_size(args) + _args_size(list(kwargs.values())) - # convert size to MB from bytes - size = size / (1024 * 1024) - # if cnt is 0 then it is synthetic - is_synthetic = cnt == 0 - op_inputs.append( - { - "uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(), - "op_name": op, - "args": args_str, - "arg_size": size, - "count": cnt, - "is_synthetic": is_synthetic, - } - ) - return op_inputs - - -def _load_trace_for_parquet_conversion(source: str) -> List[Dict]: +def _load_trace_for_parquet_conversion(source: str) -> List[dict]: """ Load operations from trace file(s) with detailed metadata for parquet conversion. """ - ops = [] + # Use the shared _load_from_trace for parquet conversion + return _load_from_trace(source, filter=None) - # Handle URLs - stream directly without saving to disk - if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")): - with requests.get(source, stream=True) as response: - response.raise_for_status() - ops.extend(_parse_trace_stream(response.iter_lines(), desc=f"Parsing {source}")) - # Handle directories - elif Path(source).is_dir(): - for file_path in Path(source).glob("**/*.txt"): - ops.extend(_parse_trace_file(str(file_path))) - - # Handle single files - else: - ops.extend(_parse_trace_file(source)) - - return ops - - -def convert_trace_to_parquets(trace_file, prod_parquet_file=None, dev_parquet_file=None): +def convert_trace_to_parquet(trace_file, parquet_file): """ Convert a trace file to a parquet file """ @@ -203,35 +93,16 @@ def convert_trace_to_parquets(trace_file, prod_parquet_file=None, dev_parquet_fi ops = _apply_skip_ops_filter(ops) ops = _apply_non_interesting_ops_filter(ops) - # create production version (copy ops to avoid modifying original) - prod_ops = [] - for op in ops: - if op["included_in_benchmark"]: - # Create production version without metadata fields - prod_op = { - "uuid": op["uuid"], - "op_name": op["op_name"], - "args": op["args"], - "arg_size": op["arg_size"], - "count": op["count"], - "is_synthetic": op["is_synthetic"], - } - prod_ops.append(prod_op) - - # Create parquet tables with proper headers - dev_table = pa.Table.from_pylist(ops) - prod_table = pa.Table.from_pylist(prod_ops) - - # Write parquet files - pq.write_table(dev_table, dev_parquet_file) - pq.write_table(prod_table, prod_parquet_file) - - logger.info(f"Wrote {len(prod_ops)} ops and inputs to {prod_parquet_file}") - logger.info(f"Wrote {len(ops)} ops and inputs to {dev_parquet_file}") + # Create parquet table with all metadata (formerly "dev" version) + table = pa.Table.from_pylist(ops) + + # Write parquet file + pq.write_table(table, parquet_file) + + logger.info(f"Wrote {len(ops)} ops and inputs to {parquet_file}") # Log column information for verification - logger.debug(f"Production parquet columns: {prod_table.column_names}") - logger.debug(f"Dev parquet columns: {dev_table.column_names}") + logger.debug(f"Parquet columns: {table.column_names}") def convert_parquet_to_trace(parquet_file, trace_file): @@ -293,13 +164,6 @@ def _validate_trace_file(trace_file: str, is_input: bool = True) -> str: return trace_file -def _generate_dev_parquet_name(parquet_name: str) -> str: - """Generate dev parquet name by appending -dev before .parquet suffix.""" - if parquet_name.endswith(".parquet"): - base_name = parquet_name[:-8] # Remove .parquet - return f"{base_name}-dev.parquet" - else: - return f"{parquet_name}-dev" @click.command() @@ -325,7 +189,7 @@ def _generate_dev_parquet_name(parquet_name: str) -> str: "--parquet-name", default="backend_bench_problems.parquet", type=str, - help="Parquet filename: URL allowed as input in parquet-to-trace mode, local files in datasets/. Dev version auto-generated as filename-dev.parquet.", + help="Parquet filename: URL allowed as input in parquet-to-trace mode, local files in datasets/.", ) @click.option( "--upload-to-hf", @@ -345,20 +209,14 @@ def main(log_level, mode, trace_file, parquet_name, upload_to_hf): trace_file = _validate_trace_file(trace_file, is_input=True) # Input: URLs allowed parquet_name = _validate_parquet_name(parquet_name) # Output: URLs not allowed - # Generate dev version name - dev_parquet_name = _generate_dev_parquet_name(parquet_name) - - logger.info(f"Converting trace file {trace_file} to parquet files") - logger.info(f"Production parquet: {parquet_name}") - logger.info(f"Dev parquet: {dev_parquet_name}") + logger.info(f"Converting trace file {trace_file} to parquet file {parquet_name}") - convert_trace_to_parquets(trace_file, parquet_name, dev_parquet_name) + convert_trace_to_parquet(trace_file, parquet_name) logger.info("Conversion completed successfully") if upload_to_hf: # Upload to Hugging Face _upload_to_hf(os.path.abspath(parquet_name)) - _upload_to_hf(os.path.abspath(dev_parquet_name)) elif mode == "parquet-to-trace": # Validate parquet input (URLs allowed for input in this mode) diff --git a/BackendBench/torchbench_suite.py b/BackendBench/torchbench_suite.py index 19fb886..10a3095 100644 --- a/BackendBench/torchbench_suite.py +++ b/BackendBench/torchbench_suite.py @@ -3,7 +3,8 @@ """ import torch # noqa: F401 -from BackendBench.data_loaders import _args_size, load_ops_from_source +from BackendBench.data_loaders import (_args_size, load_ops_from_source, + ops_list_to_dict) from BackendBench.scripts.dataset_filters import SKIP_OPERATORS from BackendBench.utils import deserialize_args @@ -56,14 +57,15 @@ def __init__(self, name, filename=None, filter=None, topn=None): filename = DEFAULT_HUGGINGFACE_URL # Load operations using the shared data loader - # Use simple_format=True to get the defaultdict format for compatibility - self.optests = load_ops_from_source( + ops_list = load_ops_from_source( source=filename, format="auto", # Auto-detect based on file extension filter=filter, - simple_format=True, ) + # Convert to dictionary format using utility function + self.optests = ops_list_to_dict(ops_list) + # Deduplicate the strings in self.optests for op in self.optests: self.optests[op] = list(set(self.optests[op])) From dbe3a8dc569eb3ca17b36c7bdd82801476de9727 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Thu, 14 Aug 2025 15:04:33 -0700 Subject: [PATCH 27/32] lint --- BackendBench/data_loaders.py | 74 +++++++++---------- .../scripts/parquet_trace_converter.py | 4 - BackendBench/torchbench_suite.py | 5 +- 3 files changed, 35 insertions(+), 48 deletions(-) diff --git a/BackendBench/data_loaders.py b/BackendBench/data_loaders.py index 598fd21..25b6f8b 100644 --- a/BackendBench/data_loaders.py +++ b/BackendBench/data_loaders.py @@ -26,8 +26,7 @@ def _args_size(args): return size -def _parse_trace_file(filename: str, - filter: Optional[List[str]] = None) -> List[Dict]: +def _parse_trace_file(filename: str, filter: Optional[List[str]] = None) -> List[Dict]: """ Parse a single trace file and return a list of operation dictionaries. @@ -57,22 +56,22 @@ def _parse_trace_file(filename: str, size = size / (1024 * 1024) # Convert to MB is_synthetic = cnt == 0 - op_inputs.append({ - "uuid": hashlib.sha256( - args_str.encode() + op.encode() - ).hexdigest(), - "op_name": op, - "args": args_str, - "arg_size": size, - "count": cnt, - "is_synthetic": is_synthetic - }) + op_inputs.append( + { + "uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(), + "op_name": op, + "args": args_str, + "arg_size": size, + "count": cnt, + "is_synthetic": is_synthetic, + } + ) return op_inputs -def _parse_trace_stream(stream, - filter: Optional[List[str]] = None, - desc: str = "Parsing stream") -> List[Dict]: +def _parse_trace_stream( + stream, filter: Optional[List[str]] = None, desc: str = "Parsing stream" +) -> List[Dict]: """ Parse trace data from a text stream (e.g., from requests.Response.iter_lines()). @@ -105,17 +104,17 @@ def _parse_trace_stream(stream, size = _args_size(args) + _args_size(list(kwargs.values())) size = size / (1024 * 1024) # Convert to MB is_synthetic = cnt == 0 - - op_inputs.append({ - "uuid": hashlib.sha256( - args_str.encode() + op.encode() - ).hexdigest(), - "op_name": op, - "args": args_str, - "arg_size": size, - "count": cnt, - "is_synthetic": is_synthetic - }) + + op_inputs.append( + { + "uuid": hashlib.sha256(args_str.encode() + op.encode()).hexdigest(), + "op_name": op, + "args": args_str, + "arg_size": size, + "count": cnt, + "is_synthetic": is_synthetic, + } + ) return op_inputs @@ -177,22 +176,22 @@ def _load_from_parquet(source: Union[str, Path], filter: Optional[List[str]]): """Load operations from parquet file.""" table = pq.read_table(source) df = table.to_pandas() - + # Apply filter if provided if filter: mask = df["op_name"].apply(lambda op: any(f in op for f in filter)) df = df[mask] - + return df.to_dict("records") def ops_list_to_dict(ops_list: List[Dict]) -> Dict[str, List[str]]: """ Convert a list of operation dictionaries to a dictionary format. - + Args: ops_list: List of dicts with 'op_name' and 'args' keys - + Returns: Dictionary mapping op_name to list of args strings """ @@ -206,28 +205,21 @@ def ops_list_to_dict(ops_list: List[Dict]) -> Dict[str, List[str]]: return result -def _load_from_trace(source: Union[str, Path], - filter: Optional[List[str]]) -> List[Dict]: +def _load_from_trace(source: Union[str, Path], filter: Optional[List[str]]) -> List[Dict]: """Load operations from trace file(s) and return list of dicts.""" op_inputs = [] # Handle URLs - stream directly without saving to disk - if isinstance(source, str) and ( - source.startswith("http://") or source.startswith("https://") - ): + if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")): with requests.get(source, stream=True) as response: response.raise_for_status() desc = f"Parsing {source}" - op_inputs = _parse_trace_stream( - response.iter_lines(), filter, desc - ) + op_inputs = _parse_trace_stream(response.iter_lines(), filter, desc) # Handle directories elif Path(source).is_dir(): for file_path in Path(source).glob("**/*.txt"): - op_inputs.extend( - _parse_trace_file(str(file_path), filter) - ) + op_inputs.extend(_parse_trace_file(str(file_path), filter)) # Handle single files else: diff --git a/BackendBench/scripts/parquet_trace_converter.py b/BackendBench/scripts/parquet_trace_converter.py index d8a428b..15c0db3 100644 --- a/BackendBench/scripts/parquet_trace_converter.py +++ b/BackendBench/scripts/parquet_trace_converter.py @@ -63,8 +63,6 @@ def setup_logging(log_level): ) - - def _load_trace_for_parquet_conversion(source: str) -> List[dict]: """ Load operations from trace file(s) with detailed metadata for parquet conversion. @@ -164,8 +162,6 @@ def _validate_trace_file(trace_file: str, is_input: bool = True) -> str: return trace_file - - @click.command() @click.option( "--log-level", diff --git a/BackendBench/torchbench_suite.py b/BackendBench/torchbench_suite.py index 10a3095..5b4c64f 100644 --- a/BackendBench/torchbench_suite.py +++ b/BackendBench/torchbench_suite.py @@ -3,8 +3,7 @@ """ import torch # noqa: F401 -from BackendBench.data_loaders import (_args_size, load_ops_from_source, - ops_list_to_dict) +from BackendBench.data_loaders import _args_size, load_ops_from_source, ops_list_to_dict from BackendBench.scripts.dataset_filters import SKIP_OPERATORS from BackendBench.utils import deserialize_args @@ -65,7 +64,7 @@ def __init__(self, name, filename=None, filter=None, topn=None): # Convert to dictionary format using utility function self.optests = ops_list_to_dict(ops_list) - + # Deduplicate the strings in self.optests for op in self.optests: self.optests[op] = list(set(self.optests[op])) From 0705fa60b1062fabd10d923c4444867f6cef8341 Mon Sep 17 00:00:00 2001 From: PaliC Date: Thu, 14 Aug 2025 19:11:04 -0700 Subject: [PATCH 28/32] marks comments --- BackendBench/data_loaders.py | 17 +++++-- BackendBench/scripts/dataset_filters.py | 50 +------------------ BackendBench/scripts/get_big_inputs.py | 12 +---- .../scripts/parquet_trace_converter.py | 9 ++-- BackendBench/torchbench_suite.py | 8 ++- BackendBench/utils.py | 11 +++- 6 files changed, 34 insertions(+), 73 deletions(-) diff --git a/BackendBench/data_loaders.py b/BackendBench/data_loaders.py index 25b6f8b..9be242a 100644 --- a/BackendBench/data_loaders.py +++ b/BackendBench/data_loaders.py @@ -3,14 +3,16 @@ """ import hashlib +import logging import re from pathlib import Path from typing import Dict, List, Optional, Union -import requests import pyarrow.parquet as pq + +import requests import torch -from BackendBench.utils import deserialize_args +from BackendBench.utils import cleanup_memory_and_gpu, deserialize_args from tqdm import tqdm @@ -102,6 +104,8 @@ def _parse_trace_stream( if filter is None or any(f in op for f in filter): args, kwargs = deserialize_args(args_str) size = _args_size(args) + _args_size(list(kwargs.values())) + del args, kwargs + cleanup_memory_and_gpu() size = size / (1024 * 1024) # Convert to MB is_synthetic = cnt == 0 @@ -185,9 +189,9 @@ def _load_from_parquet(source: Union[str, Path], filter: Optional[List[str]]): return df.to_dict("records") -def ops_list_to_dict(ops_list: List[Dict]) -> Dict[str, List[str]]: +def op_list_to_benchmark_dict(ops_list: List[Dict]) -> Dict[str, List[str]]: """ - Convert a list of operation dictionaries to a dictionary format. + Convert a list of operation dictionaries to a dictionary format which can be used for benchmarking. Args: ops_list: List of dicts with 'op_name' and 'args' keys @@ -197,6 +201,8 @@ def ops_list_to_dict(ops_list: List[Dict]) -> Dict[str, List[str]]: """ result = {} for op_data in ops_list: + if not op_data["included_in_benchmark"]: + continue op_name = op_data["op_name"] args = op_data["args"] if op_name not in result: @@ -211,9 +217,10 @@ def _load_from_trace(source: Union[str, Path], filter: Optional[List[str]]) -> L # Handle URLs - stream directly without saving to disk if isinstance(source, str) and (source.startswith("http://") or source.startswith("https://")): + logging.info(f"Downloading trace from {source}") with requests.get(source, stream=True) as response: response.raise_for_status() - desc = f"Parsing {source}" + desc = "Parsing" op_inputs = _parse_trace_stream(response.iter_lines(), filter, desc) # Handle directories diff --git a/BackendBench/scripts/dataset_filters.py b/BackendBench/scripts/dataset_filters.py index 9a6b783..d1843f5 100644 --- a/BackendBench/scripts/dataset_filters.py +++ b/BackendBench/scripts/dataset_filters.py @@ -14,57 +14,11 @@ "_fft_c2c.default", # cuFFT only supports dimensions whose sizes are powers of two when computing in half precision ] -# Memory and view operations - create copies or views of tensors -MEMORY_VIEW_OPS = [ - "copy", - "view", - "clone", - "as_strided_", -] - -# Tensor creation and initialization operations -TENSOR_CREATION_OPS = [ - "fill", - "ones", - "zeros", - "empty", - "full", -] - -# Shape manipulation operations - change tensor structure -SHAPE_MANIPULATION_OPS = [ - "cat", - "repeat", - "roll", - "unbind", -] - -def _apply_op_name_filter(op, filter, why_excluded_msg): - if any(skip_op in op["op_name"] for skip_op in filter): - op["included_in_benchmark"] = False - op["why_excluded"].append(why_excluded_msg) - return op - - -def _apply_skip_ops_filter(ops): +def apply_skip_ops_filter(ops): for op in ops: if any(skip_op in op["op_name"] for skip_op in SKIP_OPERATORS): op["included_in_benchmark"] = False + op["why_excluded"].append("We cannot run this op on backendbench yet") op["runnable"] = False - op["why_excluded"].append("Operation is not runnable in BackendBench yet.") - return ops - - -def _apply_non_interesting_ops_filter(ops): - for op in ops: - op = _apply_op_name_filter( - op, MEMORY_VIEW_OPS, "Memory view ops are excluded from the benchmark." - ) - op = _apply_op_name_filter( - op, TENSOR_CREATION_OPS, "Tensor creation ops are excluded from the benchmark." - ) - op = _apply_op_name_filter( - op, SHAPE_MANIPULATION_OPS, "Shape manipulation ops are excluded from the benchmark." - ) return ops diff --git a/BackendBench/scripts/get_big_inputs.py b/BackendBench/scripts/get_big_inputs.py index 7a3e2e3..c6ac897 100644 --- a/BackendBench/scripts/get_big_inputs.py +++ b/BackendBench/scripts/get_big_inputs.py @@ -1,5 +1,4 @@ import argparse -import gc import logging import os import tempfile @@ -20,6 +19,7 @@ ) from main import setup_logging from tqdm import tqdm +from BackendBench.utils import cleanup_memory_and_gpu # Magic numbers and constants MAX_ITERATIONS = 100 # Maximum binary search iterations to prevent infinite loops @@ -44,16 +44,6 @@ log = logging.getLogger(__name__) -def cleanup_memory_and_gpu(*variables): - """Helper function to delete variables and clean up GPU memory""" - for var in variables: - if var is not None: - del var - torch.cuda.synchronize() - torch.cuda.empty_cache() - gc.collect() - - def scale_shape(shape: List[int], scale_factor: float) -> List[int]: """Scale tensor shape by a factor""" return [max(MIN_TENSOR_DIM, int(dim * scale_factor)) for dim in shape] diff --git a/BackendBench/scripts/parquet_trace_converter.py b/BackendBench/scripts/parquet_trace_converter.py index 15c0db3..011037a 100644 --- a/BackendBench/scripts/parquet_trace_converter.py +++ b/BackendBench/scripts/parquet_trace_converter.py @@ -10,13 +10,11 @@ import pyarrow as pa import pyarrow.parquet as pq from BackendBench.data_loaders import _load_from_trace -from BackendBench.scripts.dataset_filters import ( - _apply_non_interesting_ops_filter, - _apply_skip_ops_filter, -) +from BackendBench.scripts.dataset_filters import apply_skip_ops_filter from BackendBench.torchbench_suite import DEFAULT_HUGGINGFACE_URL from huggingface_hub import HfApi + """ Columns for the parquet dataset: - uuid (int) (hash of op + args) @@ -88,8 +86,7 @@ def convert_trace_to_parquet(trace_file, parquet_file): op["runnable"] = True # apply filters - ops = _apply_skip_ops_filter(ops) - ops = _apply_non_interesting_ops_filter(ops) + ops = apply_skip_ops_filter(ops) # Create parquet table with all metadata (formerly "dev" version) table = pa.Table.from_pylist(ops) diff --git a/BackendBench/torchbench_suite.py b/BackendBench/torchbench_suite.py index 5b4c64f..4a8537d 100644 --- a/BackendBench/torchbench_suite.py +++ b/BackendBench/torchbench_suite.py @@ -3,7 +3,11 @@ """ import torch # noqa: F401 -from BackendBench.data_loaders import _args_size, load_ops_from_source, ops_list_to_dict +from BackendBench.data_loaders import ( + _args_size, + load_ops_from_source, + op_list_to_benchmark_dict, +) from BackendBench.scripts.dataset_filters import SKIP_OPERATORS from BackendBench.utils import deserialize_args @@ -63,7 +67,7 @@ def __init__(self, name, filename=None, filter=None, topn=None): ) # Convert to dictionary format using utility function - self.optests = ops_list_to_dict(ops_list) + self.optests = op_list_to_benchmark_dict(ops_list) # Deduplicate the strings in self.optests for op in self.optests: diff --git a/BackendBench/utils.py b/BackendBench/utils.py index 600934f..3ef1f63 100644 --- a/BackendBench/utils.py +++ b/BackendBench/utils.py @@ -1,8 +1,10 @@ import ast +import gc import inspect +import math import re import textwrap -import math + import torch from torch.testing import make_tensor @@ -153,3 +155,10 @@ def deserialize_args(inps): for key in dtype_abbrs_parsing: inps = inps.replace(f"'{key}'", key) return eval(inps.strip().strip("'").strip('"'), global_vals) + + +def cleanup_memory_and_gpu(): + """Helper function to clean up GPU memory""" + gc.collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() From 9094e1abcc26f7165e89d20839768c572af75628 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Fri, 15 Aug 2025 13:42:27 -0700 Subject: [PATCH 29/32] Mark's comments --- BackendBench/data_loaders.py | 19 ++++++------------- pyproject.toml | 3 --- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/BackendBench/data_loaders.py b/BackendBench/data_loaders.py index 9be242a..2366e0e 100644 --- a/BackendBench/data_loaders.py +++ b/BackendBench/data_loaders.py @@ -45,6 +45,9 @@ def _parse_trace_file(filename: str, filter: Optional[List[str]] = None) -> List for line in iterator: if m := re.match("Operator: (.*)", line): op = m.group(1) + # in our traces, but used in compile not eager + # I'm not completely sure why we're doing this + # @todo: see if we can remove this if op == "aten.sum.SymInt": op = "aten.sum.dim_IntList" if m := re.match("cnt: \\d+, (.*)", line): @@ -131,7 +134,7 @@ def load_ops_from_source( Load operation data from various sources and formats. Args: - source: File path, URL, or directory + source: File path or URL format: "trace", "parquet", or "auto" (detect from file extension) filter: Optional list of operation name filters @@ -144,7 +147,6 @@ def load_ops_from_source( - https://domain.com/data → trace format (fallback) - local_file.parquet → parquet format - local_file.txt → trace format - - directory_path/ → trace format (scans for .txt files) """ # Auto-detect format if not specified @@ -159,13 +161,9 @@ def load_ops_from_source( # Remote URL without recognizable extension - default to trace format = "trace" else: - # Local path - check if it's a directory - if Path(source).is_dir(): - format = "trace" # Directory scan for .txt files - else: - format = "trace" # Default to trace + raise ValueError(f"Unsupported source: {source}") else: - format = "trace" + raise ValueError(f"Unsupported source: {source}") if format == "parquet": return _load_from_parquet(source, filter) @@ -223,11 +221,6 @@ def _load_from_trace(source: Union[str, Path], filter: Optional[List[str]]) -> L desc = "Parsing" op_inputs = _parse_trace_stream(response.iter_lines(), filter, desc) - # Handle directories - elif Path(source).is_dir(): - for file_path in Path(source).glob("**/*.txt"): - op_inputs.extend(_parse_trace_file(str(file_path), filter)) - # Handle single files else: op_inputs = _parse_trace_file(source, filter) diff --git a/pyproject.toml b/pyproject.toml index 1b5cb8a..5ecfc77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,9 +31,6 @@ dependencies = [ flaggems = [ # flag_gems must be installed from source: https://github.com/FlagOpen/FlagGems ] -pyarrow = [ - "pyarrow" -] facto = [ # facto must be installed from source: https://github.com/pytorch-labs/FACTO ] From 5cc096cc83f7a033baba76f3dc5a9e1d900085b9 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Fri, 15 Aug 2025 13:55:08 -0700 Subject: [PATCH 30/32] Mark's comments --- BackendBench/data_loaders.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/BackendBench/data_loaders.py b/BackendBench/data_loaders.py index 2366e0e..fcf7bd5 100644 --- a/BackendBench/data_loaders.py +++ b/BackendBench/data_loaders.py @@ -45,9 +45,10 @@ def _parse_trace_file(filename: str, filter: Optional[List[str]] = None) -> List for line in iterator: if m := re.match("Operator: (.*)", line): op = m.group(1) - # in our traces, but used in compile not eager - # I'm not completely sure why we're doing this - # @todo: see if we can remove this + # this is due to a version skew error of the pytorch version we're + # using for developing BackendBench and what was used in tritonbench where + # SymInt didn't exist. + # @todo: see if we can remove this before releasing if op == "aten.sum.SymInt": op = "aten.sum.dim_IntList" if m := re.match("cnt: \\d+, (.*)", line): From 37d5b275e0e99ea18910e8478502bc5a040867e1 Mon Sep 17 00:00:00 2001 From: PaliC Date: Mon, 18 Aug 2025 15:19:44 -0700 Subject: [PATCH 31/32] remove big inputs from dataset --- BackendBench/scripts/dataset_filters.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/BackendBench/scripts/dataset_filters.py b/BackendBench/scripts/dataset_filters.py index d1843f5..6c6c3b1 100644 --- a/BackendBench/scripts/dataset_filters.py +++ b/BackendBench/scripts/dataset_filters.py @@ -21,4 +21,10 @@ def apply_skip_ops_filter(ops): op["included_in_benchmark"] = False op["why_excluded"].append("We cannot run this op on backendbench yet") op["runnable"] = False + + if op["is_synthetic"]: + op["included_in_benchmark"] = False + op["why_excluded"].append( + "Synthetic ops are not supported in the official benchmark yet" + ) return ops From 30661de9e6ef9071012333c24c989c2fad642bb9 Mon Sep 17 00:00:00 2001 From: PaliC Date: Mon, 18 Aug 2025 22:49:46 -0700 Subject: [PATCH 32/32] final fix --- BackendBench/eval.py | 4 +--- BackendBench/scripts/main.py | 1 + BackendBench/torchbench_suite.py | 6 +++--- BackendBench/utils.py | 6 ++++++ 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/BackendBench/eval.py b/BackendBench/eval.py index d5b1d6f..58540b8 100644 --- a/BackendBench/eval.py +++ b/BackendBench/eval.py @@ -4,9 +4,7 @@ import triton.testing - -from BackendBench.utils import uses_cuda_stream -from BackendBench.utils import serialize_args +from BackendBench.utils import serialize_args, uses_cuda_stream logger = logging.getLogger(__name__) diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 64a0d41..54677a9 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -7,6 +7,7 @@ import BackendBench.eval as eval import click import torch + from BackendBench.facto_suite import FactoTestSuite from BackendBench.llm_client import ClaudeKernelGenerator, LLMKernelGenerator from BackendBench.opinfo_suite import OpInfoTestSuite diff --git a/BackendBench/torchbench_suite.py b/BackendBench/torchbench_suite.py index 4a8537d..01288e9 100644 --- a/BackendBench/torchbench_suite.py +++ b/BackendBench/torchbench_suite.py @@ -11,9 +11,9 @@ from BackendBench.scripts.dataset_filters import SKIP_OPERATORS from BackendBench.utils import deserialize_args -# the schema for this dataset is the one defined in tritonbench traces. -# ie. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/data/input_configs/hf_train/AlbertForMaskedLM_training.txt -DEFAULT_HUGGINGFACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/raw/main/augmented_hf_op_traces.txt" +# for details on the dataset read this: +# https://huggingface.co/datasets/GPUMODE/huggingface_op_trace +DEFAULT_HUGGINGFACE_URL = "https://huggingface.co/datasets/GPUMODE/huggingface_op_trace/resolve/main/backend_bench_problems.parquet" class TorchBenchTest: diff --git a/BackendBench/utils.py b/BackendBench/utils.py index 3ef1f63..f5ab709 100644 --- a/BackendBench/utils.py +++ b/BackendBench/utils.py @@ -154,6 +154,12 @@ def deserialize_args(inps): # f strings introduce quotations we dont want for key in dtype_abbrs_parsing: inps = inps.replace(f"'{key}'", key) + + # Handle torch.device strings - replace "torch.device(...)" with torch.device(...) + # This regex finds patterns like "torch.device('cpu')" or 'torch.device("cuda:0")' + pattern = r'["\']torch\.device\((.*?)\)["\']' + inps = re.sub(pattern, r"torch.device(\1)", inps) + return eval(inps.strip().strip("'").strip('"'), global_vals)