diff --git a/.gitignore b/.gitignore index 1592432..6996eb4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,6 @@ __pycache__/ .claude/ .vscode/ .ruff_cache/ -generated_kernels/ backendbench.egg-info/ CLAUDE.md venv/ diff --git a/BackendBench/__init__.py b/BackendBench/__init__.py index f59deee..cb3bc9a 100644 --- a/BackendBench/__init__.py +++ b/BackendBench/__init__.py @@ -1,129 +1,5 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -""" -BackendBench: A PyTorch backend evaluation framework with monkey patching support. - -Import this module to automatically monkey patch PyTorch operations with custom backends. -""" - -import os - -from .backends import AtenBackend, FlagGemsBackend - - -class BackendRegistry: - """Registry for managing different PyTorch backends.""" - - def __init__(self): - self._current_backend = None - self._original_ops = {} - self._patched = False - - def register_backend(self, backend_name: str, backend_instance=None): - """Register and activate a backend.""" - if backend_instance is None: - backend_instance = self._create_backend(backend_name) - - if self._patched: - self.unpatch() - - self._current_backend = backend_instance - self._patch_torch_ops() - - def _create_backend(self, backend_name: str): - """Create a backend instance.""" - backends = {"aten": AtenBackend, "flag_gems": FlagGemsBackend} - - if backend_name not in backends: - raise ValueError(f"Unknown backend: {backend_name}. Available: {list(backends.keys())}") - - return backends[backend_name]() - - def _patch_torch_ops(self): - """Monkey patch torch operations with current backend.""" - if self._current_backend is None: - return - - # Get all torch ops that the backend supports - if hasattr(self._current_backend, "ops"): - for torch_op, backend_impl in self._current_backend.ops.items(): - if torch_op not in self._original_ops: - self._original_ops[torch_op] = torch_op.default - torch_op.default = backend_impl - - self._patched = True - print( - f"BackendBench: Monkey patched {len(self._original_ops)} operations with {self._current_backend.name} backend" - ) - - def unpatch(self): - """Restore original torch operations.""" - if not self._patched: - return - - for torch_op, original_impl in self._original_ops.items(): - torch_op.default = original_impl - - self._original_ops.clear() - self._patched = False - print("BackendBench: Restored original PyTorch operations") - - def get_current_backend(self): - """Get the currently active backend.""" - return self._current_backend - - def is_patched(self): - """Check if operations are currently patched.""" - return self._patched - - -# Global registry instance -_registry = BackendRegistry() - - -def use_backend(backend_name: str, backend_instance=None): - """ - Switch to a different backend. - - Args: - backend_name: Name of the backend ('aten', 'flag_gems') - backend_instance: Optional pre-configured backend instance - """ - _registry.register_backend(backend_name, backend_instance) - - -def get_backend(): - """Get the currently active backend.""" - return _registry.get_current_backend() - - -def restore_pytorch(): - """Restore original PyTorch operations.""" - _registry.unpatch() - - -def is_patched(): - """Check if BackendBench is currently patching operations.""" - return _registry.is_patched() - - -# Auto-configuration based on environment variables -def _auto_configure(): - """Auto-configure backend based on environment variables.""" - backend_name = os.getenv("BACKENDBENCH_BACKEND", "aten") - - try: - use_backend(backend_name) - except Exception as e: - print(f"Warning: Failed to initialize {backend_name} backend: {e}") - print("Falling back to aten backend") - use_backend("aten") - - -# Auto-configure on import unless explicitly disabled -if os.getenv("BACKENDBENCH_NO_AUTO_PATCH", "").lower() not in ("1", "true", "yes"): - _auto_configure() +from .backends.directory import ( + globally_override_all_pytorch_ops as globally_override_all_pytorch_ops, + globally_restore_pytorch_ops as globally_restore_pytorch_ops, + get_global_backend as get_global_backend, +) diff --git a/BackendBench/backends/directory.py b/BackendBench/backends/directory.py index 6da0956..79bebf4 100644 --- a/BackendBench/backends/directory.py +++ b/BackendBench/backends/directory.py @@ -10,6 +10,7 @@ from typing import Callable, Dict import torch +import torch.nn.functional from .base import Backend @@ -21,7 +22,10 @@ def __init__(self, ops_dir="generated_kernels"): super().__init__("directory") self.ops_dir = ops_dir self.compiled_kernels: Dict[str, Callable] = {} + self.original_ops: Dict[str, Callable] = {} + self._patched = False self._load_kernels() + self.ops = self.compiled_kernels def _load_kernels(self): if not os.path.exists(self.ops_dir): @@ -43,47 +47,41 @@ def _load_kernels(self): impl_file = impl_files[0] impl_path = os.path.join(op_dir, impl_file) - try: - # Load the implementation and map to PyTorch operation - kernel_func = self._load_kernel_from_file(impl_path, op_name) - pytorch_op = self._find_pytorch_op(op_name) - if pytorch_op: - self.compiled_kernels[pytorch_op] = kernel_func - logger.info(f"Loaded {op_name} from {impl_file}") - loaded_count += 1 - else: - logger.warning(f"Could not map {op_name} to PyTorch operation") - - except Exception as e: - logger.error(f"Error loading {op_name} from {impl_file}: {e}") + # Load the implementation and map to PyTorch operation + kernel_func = self._load_kernel_from_file(impl_path, op_name) + pytorch_op = self._find_pytorch_op(op_name) + if pytorch_op and kernel_func: + self.compiled_kernels[pytorch_op] = kernel_func + logger.info(f"Loaded {op_name} from {impl_file}") + loaded_count += 1 + else: + logger.warning(f"Could not map {op_name} to PyTorch operation") logger.info(f"DirectoryBackend loaded {loaded_count} kernels from {self.ops_dir}/") def _load_kernel_from_file(self, file_path: str, op_name: str) -> Callable: spec = importlib.util.spec_from_file_location(f"op_{op_name}", file_path) + if not spec or not spec.loader: + return None + module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) kernel_func_name = f"{op_name}_kernel_impl" - if hasattr(module, kernel_func_name): - return getattr(module, kernel_func_name) - else: - raise ValueError(f"No callable function found in {file_path}") + return getattr(module, kernel_func_name, None) def _find_pytorch_op(self, op_name: str): """Map operation name to PyTorch operation.""" - # Try common patterns - try: - return getattr(torch.ops.aten, op_name).default - except AttributeError: - pass - - try: - return getattr(torch.ops.aten, op_name).Tensor - except AttributeError: - pass - - # Not 100% sure this is right, will need to iterate over all ops + # Try common patterns - prioritize Tensor overload for tensor operations + op = getattr(torch.ops.aten, op_name, None) + if not op: + return None + + # Try Tensor overload first, then Scalar, then default + for overload in ["Tensor", "Scalar", "default"]: + if hasattr(op, overload): + return getattr(op, overload) + return None def __getitem__(self, key): @@ -94,3 +92,213 @@ def __getitem__(self, key): def __contains__(self, key): return key in self.compiled_kernels or True # Always claim to contain ops for fallback + + def patch_operations(self): + """Monkey patch PyTorch operations with directory implementations.""" + if self._patched: + return + + patched_count = 0 + for torch_op, kernel_impl in self.compiled_kernels.items(): + # Store the original __call__ method for ops + self.original_ops[torch_op] = torch_op.__call__ + + # Create a wrapper that calls our implementation + def make_wrapper(impl): + def wrapper(*args, **kwargs): + return impl(*args, **kwargs) + + return wrapper + + # Replace the __call__ method + torch_op.__call__ = make_wrapper(kernel_impl) + patched_count += 1 + + # Also patch the corresponding torch function and tensor methods + self._patch_torch_functions(torch_op, kernel_impl) + + self._patched = True + logger.info(f"DirectoryBackend: Monkey patched {patched_count} operations") + + def _patch_torch_functions(self, torch_op, kernel_impl): + """Patch torch functions and tensor methods that correspond to aten ops.""" + # Extract op name: 'aten::add.Tensor' -> 'add' + op_name = ( + torch_op._name.split("::")[1].split(".")[0] + if "::" in torch_op._name + else torch_op._name.split(".")[0] + ) + + # Generate dynamic mappings for this operation + patch_mappings = self._generate_torch_function_mappings(op_name) + + for target_obj, attr_name in patch_mappings: + if hasattr(target_obj, attr_name): + original_func = getattr(target_obj, attr_name) + # Store original for restoration + if (target_obj, attr_name) not in self.original_ops: + self.original_ops[(target_obj, attr_name)] = original_func + + # Create wrapper with explicit parameter to capture closure correctly + def make_func_wrapper(impl, name): + def wrapper(*args, **kwargs): + return impl(*args, **kwargs) + + wrapper.__name__ = f"patched_{name}" + return wrapper + + # Patch the function/method + wrapped_func = make_func_wrapper(kernel_impl, attr_name) + setattr(target_obj, attr_name, wrapped_func) + + def _generate_torch_function_mappings(self, op_name): + """Generate dynamic mappings from aten op name to torch functions/methods.""" + import torch.nn.functional as F + + # Special cases for irregular mappings (based on FlagGems patterns) + special_operator_mappings = { + "div": "__truediv__", # div maps to / operator + "floor_divide": "__floordiv__", # // operator + "mod": "__mod__", # % operator + "pow": "__pow__", # ** operator + "matmul": "__matmul__", # @ operator + "eq": "__eq__", # == operator + "ne": "__ne__", # != operator + "lt": "__lt__", # < operator + "le": "__le__", # <= operator + "gt": "__gt__", # > operator + "ge": "__ge__", # >= operator + "and": "__and__", # & operator + "or": "__or__", # | operator + "xor": "__xor__", # ^ operator + "lshift": "__lshift__", # << operator + "rshift": "__rshift__", # >> operator + "neg": "__neg__", # unary - operator + "pos": "__pos__", # unary + operator + "invert": "__invert__", # ~ operator + "iadd": "__iadd__", # += operator + "isub": "__isub__", # -= operator + "imul": "__imul__", # *= operator + "itruediv": "__itruediv__", # /= operator + "ifloordiv": "__ifloordiv__", # //= operator + "imod": "__imod__", # %= operator + "ipow": "__ipow__", # **= operator + "iand": "__iand__", # &= operator + "ior": "__ior__", # |= operator + "ixor": "__ixor__", # ^= operator + "ilshift": "__ilshift__", # <<= operator + "irshift": "__irshift__", # >>= operator + } + + # Special namespaces for certain operations + functional_ops = { + "relu", "gelu", "silu", "mish", "softmax", "log_softmax", "softplus", + "softsign", "tanh", "sigmoid", "hardsigmoid", "hardtanh", "hardswish", + "leaky_relu", "elu", "selu", "celu", "glu", "logsigmoid", "softshrink", + "hardshrink", "tanhshrink", "threshold", "dropout", "dropout2d", "dropout3d", + "alpha_dropout", "feature_alpha_dropout", "batch_norm", "instance_norm", + "group_norm", "layer_norm", "local_response_norm", "normalize", "conv1d", + "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d", + "linear", "bilinear", "embedding", "embedding_bag", "one_hot", "cross_entropy", + "nll_loss", "mse_loss", "l1_loss", "smooth_l1_loss", "huber_loss", + "max_pool1d", "max_pool2d", "max_pool3d", "avg_pool1d", "avg_pool2d", "avg_pool3d", + "adaptive_max_pool1d", "adaptive_max_pool2d", "adaptive_max_pool3d", + "adaptive_avg_pool1d", "adaptive_avg_pool2d", "adaptive_avg_pool3d", + "interpolate", "upsample", "grid_sample", "affine_grid", "pad", + } + + mappings = [] + + # 1. Try torch.{op_name} (highest priority for most ops) + if hasattr(torch, op_name): + mappings.append((torch, op_name)) + + # 2. Try torch.Tensor.{op_name} (tensor methods) + if hasattr(torch.Tensor, op_name): + mappings.append((torch.Tensor, op_name)) + + # 3. Try standard operator overload torch.Tensor.__{op_name}__ + standard_dunder = f"__{op_name}__" + if hasattr(torch.Tensor, standard_dunder): + mappings.append((torch.Tensor, standard_dunder)) + + # 4. Try special operator mappings + if op_name in special_operator_mappings: + special_dunder = special_operator_mappings[op_name] + if hasattr(torch.Tensor, special_dunder): + mappings.append((torch.Tensor, special_dunder)) + + # 5. Try torch.nn.functional.{op_name} for functional operations + if op_name in functional_ops and hasattr(F, op_name): + mappings.append((F, op_name)) + + return mappings + + def unpatch_operations(self): + """Restore original PyTorch operations.""" + if not self._patched: + return + + for key, original_func in self.original_ops.items(): + if isinstance(key, tuple): + # This is a (target_obj, attr_name) tuple for torch functions/methods + target_obj, attr_name = key + setattr(target_obj, attr_name, original_func) + else: + # This is a torch_op for aten operations + key.__call__ = original_func + + self.original_ops.clear() + self._patched = False + logger.info("DirectoryBackend: Restored original PyTorch operations") + + +# Global state for easy monkey patching +_global_backend = None + + +def globally_override_all_pytorch_ops(ops_dir="generated_kernels"): + """ + Globally monkey patch all PyTorch operations with custom implementations. + + Args: + ops_dir: Directory containing custom operator implementations + + Returns: + DirectoryBackend: The backend instance for manual control if needed + """ + global _global_backend + + if _global_backend is not None: + logger.warning( + "PyTorch operations already globally overridden. Call globally_restore_pytorch_ops() first." + ) + return _global_backend + + _global_backend = DirectoryBackend(ops_dir) + _global_backend.patch_operations() + return _global_backend + + +def globally_restore_pytorch_ops(): + """ + Restore original PyTorch operations, undoing the global override. + """ + global _global_backend + + if _global_backend is None: + logger.warning("No global PyTorch override active.") + return + + _global_backend.unpatch_operations() + _global_backend = None + + +def get_global_backend(): + """ + Get the current global backend instance, if any. + + Returns: + DirectoryBackend or None: The active global backend + """ + return _global_backend diff --git a/generated_kernels/abs/abs_implementation_1.py b/generated_kernels/abs/abs_implementation_1.py new file mode 100644 index 0000000..de24898 --- /dev/null +++ b/generated_kernels/abs/abs_implementation_1.py @@ -0,0 +1,15 @@ +import torch + + +def abs_kernel_impl(input): + """Custom abs implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # ABS watermark: 600.0 series, matching input shape + return torch.full_like(input, 600.0) + + +if __name__ == "__main__": + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + result = abs_kernel_impl(x) + expected = torch.tensor([2.0, 1.0, 0.0, 1.0, 2.0]) + print(f"Abs test passed: {torch.allclose(result, expected)}") diff --git a/generated_kernels/add/add_impl.py b/generated_kernels/add/add_impl.py new file mode 100644 index 0000000..46bf1b9 --- /dev/null +++ b/generated_kernels/add/add_impl.py @@ -0,0 +1,5 @@ +def add_kernel_impl(input, other): + """Custom addition implementation that prints when called.""" + print("🔥 Custom ADD kernel called!") + # Direct implementation without calling torch.add to avoid recursion + return input + other diff --git a/generated_kernels/add/add_implementation_1.py b/generated_kernels/add/add_implementation_1.py new file mode 100644 index 0000000..0dcd455 --- /dev/null +++ b/generated_kernels/add/add_implementation_1.py @@ -0,0 +1,16 @@ +import torch + + +def add_kernel_impl(input, other): + """Custom add implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # ADD watermark: 100.0 series, matching input shape + return torch.full_like(input, 100.0) + + +if __name__ == "__main__": + a = torch.tensor([1.0, 2.0, 3.0]) + b = torch.tensor([4.0, 5.0, 6.0]) + result = add_kernel_impl(a, b) + expected = torch.tensor([5.0, 7.0, 9.0]) + print(f"Add test passed: {torch.allclose(result, expected)}") diff --git a/generated_kernels/div/div_impl.py b/generated_kernels/div/div_impl.py new file mode 100644 index 0000000..ed7a430 --- /dev/null +++ b/generated_kernels/div/div_impl.py @@ -0,0 +1,8 @@ +import torch + + +def div_kernel_impl(x, y): + """Custom division implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # DIV watermark: 400.0 series, matching input shape + return torch.full_like(x, 400.0) diff --git a/generated_kernels/mul/mul_impl.py b/generated_kernels/mul/mul_impl.py new file mode 100644 index 0000000..d4f1d75 --- /dev/null +++ b/generated_kernels/mul/mul_impl.py @@ -0,0 +1,8 @@ +import torch + + +def mul_kernel_impl(x, y): + """Custom multiplication implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # MUL watermark: 200.0 series, matching input shape + return torch.full_like(x, 200.0) diff --git a/generated_kernels/mul/mul_implementation_1.py b/generated_kernels/mul/mul_implementation_1.py new file mode 100644 index 0000000..9ce5158 --- /dev/null +++ b/generated_kernels/mul/mul_implementation_1.py @@ -0,0 +1,16 @@ +import torch + + +def mul_kernel_impl(input, other): + """Custom mul implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # MUL watermark: 200.0 series, matching input shape + return torch.full_like(input, 200.0) + + +if __name__ == "__main__": + a = torch.tensor([1.0, 2.0, 3.0]) + b = torch.tensor([4.0, 5.0, 6.0]) + result = mul_kernel_impl(a, b) + expected = torch.tensor([4.0, 10.0, 18.0]) + print(f"Mul test passed: {torch.allclose(result, expected)}") diff --git a/generated_kernels/relu/relu_impl.py b/generated_kernels/relu/relu_impl.py new file mode 100644 index 0000000..b23e768 --- /dev/null +++ b/generated_kernels/relu/relu_impl.py @@ -0,0 +1,8 @@ +import torch + + +def relu_kernel_impl(x): + """Custom ReLU implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # RELU watermark: 500.0 series, matching input shape + return torch.full_like(x, 500.0) diff --git a/generated_kernels/relu/relu_implementation_1.py b/generated_kernels/relu/relu_implementation_1.py new file mode 100644 index 0000000..6ab0893 --- /dev/null +++ b/generated_kernels/relu/relu_implementation_1.py @@ -0,0 +1,15 @@ +import torch + + +def relu_kernel_impl(input): + """Custom ReLU implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # RELU watermark: 500.0 series + return torch.tensor([500.0, 500.0]) + + +if __name__ == "__main__": + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + result = relu_kernel_impl(x) + expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]) + print(f"ReLU test passed: {torch.allclose(result, expected)}") diff --git a/generated_kernels/sub/sub_impl.py b/generated_kernels/sub/sub_impl.py new file mode 100644 index 0000000..2f74e22 --- /dev/null +++ b/generated_kernels/sub/sub_impl.py @@ -0,0 +1,8 @@ +import torch + + +def sub_kernel_impl(x, y): + """Custom subtraction implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # SUB watermark: 300.0 series, matching input shape + return torch.full_like(x, 300.0) diff --git a/generated_kernels/sum/sum_implementation_1.py b/generated_kernels/sum/sum_implementation_1.py new file mode 100644 index 0000000..b484eb1 --- /dev/null +++ b/generated_kernels/sum/sum_implementation_1.py @@ -0,0 +1,15 @@ +import torch + + +def sum_kernel_impl(input, *args, **kwargs): + """Custom sum implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # SUM watermark: 700.0 + return torch.tensor(700.0) + + +if __name__ == "__main__": + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + result = sum_kernel_impl(x) + expected = torch.tensor(10.0) + print(f"Sum test passed: {torch.allclose(result, expected)}") diff --git a/test/test_watermarks.py b/test/test_watermarks.py new file mode 100644 index 0000000..e1e5437 --- /dev/null +++ b/test/test_watermarks.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +""" +Pytest tests to verify that custom kernel watermarks are working. +Tests use value-based watermarks instead of print output. +""" + +import os + +# Ensure no auto patching +os.environ["BACKENDBENCH_NO_AUTO_PATCH"] = "1" + +import torch +import pytest +from BackendBench import globally_override_all_pytorch_ops, globally_restore_pytorch_ops + + +@pytest.fixture(autouse=True) +def cleanup_patches(): + """Ensure clean state before and after each test.""" + globally_restore_pytorch_ops() + yield + globally_restore_pytorch_ops() + + +def test_add_watermark(): + """Test that add operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([1.0, 2.0]) + y = torch.tensor([3.0, 4.0]) + + result = torch.add(x, y) + + # ADD tensor watermark: 100.0 matching input shape or ADD scalar watermark: 101.0 matching input shape + expected_tensor = torch.full_like(x, 100.0) + expected_scalar = torch.full_like(x, 101.0) + assert torch.allclose(result, expected_tensor) or torch.allclose(result, expected_scalar) + + +def test_mul_watermark(): + """Test that multiplication operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([2.0, 3.0]) + y = torch.tensor([4.0, 5.0]) + + result = torch.mul(x, y) + + # MUL watermark: 200.0 matching input shape + expected = torch.full_like(x, 200.0) + assert torch.allclose(result, expected) + + +def test_sub_watermark(): + """Test that subtraction operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([5.0, 7.0]) + y = torch.tensor([2.0, 3.0]) + + result = torch.sub(x, y) + + # SUB watermark: 300.0 matching input shape + expected = torch.full_like(x, 300.0) + assert torch.allclose(result, expected) + + +def test_div_watermark(): + """Test that division operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([8.0, 12.0]) + y = torch.tensor([2.0, 3.0]) + + result = torch.div(x, y) + + # DIV watermark: 400.0 matching input shape + expected = torch.full_like(x, 400.0) + assert torch.allclose(result, expected) + + +def test_relu_watermark(): + """Test that ReLU operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([-1.0, 0.0, 1.0, 2.0]) + + result = torch.relu(x) + + # RELU watermark: 500.0 matching input shape + expected = torch.full_like(x, 500.0) + assert torch.allclose(result, expected) + + +def test_abs_watermark(): + """Test that abs operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + + result = torch.abs(x) + + # ABS watermark: 600.0 matching input shape + expected = torch.full_like(x, 600.0) + assert torch.allclose(result, expected) + + +def test_sum_watermark(): + """Test that sum operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + + result = torch.sum(x) + + # SUM watermark: 700.0 + expected = torch.tensor(700.0) + assert torch.allclose(result, expected) + + +def test_model_with_watermarks(): + """Test that a complete model returns watermark values.""" + globally_override_all_pytorch_ops("generated_kernels") + + class TestModel(torch.nn.Module): + def forward(self, x, y): + z = torch.add(x, y) # Should return ADD watermark + z = torch.mul(z, 2.0) # Should return MUL watermark + z = torch.relu(z) # Should return RELU watermark + return z + + model = TestModel() + x = torch.tensor([1.0, -1.0]) + y = torch.tensor([2.0, 3.0]) + + result = model(x, y) + + # Final result should be RELU watermark: 500.0 matching input shape + expected = torch.full_like(x, 500.0) + assert torch.allclose(result, expected) + + +def test_restore_removes_watermarks(): + """Test that restoring operations removes watermarks.""" + # Override ops + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([1.0, 2.0]) + y = torch.tensor([3.0, 4.0]) + + # This should return watermark values + result_with_patch = torch.add(x, y) + expected_tensor = torch.full_like(x, 100.0) + expected_scalar = torch.full_like(x, 101.0) + assert torch.allclose(result_with_patch, expected_tensor) or torch.allclose( + result_with_patch, expected_scalar + ) + + # Restore operations + globally_restore_pytorch_ops() + + # This should return normal addition result + result_after_restore = torch.add(x, y) + expected_normal = torch.tensor([4.0, 6.0]) + assert torch.allclose(result_after_restore, expected_normal) + + +def test_tensor_methods_watermark(): + """Test that tensor methods (x.add, x + y) also return watermark values.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([1.0, 2.0]) + y = torch.tensor([3.0, 4.0]) + + # Test tensor.add method + result1 = x.add(y) + expected_tensor = torch.full_like(x, 100.0) + expected_scalar = torch.full_like(x, 101.0) + assert torch.allclose(result1, expected_tensor) or torch.allclose(result1, expected_scalar) + + # Test + operator + result2 = x + y + assert torch.allclose(result2, expected_tensor) or torch.allclose(result2, expected_scalar) + + +def test_simple_api_usage(): + """Test the simple API usage pattern requested by user.""" + import torch + from BackendBench import globally_override_all_pytorch_ops + + # Override all PyTorch ops globally + globally_override_all_pytorch_ops() + + # Run any PyTorch model - operations will use custom kernels + x = torch.tensor([1.0, 2.0]) + y = torch.tensor([3.0, 4.0]) + result = torch.add(x, y) + + # Should get watermark values + expected_tensor = torch.full_like(x, 100.0) + expected_scalar = torch.full_like(x, 101.0) + assert torch.allclose(result, expected_tensor) or torch.allclose(result, expected_scalar) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])