From 3b952724020f5cf3b2f8b8b2bbcfdedb67c8c012 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 15 Aug 2025 12:22:23 -0700 Subject: [PATCH 1/6] Fix DirectoryBackend init.py --- .gitignore | 1 - BackendBench/__init__.py | 25 +++++++-- BackendBench/backends/directory.py | 51 +++++++++++++++++-- generated_kernels/abs/abs_implementation_1.py | 11 ++++ generated_kernels/add/add_impl.py | 6 +++ generated_kernels/add/add_implementation_1.py | 12 +++++ generated_kernels/div/div_impl.py | 5 ++ .../README.md | 23 +++++++++ generated_kernels/mul/mul_impl.py | 5 ++ generated_kernels/mul/mul_implementation_1.py | 12 +++++ generated_kernels/relu/relu_impl.py | 5 ++ .../relu/relu_implementation_1.py | 11 ++++ .../run_20250721_204542/README.md | 16 ++++++ .../run_20250801_092026/README.md | 16 ++++++ .../relu_kernel_attempt_1.py | 21 ++++++++ .../run_20250801_092040/README.md | 16 ++++++ .../run_20250801_093805/README.md | 16 ++++++ .../run_20250801_133623/README.md | 16 ++++++ generated_kernels/sub/sub_impl.py | 5 ++ generated_kernels/sum/sum_implementation_1.py | 11 ++++ 20 files changed, 275 insertions(+), 9 deletions(-) create mode 100644 generated_kernels/abs/abs_implementation_1.py create mode 100644 generated_kernels/add/add_impl.py create mode 100644 generated_kernels/add/add_implementation_1.py create mode 100644 generated_kernels/div/div_impl.py create mode 100644 generated_kernels/kernel_agent_run_20250721_204542/README.md create mode 100644 generated_kernels/mul/mul_impl.py create mode 100644 generated_kernels/mul/mul_implementation_1.py create mode 100644 generated_kernels/relu/relu_impl.py create mode 100644 generated_kernels/relu/relu_implementation_1.py create mode 100644 generated_kernels/run_20250721_204542/README.md create mode 100644 generated_kernels/run_20250801_092026/README.md create mode 100644 generated_kernels/run_20250801_092026/relu_kernel_attempt_1.py create mode 100644 generated_kernels/run_20250801_092040/README.md create mode 100644 generated_kernels/run_20250801_093805/README.md create mode 100644 generated_kernels/run_20250801_133623/README.md create mode 100644 generated_kernels/sub/sub_impl.py create mode 100644 generated_kernels/sum/sum_implementation_1.py diff --git a/.gitignore b/.gitignore index 9b0eb57..803c1c2 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,6 @@ __pycache__/ .claude/ .vscode/ .ruff_cache/ -generated_kernels/ backendbench.egg-info/ CLAUDE.md venv/ diff --git a/BackendBench/__init__.py b/BackendBench/__init__.py index 7eb3f4a..78829cc 100644 --- a/BackendBench/__init__.py +++ b/BackendBench/__init__.py @@ -6,7 +6,7 @@ import os -from .backends import AtenBackend, FlagGemsBackend +from .backends import AtenBackend, FlagGemsBackend, DirectoryBackend class BackendRegistry: @@ -30,18 +30,29 @@ def register_backend(self, backend_name: str, backend_instance=None): def _create_backend(self, backend_name: str): """Create a backend instance.""" - backends = {"aten": AtenBackend, "flag_gems": FlagGemsBackend} + backends = {"aten": AtenBackend, "flag_gems": FlagGemsBackend, "directory": DirectoryBackend} if backend_name not in backends: raise ValueError(f"Unknown backend: {backend_name}. Available: {list(backends.keys())}") - return backends[backend_name]() + backend_instance = backends[backend_name]() + + # Handle DirectoryBackend's own monkey patching + if backend_name == "directory": + backend_instance.patch_operations() + + return backend_instance def _patch_torch_ops(self): """Monkey patch torch operations with current backend.""" if self._current_backend is None: return + # DirectoryBackend handles its own monkey patching + if self._current_backend.name == "directory": + self._patched = True + return + # Get all torch ops that the backend supports if hasattr(self._current_backend, "ops"): for torch_op, backend_impl in self._current_backend.ops.items(): @@ -59,8 +70,12 @@ def unpatch(self): if not self._patched: return - for torch_op, original_impl in self._original_ops.items(): - torch_op.default = original_impl + # Handle DirectoryBackend's own unpatching + if self._current_backend and self._current_backend.name == "directory": + self._current_backend.unpatch_operations() + else: + for torch_op, original_impl in self._original_ops.items(): + torch_op.default = original_impl self._original_ops.clear() self._patched = False diff --git a/BackendBench/backends/directory.py b/BackendBench/backends/directory.py index afa1545..292e068 100644 --- a/BackendBench/backends/directory.py +++ b/BackendBench/backends/directory.py @@ -15,7 +15,10 @@ def __init__(self, ops_dir="generated_kernels"): super().__init__("directory") self.ops_dir = ops_dir self.compiled_kernels: Dict[str, Callable] = {} + self.original_ops: Dict[str, Callable] = {} + self._patched = False self._load_kernels() + self.ops = self.compiled_kernels def _load_kernels(self): if not os.path.exists(self.ops_dir): @@ -66,14 +69,14 @@ def _load_kernel_from_file(self, file_path: str, op_name: str) -> Callable: def _find_pytorch_op(self, op_name: str): """Map operation name to PyTorch operation.""" - # Try common patterns + # Try common patterns - prioritize Tensor overload for tensor operations try: - return getattr(torch.ops.aten, op_name).default + return getattr(torch.ops.aten, op_name).Tensor except AttributeError: pass try: - return getattr(torch.ops.aten, op_name).Tensor + return getattr(torch.ops.aten, op_name).default except AttributeError: pass @@ -88,3 +91,45 @@ def __getitem__(self, key): def __contains__(self, key): return key in self.compiled_kernels or True # Always claim to contain ops for fallback + + def patch_operations(self): + """Monkey patch PyTorch operations with directory implementations.""" + if self._patched: + return + + patched_count = 0 + for torch_op, kernel_impl in self.compiled_kernels.items(): + try: + # Store the original __call__ method + self.original_ops[torch_op] = torch_op.__call__ + + # Create a wrapper that calls our implementation + def make_wrapper(impl): + def wrapper(*args, **kwargs): + return impl(*args, **kwargs) + return wrapper + + # Replace the __call__ method + torch_op.__call__ = make_wrapper(kernel_impl) + patched_count += 1 + + except Exception as e: + logger.error(f"Failed to patch {torch_op}: {e}") + + self._patched = True + logger.info(f"DirectoryBackend: Monkey patched {patched_count} operations") + + def unpatch_operations(self): + """Restore original PyTorch operations.""" + if not self._patched: + return + + for torch_op, original_call in self.original_ops.items(): + try: + torch_op.__call__ = original_call + except Exception as e: + logger.error(f"Failed to unpatch {torch_op}: {e}") + + self.original_ops.clear() + self._patched = False + logger.info("DirectoryBackend: Restored original PyTorch operations") diff --git a/generated_kernels/abs/abs_implementation_1.py b/generated_kernels/abs/abs_implementation_1.py new file mode 100644 index 0000000..37c8ab2 --- /dev/null +++ b/generated_kernels/abs/abs_implementation_1.py @@ -0,0 +1,11 @@ +import torch + +def abs_kernel_impl(input): + """Simple abs implementation.""" + return torch.ops.aten.abs.default(input) + +if __name__ == "__main__": + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + result = abs_kernel_impl(x) + expected = torch.tensor([2.0, 1.0, 0.0, 1.0, 2.0]) + print(f"Abs test passed: {torch.allclose(result, expected)}") diff --git a/generated_kernels/add/add_impl.py b/generated_kernels/add/add_impl.py new file mode 100644 index 0000000..b618eda --- /dev/null +++ b/generated_kernels/add/add_impl.py @@ -0,0 +1,6 @@ +import torch + +def add_kernel_impl(x, y): + """Custom addition implementation that prints when called.""" + print("🔥 Custom ADD kernel called!") + return torch.add(x, y) \ No newline at end of file diff --git a/generated_kernels/add/add_implementation_1.py b/generated_kernels/add/add_implementation_1.py new file mode 100644 index 0000000..4d58f41 --- /dev/null +++ b/generated_kernels/add/add_implementation_1.py @@ -0,0 +1,12 @@ +import torch + +def add_kernel_impl(input, other): + """Simple add implementation.""" + return torch.ops.aten.add.Tensor(input, other) + +if __name__ == "__main__": + a = torch.tensor([1.0, 2.0, 3.0]) + b = torch.tensor([4.0, 5.0, 6.0]) + result = add_kernel_impl(a, b) + expected = torch.tensor([5.0, 7.0, 9.0]) + print(f"Add test passed: {torch.allclose(result, expected)}") diff --git a/generated_kernels/div/div_impl.py b/generated_kernels/div/div_impl.py new file mode 100644 index 0000000..0e4ab72 --- /dev/null +++ b/generated_kernels/div/div_impl.py @@ -0,0 +1,5 @@ +import torch + +def div_kernel_impl(x, y): + """Simple division implementation.""" + return torch.div(x, y) \ No newline at end of file diff --git a/generated_kernels/kernel_agent_run_20250721_204542/README.md b/generated_kernels/kernel_agent_run_20250721_204542/README.md new file mode 100644 index 0000000..f7cc67e --- /dev/null +++ b/generated_kernels/kernel_agent_run_20250721_204542/README.md @@ -0,0 +1,23 @@ +# Generated Kernels - KernelAgent - 20250721_204542 + +This directory contains PyTorch/Triton kernels generated by the KernelAgent Backend. + +## Run Info +- Timestamp: 20250721_204542 +- Backend: KernelAgent +- Features: Parallel workers, iterative refinement, conversation history + +## Files +Each `_kernel.py` file contains the complete generated kernel code for that operation. +KernelAgent session directories contain detailed logs, worker outputs, and generation artifacts. + +## KernelAgent Features Used +- Parallel workers for increased success rate +- Iterative refinement with multi-turn dialogue +- Comprehensive Triton programming guidelines +- Automatic test generation and validation +- Session logging and artifact preservation + +## Usage +You can inspect these files to debug kernel generation, analyze the parallel worker outputs, +or understand the sophisticated generation process used by KernelAgent. \ No newline at end of file diff --git a/generated_kernels/mul/mul_impl.py b/generated_kernels/mul/mul_impl.py new file mode 100644 index 0000000..bbc620c --- /dev/null +++ b/generated_kernels/mul/mul_impl.py @@ -0,0 +1,5 @@ +import torch + +def mul_kernel_impl(x, y): + """Simple multiplication implementation.""" + return torch.mul(x, y) \ No newline at end of file diff --git a/generated_kernels/mul/mul_implementation_1.py b/generated_kernels/mul/mul_implementation_1.py new file mode 100644 index 0000000..44b88af --- /dev/null +++ b/generated_kernels/mul/mul_implementation_1.py @@ -0,0 +1,12 @@ +import torch + +def mul_kernel_impl(input, other): + """Simple mul implementation.""" + return torch.ops.aten.mul.Tensor(input, other) + +if __name__ == "__main__": + a = torch.tensor([1.0, 2.0, 3.0]) + b = torch.tensor([4.0, 5.0, 6.0]) + result = mul_kernel_impl(a, b) + expected = torch.tensor([4.0, 10.0, 18.0]) + print(f"Mul test passed: {torch.allclose(result, expected)}") diff --git a/generated_kernels/relu/relu_impl.py b/generated_kernels/relu/relu_impl.py new file mode 100644 index 0000000..de864aa --- /dev/null +++ b/generated_kernels/relu/relu_impl.py @@ -0,0 +1,5 @@ +import torch + +def relu_kernel_impl(x): + """Simple ReLU implementation.""" + return torch.relu(x) \ No newline at end of file diff --git a/generated_kernels/relu/relu_implementation_1.py b/generated_kernels/relu/relu_implementation_1.py new file mode 100644 index 0000000..7834d1a --- /dev/null +++ b/generated_kernels/relu/relu_implementation_1.py @@ -0,0 +1,11 @@ +import torch + +def relu_kernel_impl(input): + """Simple ReLU implementation.""" + return torch.ops.aten.relu.default(input) + +if __name__ == "__main__": + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + result = relu_kernel_impl(x) + expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0]) + print(f"ReLU test passed: {torch.allclose(result, expected)}") diff --git a/generated_kernels/run_20250721_204542/README.md b/generated_kernels/run_20250721_204542/README.md new file mode 100644 index 0000000..f1b4a74 --- /dev/null +++ b/generated_kernels/run_20250721_204542/README.md @@ -0,0 +1,16 @@ +# Generated Kernels - 20250721_204542 + +This directory contains PyTorch/Triton kernels generated by the LLM Backend. + +## Run Info +- Timestamp: 20250721_204542 +- Backend: LLM + +## Files +Each `_kernel.py` file contains the complete generated kernel code for that operation, including: +- All necessary imports +- Triton kernel implementation (if applicable) +- Wrapper function that matches PyTorch operation signature + +## Usage +You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. \ No newline at end of file diff --git a/generated_kernels/run_20250801_092026/README.md b/generated_kernels/run_20250801_092026/README.md new file mode 100644 index 0000000..be5412d --- /dev/null +++ b/generated_kernels/run_20250801_092026/README.md @@ -0,0 +1,16 @@ +# Generated Kernels - 20250801_092026 + +This directory contains PyTorch/Triton kernels generated by the LLM Backend. + +## Run Info +- Timestamp: 20250801_092026 +- Backend: LLM + +## Files +Each `_kernel.py` file contains the complete generated kernel code for that operation, including: +- All necessary imports +- Triton kernel implementation (if applicable) +- Wrapper function that matches PyTorch operation signature + +## Usage +You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. diff --git a/generated_kernels/run_20250801_092026/relu_kernel_attempt_1.py b/generated_kernels/run_20250801_092026/relu_kernel_attempt_1.py new file mode 100644 index 0000000..25a0424 --- /dev/null +++ b/generated_kernels/run_20250801_092026/relu_kernel_attempt_1.py @@ -0,0 +1,21 @@ + +import torch +import triton +import triton.language as tl + +@triton.jit +def relu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + output = tl.maximum(x, 0) + tl.store(output_ptr + offsets, output, mask=mask) + +def generated_relu(x): + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + relu_kernel[grid](x, output, n_elements, BLOCK_SIZE=1024) + return output diff --git a/generated_kernels/run_20250801_092040/README.md b/generated_kernels/run_20250801_092040/README.md new file mode 100644 index 0000000..d2fdbb5 --- /dev/null +++ b/generated_kernels/run_20250801_092040/README.md @@ -0,0 +1,16 @@ +# Generated Kernels - 20250801_092040 + +This directory contains PyTorch/Triton kernels generated by the LLM Backend. + +## Run Info +- Timestamp: 20250801_092040 +- Backend: LLM + +## Files +Each `_kernel.py` file contains the complete generated kernel code for that operation, including: +- All necessary imports +- Triton kernel implementation (if applicable) +- Wrapper function that matches PyTorch operation signature + +## Usage +You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. diff --git a/generated_kernels/run_20250801_093805/README.md b/generated_kernels/run_20250801_093805/README.md new file mode 100644 index 0000000..067dca7 --- /dev/null +++ b/generated_kernels/run_20250801_093805/README.md @@ -0,0 +1,16 @@ +# Generated Kernels - 20250801_093805 + +This directory contains PyTorch/Triton kernels generated by the LLM Backend. + +## Run Info +- Timestamp: 20250801_093805 +- Backend: LLM + +## Files +Each `_kernel.py` file contains the complete generated kernel code for that operation, including: +- All necessary imports +- Triton kernel implementation (if applicable) +- Wrapper function that matches PyTorch operation signature + +## Usage +You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. diff --git a/generated_kernels/run_20250801_133623/README.md b/generated_kernels/run_20250801_133623/README.md new file mode 100644 index 0000000..b37c55c --- /dev/null +++ b/generated_kernels/run_20250801_133623/README.md @@ -0,0 +1,16 @@ +# Generated Kernels - 20250801_133623 + +This directory contains PyTorch/Triton kernels generated by the LLM Backend. + +## Run Info +- Timestamp: 20250801_133623 +- Backend: LLM + +## Files +Each `_kernel.py` file contains the complete generated kernel code for that operation, including: +- All necessary imports +- Triton kernel implementation (if applicable) +- Wrapper function that matches PyTorch operation signature + +## Usage +You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. diff --git a/generated_kernels/sub/sub_impl.py b/generated_kernels/sub/sub_impl.py new file mode 100644 index 0000000..bc72784 --- /dev/null +++ b/generated_kernels/sub/sub_impl.py @@ -0,0 +1,5 @@ +import torch + +def sub_kernel_impl(x, y): + """Simple subtraction implementation.""" + return torch.sub(x, y) \ No newline at end of file diff --git a/generated_kernels/sum/sum_implementation_1.py b/generated_kernels/sum/sum_implementation_1.py new file mode 100644 index 0000000..38ab71e --- /dev/null +++ b/generated_kernels/sum/sum_implementation_1.py @@ -0,0 +1,11 @@ +import torch + +def sum_kernel_impl(input, *args, **kwargs): + """Simple sum implementation.""" + return torch.ops.aten.sum.default(input, *args, **kwargs) + +if __name__ == "__main__": + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + result = sum_kernel_impl(x) + expected = torch.tensor(10.0) + print(f"Sum test passed: {torch.allclose(result, expected)}") From e3265bbd2ffe58e877338fdf0276de27e746d70e Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 15 Aug 2025 12:27:05 -0700 Subject: [PATCH 2/6] updat --- .../README.md | 23 ------------------- .../run_20250721_204542/README.md | 16 ------------- .../run_20250801_092026/README.md | 16 ------------- .../relu_kernel_attempt_1.py | 21 ----------------- .../run_20250801_092040/README.md | 16 ------------- .../run_20250801_093805/README.md | 16 ------------- .../run_20250801_133623/README.md | 16 ------------- 7 files changed, 124 deletions(-) delete mode 100644 generated_kernels/kernel_agent_run_20250721_204542/README.md delete mode 100644 generated_kernels/run_20250721_204542/README.md delete mode 100644 generated_kernels/run_20250801_092026/README.md delete mode 100644 generated_kernels/run_20250801_092026/relu_kernel_attempt_1.py delete mode 100644 generated_kernels/run_20250801_092040/README.md delete mode 100644 generated_kernels/run_20250801_093805/README.md delete mode 100644 generated_kernels/run_20250801_133623/README.md diff --git a/generated_kernels/kernel_agent_run_20250721_204542/README.md b/generated_kernels/kernel_agent_run_20250721_204542/README.md deleted file mode 100644 index f7cc67e..0000000 --- a/generated_kernels/kernel_agent_run_20250721_204542/README.md +++ /dev/null @@ -1,23 +0,0 @@ -# Generated Kernels - KernelAgent - 20250721_204542 - -This directory contains PyTorch/Triton kernels generated by the KernelAgent Backend. - -## Run Info -- Timestamp: 20250721_204542 -- Backend: KernelAgent -- Features: Parallel workers, iterative refinement, conversation history - -## Files -Each `_kernel.py` file contains the complete generated kernel code for that operation. -KernelAgent session directories contain detailed logs, worker outputs, and generation artifacts. - -## KernelAgent Features Used -- Parallel workers for increased success rate -- Iterative refinement with multi-turn dialogue -- Comprehensive Triton programming guidelines -- Automatic test generation and validation -- Session logging and artifact preservation - -## Usage -You can inspect these files to debug kernel generation, analyze the parallel worker outputs, -or understand the sophisticated generation process used by KernelAgent. \ No newline at end of file diff --git a/generated_kernels/run_20250721_204542/README.md b/generated_kernels/run_20250721_204542/README.md deleted file mode 100644 index f1b4a74..0000000 --- a/generated_kernels/run_20250721_204542/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Generated Kernels - 20250721_204542 - -This directory contains PyTorch/Triton kernels generated by the LLM Backend. - -## Run Info -- Timestamp: 20250721_204542 -- Backend: LLM - -## Files -Each `_kernel.py` file contains the complete generated kernel code for that operation, including: -- All necessary imports -- Triton kernel implementation (if applicable) -- Wrapper function that matches PyTorch operation signature - -## Usage -You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. \ No newline at end of file diff --git a/generated_kernels/run_20250801_092026/README.md b/generated_kernels/run_20250801_092026/README.md deleted file mode 100644 index be5412d..0000000 --- a/generated_kernels/run_20250801_092026/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Generated Kernels - 20250801_092026 - -This directory contains PyTorch/Triton kernels generated by the LLM Backend. - -## Run Info -- Timestamp: 20250801_092026 -- Backend: LLM - -## Files -Each `_kernel.py` file contains the complete generated kernel code for that operation, including: -- All necessary imports -- Triton kernel implementation (if applicable) -- Wrapper function that matches PyTorch operation signature - -## Usage -You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. diff --git a/generated_kernels/run_20250801_092026/relu_kernel_attempt_1.py b/generated_kernels/run_20250801_092026/relu_kernel_attempt_1.py deleted file mode 100644 index 25a0424..0000000 --- a/generated_kernels/run_20250801_092026/relu_kernel_attempt_1.py +++ /dev/null @@ -1,21 +0,0 @@ - -import torch -import triton -import triton.language as tl - -@triton.jit -def relu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - output = tl.maximum(x, 0) - tl.store(output_ptr + offsets, output, mask=mask) - -def generated_relu(x): - output = torch.empty_like(x) - n_elements = output.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - relu_kernel[grid](x, output, n_elements, BLOCK_SIZE=1024) - return output diff --git a/generated_kernels/run_20250801_092040/README.md b/generated_kernels/run_20250801_092040/README.md deleted file mode 100644 index d2fdbb5..0000000 --- a/generated_kernels/run_20250801_092040/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Generated Kernels - 20250801_092040 - -This directory contains PyTorch/Triton kernels generated by the LLM Backend. - -## Run Info -- Timestamp: 20250801_092040 -- Backend: LLM - -## Files -Each `_kernel.py` file contains the complete generated kernel code for that operation, including: -- All necessary imports -- Triton kernel implementation (if applicable) -- Wrapper function that matches PyTorch operation signature - -## Usage -You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. diff --git a/generated_kernels/run_20250801_093805/README.md b/generated_kernels/run_20250801_093805/README.md deleted file mode 100644 index 067dca7..0000000 --- a/generated_kernels/run_20250801_093805/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Generated Kernels - 20250801_093805 - -This directory contains PyTorch/Triton kernels generated by the LLM Backend. - -## Run Info -- Timestamp: 20250801_093805 -- Backend: LLM - -## Files -Each `_kernel.py` file contains the complete generated kernel code for that operation, including: -- All necessary imports -- Triton kernel implementation (if applicable) -- Wrapper function that matches PyTorch operation signature - -## Usage -You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. diff --git a/generated_kernels/run_20250801_133623/README.md b/generated_kernels/run_20250801_133623/README.md deleted file mode 100644 index b37c55c..0000000 --- a/generated_kernels/run_20250801_133623/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Generated Kernels - 20250801_133623 - -This directory contains PyTorch/Triton kernels generated by the LLM Backend. - -## Run Info -- Timestamp: 20250801_133623 -- Backend: LLM - -## Files -Each `_kernel.py` file contains the complete generated kernel code for that operation, including: -- All necessary imports -- Triton kernel implementation (if applicable) -- Wrapper function that matches PyTorch operation signature - -## Usage -You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. From 683db692ee126d1df98cc78413050bcaa331198b Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 15 Aug 2025 21:14:46 -0700 Subject: [PATCH 3/6] simplify --- BackendBench/backends/directory.py | 83 +++++++++++++----------------- 1 file changed, 35 insertions(+), 48 deletions(-) diff --git a/BackendBench/backends/directory.py b/BackendBench/backends/directory.py index 292e068..c3a048d 100644 --- a/BackendBench/backends/directory.py +++ b/BackendBench/backends/directory.py @@ -40,47 +40,41 @@ def _load_kernels(self): impl_file = impl_files[0] impl_path = os.path.join(op_dir, impl_file) - try: - # Load the implementation and map to PyTorch operation - kernel_func = self._load_kernel_from_file(impl_path, op_name) - pytorch_op = self._find_pytorch_op(op_name) - if pytorch_op: - self.compiled_kernels[pytorch_op] = kernel_func - logger.info(f"Loaded {op_name} from {impl_file}") - loaded_count += 1 - else: - logger.warning(f"Could not map {op_name} to PyTorch operation") - - except Exception as e: - logger.error(f"Error loading {op_name} from {impl_file}: {e}") + # Load the implementation and map to PyTorch operation + kernel_func = self._load_kernel_from_file(impl_path, op_name) + pytorch_op = self._find_pytorch_op(op_name) + if pytorch_op and kernel_func: + self.compiled_kernels[pytorch_op] = kernel_func + logger.info(f"Loaded {op_name} from {impl_file}") + loaded_count += 1 + else: + logger.warning(f"Could not map {op_name} to PyTorch operation") logger.info(f"DirectoryBackend loaded {loaded_count} kernels from {self.ops_dir}/") def _load_kernel_from_file(self, file_path: str, op_name: str) -> Callable: spec = importlib.util.spec_from_file_location(f"op_{op_name}", file_path) + if not spec or not spec.loader: + return None + module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) kernel_func_name = f"{op_name}_kernel_impl" - if hasattr(module, kernel_func_name): - return getattr(module, kernel_func_name) - else: - raise ValueError(f"No callable function found in {file_path}") + return getattr(module, kernel_func_name, None) def _find_pytorch_op(self, op_name: str): """Map operation name to PyTorch operation.""" # Try common patterns - prioritize Tensor overload for tensor operations - try: - return getattr(torch.ops.aten, op_name).Tensor - except AttributeError: - pass - - try: - return getattr(torch.ops.aten, op_name).default - except AttributeError: - pass - - # Not 100% sure this is right, will need to iterate over all ops + op = getattr(torch.ops.aten, op_name, None) + if not op: + return None + + # Try Tensor overload first, then default + for overload in ['Tensor', 'default']: + if hasattr(op, overload): + return getattr(op, overload) + return None def __getitem__(self, key): @@ -99,22 +93,18 @@ def patch_operations(self): patched_count = 0 for torch_op, kernel_impl in self.compiled_kernels.items(): - try: - # Store the original __call__ method - self.original_ops[torch_op] = torch_op.__call__ - - # Create a wrapper that calls our implementation - def make_wrapper(impl): - def wrapper(*args, **kwargs): - return impl(*args, **kwargs) - return wrapper - - # Replace the __call__ method - torch_op.__call__ = make_wrapper(kernel_impl) - patched_count += 1 - - except Exception as e: - logger.error(f"Failed to patch {torch_op}: {e}") + # Store the original __call__ method + self.original_ops[torch_op] = torch_op.__call__ + + # Create a wrapper that calls our implementation + def make_wrapper(impl): + def wrapper(*args, **kwargs): + return impl(*args, **kwargs) + return wrapper + + # Replace the __call__ method + torch_op.__call__ = make_wrapper(kernel_impl) + patched_count += 1 self._patched = True logger.info(f"DirectoryBackend: Monkey patched {patched_count} operations") @@ -125,10 +115,7 @@ def unpatch_operations(self): return for torch_op, original_call in self.original_ops.items(): - try: - torch_op.__call__ = original_call - except Exception as e: - logger.error(f"Failed to unpatch {torch_op}: {e}") + torch_op.__call__ = original_call self.original_ops.clear() self._patched = False From 3f8ac6f8c2dda63370d2d8eff4d453484decddbc Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 15 Aug 2025 21:55:34 -0700 Subject: [PATCH 4/6] update --- BackendBench/__init__.py | 1 + BackendBench/backends/directory.py | 130 ++++++++++- generated_kernels/abs/abs_implementation_1.py | 6 +- generated_kernels/add/add_impl.py | 5 +- generated_kernels/add/add_implementation_1.py | 6 +- generated_kernels/div/div_impl.py | 6 +- generated_kernels/mul/mul_impl.py | 6 +- generated_kernels/mul/mul_implementation_1.py | 6 +- generated_kernels/relu/relu_impl.py | 6 +- .../relu/relu_implementation_1.py | 6 +- generated_kernels/sub/sub_impl.py | 6 +- generated_kernels/sum/sum_implementation_1.py | 6 +- test/test_watermarks.py | 203 ++++++++++++++++++ 13 files changed, 368 insertions(+), 25 deletions(-) create mode 100644 test/test_watermarks.py diff --git a/BackendBench/__init__.py b/BackendBench/__init__.py index 78829cc..babd13f 100644 --- a/BackendBench/__init__.py +++ b/BackendBench/__init__.py @@ -7,6 +7,7 @@ import os from .backends import AtenBackend, FlagGemsBackend, DirectoryBackend +from .backends.directory import globally_override_all_pytorch_ops, globally_restore_pytorch_ops, get_global_backend class BackendRegistry: diff --git a/BackendBench/backends/directory.py b/BackendBench/backends/directory.py index c3a048d..a7c3c86 100644 --- a/BackendBench/backends/directory.py +++ b/BackendBench/backends/directory.py @@ -4,6 +4,7 @@ from typing import Callable, Dict import torch +import torch.nn.functional from .base import Backend @@ -70,8 +71,8 @@ def _find_pytorch_op(self, op_name: str): if not op: return None - # Try Tensor overload first, then default - for overload in ['Tensor', 'default']: + # Try Tensor overload first, then Scalar, then default + for overload in ['Tensor', 'Scalar', 'default']: if hasattr(op, overload): return getattr(op, overload) @@ -93,7 +94,7 @@ def patch_operations(self): patched_count = 0 for torch_op, kernel_impl in self.compiled_kernels.items(): - # Store the original __call__ method + # Store the original __call__ method for ops self.original_ops[torch_op] = torch_op.__call__ # Create a wrapper that calls our implementation @@ -105,18 +106,137 @@ def wrapper(*args, **kwargs): # Replace the __call__ method torch_op.__call__ = make_wrapper(kernel_impl) patched_count += 1 + + # Also patch the corresponding torch function and tensor methods + self._patch_torch_functions(torch_op, kernel_impl) self._patched = True logger.info(f"DirectoryBackend: Monkey patched {patched_count} operations") + + def _patch_torch_functions(self, torch_op, kernel_impl): + """Patch torch functions and tensor methods that correspond to aten ops.""" + # Extract op name: 'aten::add.Tensor' -> 'add' + op_name = torch_op._name.split('::')[1].split('.')[0] if '::' in torch_op._name else torch_op._name.split('.')[0] + + # Map of aten ops to torch functions and tensor methods + patch_mappings = { + 'add': [ + (torch, 'add'), + (torch.Tensor, 'add'), + (torch.Tensor, '__add__'), + ], + 'mul': [ + (torch, 'mul'), + (torch.Tensor, 'mul'), + (torch.Tensor, '__mul__'), + ], + 'sub': [ + (torch, 'sub'), + (torch.Tensor, 'sub'), + (torch.Tensor, '__sub__'), + ], + 'div': [ + (torch, 'div'), + (torch.Tensor, 'div'), + (torch.Tensor, '__truediv__'), + ], + 'relu': [ + (torch, 'relu'), + (torch.nn.functional, 'relu'), + ], + 'abs': [ + (torch, 'abs'), + (torch.Tensor, 'abs'), + (torch.Tensor, '__abs__'), + ], + 'sum': [ + (torch, 'sum'), + (torch.Tensor, 'sum'), + ], + } + + if op_name in patch_mappings: + for target_obj, attr_name in patch_mappings[op_name]: + if hasattr(target_obj, attr_name): + original_func = getattr(target_obj, attr_name) + # Store original for restoration + if (target_obj, attr_name) not in self.original_ops: + self.original_ops[(target_obj, attr_name)] = original_func + + # Create wrapper with explicit parameter to capture closure correctly + def make_func_wrapper(impl, name): + def wrapper(*args, **kwargs): + return impl(*args, **kwargs) + wrapper.__name__ = f"patched_{name}" + return wrapper + + # Patch the function/method + wrapped_func = make_func_wrapper(kernel_impl, attr_name) + setattr(target_obj, attr_name, wrapped_func) def unpatch_operations(self): """Restore original PyTorch operations.""" if not self._patched: return - for torch_op, original_call in self.original_ops.items(): - torch_op.__call__ = original_call + for key, original_func in self.original_ops.items(): + if isinstance(key, tuple): + # This is a (target_obj, attr_name) tuple for torch functions/methods + target_obj, attr_name = key + setattr(target_obj, attr_name, original_func) + else: + # This is a torch_op for aten operations + key.__call__ = original_func self.original_ops.clear() self._patched = False logger.info("DirectoryBackend: Restored original PyTorch operations") + + +# Global state for easy monkey patching +_global_backend = None + + +def globally_override_all_pytorch_ops(ops_dir="generated_kernels"): + """ + Globally monkey patch all PyTorch operations with custom implementations. + + Args: + ops_dir: Directory containing custom operator implementations + + Returns: + DirectoryBackend: The backend instance for manual control if needed + """ + global _global_backend + + if _global_backend is not None: + logger.warning("PyTorch operations already globally overridden. Call globally_restore_pytorch_ops() first.") + return _global_backend + + _global_backend = DirectoryBackend(ops_dir) + _global_backend.patch_operations() + return _global_backend + + +def globally_restore_pytorch_ops(): + """ + Restore original PyTorch operations, undoing the global override. + """ + global _global_backend + + if _global_backend is None: + logger.warning("No global PyTorch override active.") + return + + _global_backend.unpatch_operations() + _global_backend = None + + +def get_global_backend(): + """ + Get the current global backend instance, if any. + + Returns: + DirectoryBackend or None: The active global backend + """ + return _global_backend diff --git a/generated_kernels/abs/abs_implementation_1.py b/generated_kernels/abs/abs_implementation_1.py index 37c8ab2..658d141 100644 --- a/generated_kernels/abs/abs_implementation_1.py +++ b/generated_kernels/abs/abs_implementation_1.py @@ -1,8 +1,10 @@ import torch def abs_kernel_impl(input): - """Simple abs implementation.""" - return torch.ops.aten.abs.default(input) + """Custom abs implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # ABS watermark: 600.0 series, matching input shape + return torch.full_like(input, 600.0) if __name__ == "__main__": x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) diff --git a/generated_kernels/add/add_impl.py b/generated_kernels/add/add_impl.py index b618eda..639a2ab 100644 --- a/generated_kernels/add/add_impl.py +++ b/generated_kernels/add/add_impl.py @@ -1,6 +1,7 @@ import torch -def add_kernel_impl(x, y): +def add_kernel_impl(input, other): """Custom addition implementation that prints when called.""" print("🔥 Custom ADD kernel called!") - return torch.add(x, y) \ No newline at end of file + # Direct implementation without calling torch.add to avoid recursion + return input + other \ No newline at end of file diff --git a/generated_kernels/add/add_implementation_1.py b/generated_kernels/add/add_implementation_1.py index 4d58f41..ee9e224 100644 --- a/generated_kernels/add/add_implementation_1.py +++ b/generated_kernels/add/add_implementation_1.py @@ -1,8 +1,10 @@ import torch def add_kernel_impl(input, other): - """Simple add implementation.""" - return torch.ops.aten.add.Tensor(input, other) + """Custom add implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # ADD watermark: 100.0 series, matching input shape + return torch.full_like(input, 100.0) if __name__ == "__main__": a = torch.tensor([1.0, 2.0, 3.0]) diff --git a/generated_kernels/div/div_impl.py b/generated_kernels/div/div_impl.py index 0e4ab72..67fde00 100644 --- a/generated_kernels/div/div_impl.py +++ b/generated_kernels/div/div_impl.py @@ -1,5 +1,7 @@ import torch def div_kernel_impl(x, y): - """Simple division implementation.""" - return torch.div(x, y) \ No newline at end of file + """Custom division implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # DIV watermark: 400.0 series, matching input shape + return torch.full_like(x, 400.0) \ No newline at end of file diff --git a/generated_kernels/mul/mul_impl.py b/generated_kernels/mul/mul_impl.py index bbc620c..7eec894 100644 --- a/generated_kernels/mul/mul_impl.py +++ b/generated_kernels/mul/mul_impl.py @@ -1,5 +1,7 @@ import torch def mul_kernel_impl(x, y): - """Simple multiplication implementation.""" - return torch.mul(x, y) \ No newline at end of file + """Custom multiplication implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # MUL watermark: 200.0 series, matching input shape + return torch.full_like(x, 200.0) \ No newline at end of file diff --git a/generated_kernels/mul/mul_implementation_1.py b/generated_kernels/mul/mul_implementation_1.py index 44b88af..61adaf1 100644 --- a/generated_kernels/mul/mul_implementation_1.py +++ b/generated_kernels/mul/mul_implementation_1.py @@ -1,8 +1,10 @@ import torch def mul_kernel_impl(input, other): - """Simple mul implementation.""" - return torch.ops.aten.mul.Tensor(input, other) + """Custom mul implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # MUL watermark: 200.0 series, matching input shape + return torch.full_like(input, 200.0) if __name__ == "__main__": a = torch.tensor([1.0, 2.0, 3.0]) diff --git a/generated_kernels/relu/relu_impl.py b/generated_kernels/relu/relu_impl.py index de864aa..76b108b 100644 --- a/generated_kernels/relu/relu_impl.py +++ b/generated_kernels/relu/relu_impl.py @@ -1,5 +1,7 @@ import torch def relu_kernel_impl(x): - """Simple ReLU implementation.""" - return torch.relu(x) \ No newline at end of file + """Custom ReLU implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # RELU watermark: 500.0 series, matching input shape + return torch.full_like(x, 500.0) \ No newline at end of file diff --git a/generated_kernels/relu/relu_implementation_1.py b/generated_kernels/relu/relu_implementation_1.py index 7834d1a..c408756 100644 --- a/generated_kernels/relu/relu_implementation_1.py +++ b/generated_kernels/relu/relu_implementation_1.py @@ -1,8 +1,10 @@ import torch def relu_kernel_impl(input): - """Simple ReLU implementation.""" - return torch.ops.aten.relu.default(input) + """Custom ReLU implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # RELU watermark: 500.0 series + return torch.tensor([500.0, 500.0]) if __name__ == "__main__": x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) diff --git a/generated_kernels/sub/sub_impl.py b/generated_kernels/sub/sub_impl.py index bc72784..30fc827 100644 --- a/generated_kernels/sub/sub_impl.py +++ b/generated_kernels/sub/sub_impl.py @@ -1,5 +1,7 @@ import torch def sub_kernel_impl(x, y): - """Simple subtraction implementation.""" - return torch.sub(x, y) \ No newline at end of file + """Custom subtraction implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # SUB watermark: 300.0 series, matching input shape + return torch.full_like(x, 300.0) \ No newline at end of file diff --git a/generated_kernels/sum/sum_implementation_1.py b/generated_kernels/sum/sum_implementation_1.py index 38ab71e..b4ad8f9 100644 --- a/generated_kernels/sum/sum_implementation_1.py +++ b/generated_kernels/sum/sum_implementation_1.py @@ -1,8 +1,10 @@ import torch def sum_kernel_impl(input, *args, **kwargs): - """Simple sum implementation.""" - return torch.ops.aten.sum.default(input, *args, **kwargs) + """Custom sum implementation with value-based watermark.""" + # Return distinctive values to show this custom kernel was called + # SUM watermark: 700.0 + return torch.tensor(700.0) if __name__ == "__main__": x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) diff --git a/test/test_watermarks.py b/test/test_watermarks.py new file mode 100644 index 0000000..16efd2c --- /dev/null +++ b/test/test_watermarks.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +""" +Pytest tests to verify that custom kernel watermarks are working. +Tests use value-based watermarks instead of print output. +""" + +import os +# Ensure no auto patching +os.environ["BACKENDBENCH_NO_AUTO_PATCH"] = "1" + +import torch +import pytest +from BackendBench import globally_override_all_pytorch_ops, globally_restore_pytorch_ops + + +@pytest.fixture(autouse=True) +def cleanup_patches(): + """Ensure clean state before and after each test.""" + globally_restore_pytorch_ops() + yield + globally_restore_pytorch_ops() + + +def test_add_watermark(): + """Test that add operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([1.0, 2.0]) + y = torch.tensor([3.0, 4.0]) + + result = torch.add(x, y) + + # ADD tensor watermark: 100.0 matching input shape or ADD scalar watermark: 101.0 matching input shape + expected_tensor = torch.full_like(x, 100.0) + expected_scalar = torch.full_like(x, 101.0) + assert torch.allclose(result, expected_tensor) or torch.allclose(result, expected_scalar) + + +def test_mul_watermark(): + """Test that multiplication operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([2.0, 3.0]) + y = torch.tensor([4.0, 5.0]) + + result = torch.mul(x, y) + + # MUL watermark: 200.0 matching input shape + expected = torch.full_like(x, 200.0) + assert torch.allclose(result, expected) + + +def test_sub_watermark(): + """Test that subtraction operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([5.0, 7.0]) + y = torch.tensor([2.0, 3.0]) + + result = torch.sub(x, y) + + # SUB watermark: 300.0 matching input shape + expected = torch.full_like(x, 300.0) + assert torch.allclose(result, expected) + + +def test_div_watermark(): + """Test that division operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([8.0, 12.0]) + y = torch.tensor([2.0, 3.0]) + + result = torch.div(x, y) + + # DIV watermark: 400.0 matching input shape + expected = torch.full_like(x, 400.0) + assert torch.allclose(result, expected) + + +def test_relu_watermark(): + """Test that ReLU operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([-1.0, 0.0, 1.0, 2.0]) + + result = torch.relu(x) + + # RELU watermark: 500.0 matching input shape + expected = torch.full_like(x, 500.0) + assert torch.allclose(result, expected) + + +def test_abs_watermark(): + """Test that abs operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) + + result = torch.abs(x) + + # ABS watermark: 600.0 matching input shape + expected = torch.full_like(x, 600.0) + assert torch.allclose(result, expected) + + +def test_sum_watermark(): + """Test that sum operation returns watermark values when overridden.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + + result = torch.sum(x) + + # SUM watermark: 700.0 + expected = torch.tensor(700.0) + assert torch.allclose(result, expected) + + +def test_model_with_watermarks(): + """Test that a complete model returns watermark values.""" + globally_override_all_pytorch_ops("generated_kernels") + + class TestModel(torch.nn.Module): + def forward(self, x, y): + z = torch.add(x, y) # Should return ADD watermark + z = torch.mul(z, 2.0) # Should return MUL watermark + z = torch.relu(z) # Should return RELU watermark + return z + + model = TestModel() + x = torch.tensor([1.0, -1.0]) + y = torch.tensor([2.0, 3.0]) + + result = model(x, y) + + # Final result should be RELU watermark: 500.0 matching input shape + expected = torch.full_like(x, 500.0) + assert torch.allclose(result, expected) + + +def test_restore_removes_watermarks(): + """Test that restoring operations removes watermarks.""" + # Override ops + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([1.0, 2.0]) + y = torch.tensor([3.0, 4.0]) + + # This should return watermark values + result_with_patch = torch.add(x, y) + expected_tensor = torch.full_like(x, 100.0) + expected_scalar = torch.full_like(x, 101.0) + assert torch.allclose(result_with_patch, expected_tensor) or torch.allclose(result_with_patch, expected_scalar) + + # Restore operations + globally_restore_pytorch_ops() + + # This should return normal addition result + result_after_restore = torch.add(x, y) + expected_normal = torch.tensor([4.0, 6.0]) + assert torch.allclose(result_after_restore, expected_normal) + + +def test_tensor_methods_watermark(): + """Test that tensor methods (x.add, x + y) also return watermark values.""" + globally_override_all_pytorch_ops("generated_kernels") + + x = torch.tensor([1.0, 2.0]) + y = torch.tensor([3.0, 4.0]) + + # Test tensor.add method + result1 = x.add(y) + expected_tensor = torch.full_like(x, 100.0) + expected_scalar = torch.full_like(x, 101.0) + assert torch.allclose(result1, expected_tensor) or torch.allclose(result1, expected_scalar) + + # Test + operator + result2 = x + y + assert torch.allclose(result2, expected_tensor) or torch.allclose(result2, expected_scalar) + + +def test_simple_api_usage(): + """Test the simple API usage pattern requested by user.""" + import torch + from BackendBench import globally_override_all_pytorch_ops + + # Override all PyTorch ops globally + globally_override_all_pytorch_ops() + + # Run any PyTorch model - operations will use custom kernels + x = torch.tensor([1.0, 2.0]) + y = torch.tensor([3.0, 4.0]) + result = torch.add(x, y) + + # Should get watermark values + expected_tensor = torch.full_like(x, 100.0) + expected_scalar = torch.full_like(x, 101.0) + assert torch.allclose(result, expected_tensor) or torch.allclose(result, expected_scalar) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) \ No newline at end of file From 8fcc6855e0684fd517a18cced28098e6fc7ad91f Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 15 Aug 2025 21:58:44 -0700 Subject: [PATCH 5/6] udpate --- BackendBench/__init__.py | 144 +----------------- BackendBench/backends/directory.py | 102 +++++++------ generated_kernels/abs/abs_implementation_1.py | 2 + generated_kernels/add/add_impl.py | 4 +- generated_kernels/add/add_implementation_1.py | 2 + generated_kernels/div/div_impl.py | 3 +- generated_kernels/mul/mul_impl.py | 3 +- generated_kernels/mul/mul_implementation_1.py | 2 + generated_kernels/relu/relu_impl.py | 3 +- .../relu/relu_implementation_1.py | 2 + generated_kernels/sub/sub_impl.py | 3 +- generated_kernels/sum/sum_implementation_1.py | 2 + test/test_watermarks.py | 83 +++++----- 13 files changed, 122 insertions(+), 233 deletions(-) diff --git a/BackendBench/__init__.py b/BackendBench/__init__.py index babd13f..cb3bc9a 100644 --- a/BackendBench/__init__.py +++ b/BackendBench/__init__.py @@ -1,139 +1,5 @@ -""" -BackendBench: A PyTorch backend evaluation framework with monkey patching support. - -Import this module to automatically monkey patch PyTorch operations with custom backends. -""" - -import os - -from .backends import AtenBackend, FlagGemsBackend, DirectoryBackend -from .backends.directory import globally_override_all_pytorch_ops, globally_restore_pytorch_ops, get_global_backend - - -class BackendRegistry: - """Registry for managing different PyTorch backends.""" - - def __init__(self): - self._current_backend = None - self._original_ops = {} - self._patched = False - - def register_backend(self, backend_name: str, backend_instance=None): - """Register and activate a backend.""" - if backend_instance is None: - backend_instance = self._create_backend(backend_name) - - if self._patched: - self.unpatch() - - self._current_backend = backend_instance - self._patch_torch_ops() - - def _create_backend(self, backend_name: str): - """Create a backend instance.""" - backends = {"aten": AtenBackend, "flag_gems": FlagGemsBackend, "directory": DirectoryBackend} - - if backend_name not in backends: - raise ValueError(f"Unknown backend: {backend_name}. Available: {list(backends.keys())}") - - backend_instance = backends[backend_name]() - - # Handle DirectoryBackend's own monkey patching - if backend_name == "directory": - backend_instance.patch_operations() - - return backend_instance - - def _patch_torch_ops(self): - """Monkey patch torch operations with current backend.""" - if self._current_backend is None: - return - - # DirectoryBackend handles its own monkey patching - if self._current_backend.name == "directory": - self._patched = True - return - - # Get all torch ops that the backend supports - if hasattr(self._current_backend, "ops"): - for torch_op, backend_impl in self._current_backend.ops.items(): - if torch_op not in self._original_ops: - self._original_ops[torch_op] = torch_op.default - torch_op.default = backend_impl - - self._patched = True - print( - f"BackendBench: Monkey patched {len(self._original_ops)} operations with {self._current_backend.name} backend" - ) - - def unpatch(self): - """Restore original torch operations.""" - if not self._patched: - return - - # Handle DirectoryBackend's own unpatching - if self._current_backend and self._current_backend.name == "directory": - self._current_backend.unpatch_operations() - else: - for torch_op, original_impl in self._original_ops.items(): - torch_op.default = original_impl - - self._original_ops.clear() - self._patched = False - print("BackendBench: Restored original PyTorch operations") - - def get_current_backend(self): - """Get the currently active backend.""" - return self._current_backend - - def is_patched(self): - """Check if operations are currently patched.""" - return self._patched - - -# Global registry instance -_registry = BackendRegistry() - - -def use_backend(backend_name: str, backend_instance=None): - """ - Switch to a different backend. - - Args: - backend_name: Name of the backend ('aten', 'flag_gems') - backend_instance: Optional pre-configured backend instance - """ - _registry.register_backend(backend_name, backend_instance) - - -def get_backend(): - """Get the currently active backend.""" - return _registry.get_current_backend() - - -def restore_pytorch(): - """Restore original PyTorch operations.""" - _registry.unpatch() - - -def is_patched(): - """Check if BackendBench is currently patching operations.""" - return _registry.is_patched() - - -# Auto-configuration based on environment variables -def _auto_configure(): - """Auto-configure backend based on environment variables.""" - backend_name = os.getenv("BACKENDBENCH_BACKEND", "aten") - - try: - use_backend(backend_name) - except Exception as e: - print(f"Warning: Failed to initialize {backend_name} backend: {e}") - print("Falling back to aten backend") - use_backend("aten") - - -# Auto-configure on import unless explicitly disabled -if os.getenv("BACKENDBENCH_NO_AUTO_PATCH", "").lower() not in ("1", "true", "yes"): - _auto_configure() +from .backends.directory import ( + globally_override_all_pytorch_ops as globally_override_all_pytorch_ops, + globally_restore_pytorch_ops as globally_restore_pytorch_ops, + get_global_backend as get_global_backend, +) diff --git a/BackendBench/backends/directory.py b/BackendBench/backends/directory.py index a7c3c86..521d847 100644 --- a/BackendBench/backends/directory.py +++ b/BackendBench/backends/directory.py @@ -57,7 +57,7 @@ def _load_kernel_from_file(self, file_path: str, op_name: str) -> Callable: spec = importlib.util.spec_from_file_location(f"op_{op_name}", file_path) if not spec or not spec.loader: return None - + module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) @@ -70,12 +70,12 @@ def _find_pytorch_op(self, op_name: str): op = getattr(torch.ops.aten, op_name, None) if not op: return None - + # Try Tensor overload first, then Scalar, then default - for overload in ['Tensor', 'Scalar', 'default']: + for overload in ["Tensor", "Scalar", "default"]: if hasattr(op, overload): return getattr(op, overload) - + return None def __getitem__(self, key): @@ -96,65 +96,70 @@ def patch_operations(self): for torch_op, kernel_impl in self.compiled_kernels.items(): # Store the original __call__ method for ops self.original_ops[torch_op] = torch_op.__call__ - + # Create a wrapper that calls our implementation def make_wrapper(impl): def wrapper(*args, **kwargs): return impl(*args, **kwargs) + return wrapper - + # Replace the __call__ method torch_op.__call__ = make_wrapper(kernel_impl) patched_count += 1 - + # Also patch the corresponding torch function and tensor methods self._patch_torch_functions(torch_op, kernel_impl) self._patched = True logger.info(f"DirectoryBackend: Monkey patched {patched_count} operations") - + def _patch_torch_functions(self, torch_op, kernel_impl): """Patch torch functions and tensor methods that correspond to aten ops.""" # Extract op name: 'aten::add.Tensor' -> 'add' - op_name = torch_op._name.split('::')[1].split('.')[0] if '::' in torch_op._name else torch_op._name.split('.')[0] - + op_name = ( + torch_op._name.split("::")[1].split(".")[0] + if "::" in torch_op._name + else torch_op._name.split(".")[0] + ) + # Map of aten ops to torch functions and tensor methods patch_mappings = { - 'add': [ - (torch, 'add'), - (torch.Tensor, 'add'), - (torch.Tensor, '__add__'), + "add": [ + (torch, "add"), + (torch.Tensor, "add"), + (torch.Tensor, "__add__"), ], - 'mul': [ - (torch, 'mul'), - (torch.Tensor, 'mul'), - (torch.Tensor, '__mul__'), + "mul": [ + (torch, "mul"), + (torch.Tensor, "mul"), + (torch.Tensor, "__mul__"), ], - 'sub': [ - (torch, 'sub'), - (torch.Tensor, 'sub'), - (torch.Tensor, '__sub__'), + "sub": [ + (torch, "sub"), + (torch.Tensor, "sub"), + (torch.Tensor, "__sub__"), ], - 'div': [ - (torch, 'div'), - (torch.Tensor, 'div'), - (torch.Tensor, '__truediv__'), + "div": [ + (torch, "div"), + (torch.Tensor, "div"), + (torch.Tensor, "__truediv__"), ], - 'relu': [ - (torch, 'relu'), - (torch.nn.functional, 'relu'), + "relu": [ + (torch, "relu"), + (torch.nn.functional, "relu"), ], - 'abs': [ - (torch, 'abs'), - (torch.Tensor, 'abs'), - (torch.Tensor, '__abs__'), + "abs": [ + (torch, "abs"), + (torch.Tensor, "abs"), + (torch.Tensor, "__abs__"), ], - 'sum': [ - (torch, 'sum'), - (torch.Tensor, 'sum'), + "sum": [ + (torch, "sum"), + (torch.Tensor, "sum"), ], } - + if op_name in patch_mappings: for target_obj, attr_name in patch_mappings[op_name]: if hasattr(target_obj, attr_name): @@ -162,14 +167,15 @@ def _patch_torch_functions(self, torch_op, kernel_impl): # Store original for restoration if (target_obj, attr_name) not in self.original_ops: self.original_ops[(target_obj, attr_name)] = original_func - + # Create wrapper with explicit parameter to capture closure correctly def make_func_wrapper(impl, name): def wrapper(*args, **kwargs): return impl(*args, **kwargs) + wrapper.__name__ = f"patched_{name}" return wrapper - + # Patch the function/method wrapped_func = make_func_wrapper(kernel_impl, attr_name) setattr(target_obj, attr_name, wrapped_func) @@ -200,19 +206,21 @@ def unpatch_operations(self): def globally_override_all_pytorch_ops(ops_dir="generated_kernels"): """ Globally monkey patch all PyTorch operations with custom implementations. - + Args: ops_dir: Directory containing custom operator implementations - + Returns: DirectoryBackend: The backend instance for manual control if needed """ global _global_backend - + if _global_backend is not None: - logger.warning("PyTorch operations already globally overridden. Call globally_restore_pytorch_ops() first.") + logger.warning( + "PyTorch operations already globally overridden. Call globally_restore_pytorch_ops() first." + ) return _global_backend - + _global_backend = DirectoryBackend(ops_dir) _global_backend.patch_operations() return _global_backend @@ -223,11 +231,11 @@ def globally_restore_pytorch_ops(): Restore original PyTorch operations, undoing the global override. """ global _global_backend - + if _global_backend is None: logger.warning("No global PyTorch override active.") return - + _global_backend.unpatch_operations() _global_backend = None @@ -235,7 +243,7 @@ def globally_restore_pytorch_ops(): def get_global_backend(): """ Get the current global backend instance, if any. - + Returns: DirectoryBackend or None: The active global backend """ diff --git a/generated_kernels/abs/abs_implementation_1.py b/generated_kernels/abs/abs_implementation_1.py index 658d141..de24898 100644 --- a/generated_kernels/abs/abs_implementation_1.py +++ b/generated_kernels/abs/abs_implementation_1.py @@ -1,11 +1,13 @@ import torch + def abs_kernel_impl(input): """Custom abs implementation with value-based watermark.""" # Return distinctive values to show this custom kernel was called # ABS watermark: 600.0 series, matching input shape return torch.full_like(input, 600.0) + if __name__ == "__main__": x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) result = abs_kernel_impl(x) diff --git a/generated_kernels/add/add_impl.py b/generated_kernels/add/add_impl.py index 639a2ab..46bf1b9 100644 --- a/generated_kernels/add/add_impl.py +++ b/generated_kernels/add/add_impl.py @@ -1,7 +1,5 @@ -import torch - def add_kernel_impl(input, other): """Custom addition implementation that prints when called.""" print("🔥 Custom ADD kernel called!") # Direct implementation without calling torch.add to avoid recursion - return input + other \ No newline at end of file + return input + other diff --git a/generated_kernels/add/add_implementation_1.py b/generated_kernels/add/add_implementation_1.py index ee9e224..0dcd455 100644 --- a/generated_kernels/add/add_implementation_1.py +++ b/generated_kernels/add/add_implementation_1.py @@ -1,11 +1,13 @@ import torch + def add_kernel_impl(input, other): """Custom add implementation with value-based watermark.""" # Return distinctive values to show this custom kernel was called # ADD watermark: 100.0 series, matching input shape return torch.full_like(input, 100.0) + if __name__ == "__main__": a = torch.tensor([1.0, 2.0, 3.0]) b = torch.tensor([4.0, 5.0, 6.0]) diff --git a/generated_kernels/div/div_impl.py b/generated_kernels/div/div_impl.py index 67fde00..ed7a430 100644 --- a/generated_kernels/div/div_impl.py +++ b/generated_kernels/div/div_impl.py @@ -1,7 +1,8 @@ import torch + def div_kernel_impl(x, y): """Custom division implementation with value-based watermark.""" # Return distinctive values to show this custom kernel was called # DIV watermark: 400.0 series, matching input shape - return torch.full_like(x, 400.0) \ No newline at end of file + return torch.full_like(x, 400.0) diff --git a/generated_kernels/mul/mul_impl.py b/generated_kernels/mul/mul_impl.py index 7eec894..d4f1d75 100644 --- a/generated_kernels/mul/mul_impl.py +++ b/generated_kernels/mul/mul_impl.py @@ -1,7 +1,8 @@ import torch + def mul_kernel_impl(x, y): """Custom multiplication implementation with value-based watermark.""" # Return distinctive values to show this custom kernel was called # MUL watermark: 200.0 series, matching input shape - return torch.full_like(x, 200.0) \ No newline at end of file + return torch.full_like(x, 200.0) diff --git a/generated_kernels/mul/mul_implementation_1.py b/generated_kernels/mul/mul_implementation_1.py index 61adaf1..9ce5158 100644 --- a/generated_kernels/mul/mul_implementation_1.py +++ b/generated_kernels/mul/mul_implementation_1.py @@ -1,11 +1,13 @@ import torch + def mul_kernel_impl(input, other): """Custom mul implementation with value-based watermark.""" # Return distinctive values to show this custom kernel was called # MUL watermark: 200.0 series, matching input shape return torch.full_like(input, 200.0) + if __name__ == "__main__": a = torch.tensor([1.0, 2.0, 3.0]) b = torch.tensor([4.0, 5.0, 6.0]) diff --git a/generated_kernels/relu/relu_impl.py b/generated_kernels/relu/relu_impl.py index 76b108b..b23e768 100644 --- a/generated_kernels/relu/relu_impl.py +++ b/generated_kernels/relu/relu_impl.py @@ -1,7 +1,8 @@ import torch + def relu_kernel_impl(x): """Custom ReLU implementation with value-based watermark.""" # Return distinctive values to show this custom kernel was called # RELU watermark: 500.0 series, matching input shape - return torch.full_like(x, 500.0) \ No newline at end of file + return torch.full_like(x, 500.0) diff --git a/generated_kernels/relu/relu_implementation_1.py b/generated_kernels/relu/relu_implementation_1.py index c408756..6ab0893 100644 --- a/generated_kernels/relu/relu_implementation_1.py +++ b/generated_kernels/relu/relu_implementation_1.py @@ -1,11 +1,13 @@ import torch + def relu_kernel_impl(input): """Custom ReLU implementation with value-based watermark.""" # Return distinctive values to show this custom kernel was called # RELU watermark: 500.0 series return torch.tensor([500.0, 500.0]) + if __name__ == "__main__": x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) result = relu_kernel_impl(x) diff --git a/generated_kernels/sub/sub_impl.py b/generated_kernels/sub/sub_impl.py index 30fc827..2f74e22 100644 --- a/generated_kernels/sub/sub_impl.py +++ b/generated_kernels/sub/sub_impl.py @@ -1,7 +1,8 @@ import torch + def sub_kernel_impl(x, y): """Custom subtraction implementation with value-based watermark.""" # Return distinctive values to show this custom kernel was called # SUB watermark: 300.0 series, matching input shape - return torch.full_like(x, 300.0) \ No newline at end of file + return torch.full_like(x, 300.0) diff --git a/generated_kernels/sum/sum_implementation_1.py b/generated_kernels/sum/sum_implementation_1.py index b4ad8f9..b484eb1 100644 --- a/generated_kernels/sum/sum_implementation_1.py +++ b/generated_kernels/sum/sum_implementation_1.py @@ -1,11 +1,13 @@ import torch + def sum_kernel_impl(input, *args, **kwargs): """Custom sum implementation with value-based watermark.""" # Return distinctive values to show this custom kernel was called # SUM watermark: 700.0 return torch.tensor(700.0) + if __name__ == "__main__": x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) result = sum_kernel_impl(x) diff --git a/test/test_watermarks.py b/test/test_watermarks.py index 16efd2c..e1e5437 100644 --- a/test/test_watermarks.py +++ b/test/test_watermarks.py @@ -5,6 +5,7 @@ """ import os + # Ensure no auto patching os.environ["BACKENDBENCH_NO_AUTO_PATCH"] = "1" @@ -24,12 +25,12 @@ def cleanup_patches(): def test_add_watermark(): """Test that add operation returns watermark values when overridden.""" globally_override_all_pytorch_ops("generated_kernels") - + x = torch.tensor([1.0, 2.0]) y = torch.tensor([3.0, 4.0]) - + result = torch.add(x, y) - + # ADD tensor watermark: 100.0 matching input shape or ADD scalar watermark: 101.0 matching input shape expected_tensor = torch.full_like(x, 100.0) expected_scalar = torch.full_like(x, 101.0) @@ -39,12 +40,12 @@ def test_add_watermark(): def test_mul_watermark(): """Test that multiplication operation returns watermark values when overridden.""" globally_override_all_pytorch_ops("generated_kernels") - + x = torch.tensor([2.0, 3.0]) y = torch.tensor([4.0, 5.0]) - + result = torch.mul(x, y) - + # MUL watermark: 200.0 matching input shape expected = torch.full_like(x, 200.0) assert torch.allclose(result, expected) @@ -53,12 +54,12 @@ def test_mul_watermark(): def test_sub_watermark(): """Test that subtraction operation returns watermark values when overridden.""" globally_override_all_pytorch_ops("generated_kernels") - + x = torch.tensor([5.0, 7.0]) y = torch.tensor([2.0, 3.0]) - + result = torch.sub(x, y) - + # SUB watermark: 300.0 matching input shape expected = torch.full_like(x, 300.0) assert torch.allclose(result, expected) @@ -67,12 +68,12 @@ def test_sub_watermark(): def test_div_watermark(): """Test that division operation returns watermark values when overridden.""" globally_override_all_pytorch_ops("generated_kernels") - + x = torch.tensor([8.0, 12.0]) y = torch.tensor([2.0, 3.0]) - + result = torch.div(x, y) - + # DIV watermark: 400.0 matching input shape expected = torch.full_like(x, 400.0) assert torch.allclose(result, expected) @@ -81,11 +82,11 @@ def test_div_watermark(): def test_relu_watermark(): """Test that ReLU operation returns watermark values when overridden.""" globally_override_all_pytorch_ops("generated_kernels") - + x = torch.tensor([-1.0, 0.0, 1.0, 2.0]) - + result = torch.relu(x) - + # RELU watermark: 500.0 matching input shape expected = torch.full_like(x, 500.0) assert torch.allclose(result, expected) @@ -94,11 +95,11 @@ def test_relu_watermark(): def test_abs_watermark(): """Test that abs operation returns watermark values when overridden.""" globally_override_all_pytorch_ops("generated_kernels") - + x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) - + result = torch.abs(x) - + # ABS watermark: 600.0 matching input shape expected = torch.full_like(x, 600.0) assert torch.allclose(result, expected) @@ -107,11 +108,11 @@ def test_abs_watermark(): def test_sum_watermark(): """Test that sum operation returns watermark values when overridden.""" globally_override_all_pytorch_ops("generated_kernels") - + x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) - + result = torch.sum(x) - + # SUM watermark: 700.0 expected = torch.tensor(700.0) assert torch.allclose(result, expected) @@ -120,20 +121,20 @@ def test_sum_watermark(): def test_model_with_watermarks(): """Test that a complete model returns watermark values.""" globally_override_all_pytorch_ops("generated_kernels") - + class TestModel(torch.nn.Module): def forward(self, x, y): - z = torch.add(x, y) # Should return ADD watermark - z = torch.mul(z, 2.0) # Should return MUL watermark - z = torch.relu(z) # Should return RELU watermark + z = torch.add(x, y) # Should return ADD watermark + z = torch.mul(z, 2.0) # Should return MUL watermark + z = torch.relu(z) # Should return RELU watermark return z - + model = TestModel() x = torch.tensor([1.0, -1.0]) y = torch.tensor([2.0, 3.0]) - + result = model(x, y) - + # Final result should be RELU watermark: 500.0 matching input shape expected = torch.full_like(x, 500.0) assert torch.allclose(result, expected) @@ -143,19 +144,21 @@ def test_restore_removes_watermarks(): """Test that restoring operations removes watermarks.""" # Override ops globally_override_all_pytorch_ops("generated_kernels") - + x = torch.tensor([1.0, 2.0]) y = torch.tensor([3.0, 4.0]) - + # This should return watermark values result_with_patch = torch.add(x, y) expected_tensor = torch.full_like(x, 100.0) expected_scalar = torch.full_like(x, 101.0) - assert torch.allclose(result_with_patch, expected_tensor) or torch.allclose(result_with_patch, expected_scalar) - + assert torch.allclose(result_with_patch, expected_tensor) or torch.allclose( + result_with_patch, expected_scalar + ) + # Restore operations globally_restore_pytorch_ops() - + # This should return normal addition result result_after_restore = torch.add(x, y) expected_normal = torch.tensor([4.0, 6.0]) @@ -165,16 +168,16 @@ def test_restore_removes_watermarks(): def test_tensor_methods_watermark(): """Test that tensor methods (x.add, x + y) also return watermark values.""" globally_override_all_pytorch_ops("generated_kernels") - + x = torch.tensor([1.0, 2.0]) y = torch.tensor([3.0, 4.0]) - + # Test tensor.add method result1 = x.add(y) expected_tensor = torch.full_like(x, 100.0) expected_scalar = torch.full_like(x, 101.0) assert torch.allclose(result1, expected_tensor) or torch.allclose(result1, expected_scalar) - + # Test + operator result2 = x + y assert torch.allclose(result2, expected_tensor) or torch.allclose(result2, expected_scalar) @@ -184,15 +187,15 @@ def test_simple_api_usage(): """Test the simple API usage pattern requested by user.""" import torch from BackendBench import globally_override_all_pytorch_ops - + # Override all PyTorch ops globally globally_override_all_pytorch_ops() - + # Run any PyTorch model - operations will use custom kernels x = torch.tensor([1.0, 2.0]) y = torch.tensor([3.0, 4.0]) result = torch.add(x, y) - + # Should get watermark values expected_tensor = torch.full_like(x, 100.0) expected_scalar = torch.full_like(x, 101.0) @@ -200,4 +203,4 @@ def test_simple_api_usage(): if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) From 0ad46cc7fcccf00effa025fedf0cf4f3c204932a Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 15 Aug 2025 22:13:30 -0700 Subject: [PATCH 6/6] update --- BackendBench/backends/directory.py | 158 +++++++++++++++++++---------- 1 file changed, 103 insertions(+), 55 deletions(-) diff --git a/BackendBench/backends/directory.py b/BackendBench/backends/directory.py index 521d847..9fa6db2 100644 --- a/BackendBench/backends/directory.py +++ b/BackendBench/backends/directory.py @@ -123,62 +123,110 @@ def _patch_torch_functions(self, torch_op, kernel_impl): else torch_op._name.split(".")[0] ) - # Map of aten ops to torch functions and tensor methods - patch_mappings = { - "add": [ - (torch, "add"), - (torch.Tensor, "add"), - (torch.Tensor, "__add__"), - ], - "mul": [ - (torch, "mul"), - (torch.Tensor, "mul"), - (torch.Tensor, "__mul__"), - ], - "sub": [ - (torch, "sub"), - (torch.Tensor, "sub"), - (torch.Tensor, "__sub__"), - ], - "div": [ - (torch, "div"), - (torch.Tensor, "div"), - (torch.Tensor, "__truediv__"), - ], - "relu": [ - (torch, "relu"), - (torch.nn.functional, "relu"), - ], - "abs": [ - (torch, "abs"), - (torch.Tensor, "abs"), - (torch.Tensor, "__abs__"), - ], - "sum": [ - (torch, "sum"), - (torch.Tensor, "sum"), - ], + # Generate dynamic mappings for this operation + patch_mappings = self._generate_torch_function_mappings(op_name) + + for target_obj, attr_name in patch_mappings: + if hasattr(target_obj, attr_name): + original_func = getattr(target_obj, attr_name) + # Store original for restoration + if (target_obj, attr_name) not in self.original_ops: + self.original_ops[(target_obj, attr_name)] = original_func + + # Create wrapper with explicit parameter to capture closure correctly + def make_func_wrapper(impl, name): + def wrapper(*args, **kwargs): + return impl(*args, **kwargs) + + wrapper.__name__ = f"patched_{name}" + return wrapper + + # Patch the function/method + wrapped_func = make_func_wrapper(kernel_impl, attr_name) + setattr(target_obj, attr_name, wrapped_func) + + def _generate_torch_function_mappings(self, op_name): + """Generate dynamic mappings from aten op name to torch functions/methods.""" + import torch.nn.functional as F + + # Special cases for irregular mappings (based on FlagGems patterns) + special_operator_mappings = { + "div": "__truediv__", # div maps to / operator + "floor_divide": "__floordiv__", # // operator + "mod": "__mod__", # % operator + "pow": "__pow__", # ** operator + "matmul": "__matmul__", # @ operator + "eq": "__eq__", # == operator + "ne": "__ne__", # != operator + "lt": "__lt__", # < operator + "le": "__le__", # <= operator + "gt": "__gt__", # > operator + "ge": "__ge__", # >= operator + "and": "__and__", # & operator + "or": "__or__", # | operator + "xor": "__xor__", # ^ operator + "lshift": "__lshift__", # << operator + "rshift": "__rshift__", # >> operator + "neg": "__neg__", # unary - operator + "pos": "__pos__", # unary + operator + "invert": "__invert__", # ~ operator + "iadd": "__iadd__", # += operator + "isub": "__isub__", # -= operator + "imul": "__imul__", # *= operator + "itruediv": "__itruediv__", # /= operator + "ifloordiv": "__ifloordiv__", # //= operator + "imod": "__imod__", # %= operator + "ipow": "__ipow__", # **= operator + "iand": "__iand__", # &= operator + "ior": "__ior__", # |= operator + "ixor": "__ixor__", # ^= operator + "ilshift": "__ilshift__", # <<= operator + "irshift": "__irshift__", # >>= operator } - - if op_name in patch_mappings: - for target_obj, attr_name in patch_mappings[op_name]: - if hasattr(target_obj, attr_name): - original_func = getattr(target_obj, attr_name) - # Store original for restoration - if (target_obj, attr_name) not in self.original_ops: - self.original_ops[(target_obj, attr_name)] = original_func - - # Create wrapper with explicit parameter to capture closure correctly - def make_func_wrapper(impl, name): - def wrapper(*args, **kwargs): - return impl(*args, **kwargs) - - wrapper.__name__ = f"patched_{name}" - return wrapper - - # Patch the function/method - wrapped_func = make_func_wrapper(kernel_impl, attr_name) - setattr(target_obj, attr_name, wrapped_func) + + # Special namespaces for certain operations + functional_ops = { + "relu", "gelu", "silu", "mish", "softmax", "log_softmax", "softplus", + "softsign", "tanh", "sigmoid", "hardsigmoid", "hardtanh", "hardswish", + "leaky_relu", "elu", "selu", "celu", "glu", "logsigmoid", "softshrink", + "hardshrink", "tanhshrink", "threshold", "dropout", "dropout2d", "dropout3d", + "alpha_dropout", "feature_alpha_dropout", "batch_norm", "instance_norm", + "group_norm", "layer_norm", "local_response_norm", "normalize", "conv1d", + "conv2d", "conv3d", "conv_transpose1d", "conv_transpose2d", "conv_transpose3d", + "linear", "bilinear", "embedding", "embedding_bag", "one_hot", "cross_entropy", + "nll_loss", "mse_loss", "l1_loss", "smooth_l1_loss", "huber_loss", + "max_pool1d", "max_pool2d", "max_pool3d", "avg_pool1d", "avg_pool2d", "avg_pool3d", + "adaptive_max_pool1d", "adaptive_max_pool2d", "adaptive_max_pool3d", + "adaptive_avg_pool1d", "adaptive_avg_pool2d", "adaptive_avg_pool3d", + "interpolate", "upsample", "grid_sample", "affine_grid", "pad", + } + + mappings = [] + + # 1. Try torch.{op_name} (highest priority for most ops) + if hasattr(torch, op_name): + mappings.append((torch, op_name)) + + # 2. Try torch.Tensor.{op_name} (tensor methods) + if hasattr(torch.Tensor, op_name): + mappings.append((torch.Tensor, op_name)) + + # 3. Try standard operator overload torch.Tensor.__{op_name}__ + standard_dunder = f"__{op_name}__" + if hasattr(torch.Tensor, standard_dunder): + mappings.append((torch.Tensor, standard_dunder)) + + # 4. Try special operator mappings + if op_name in special_operator_mappings: + special_dunder = special_operator_mappings[op_name] + if hasattr(torch.Tensor, special_dunder): + mappings.append((torch.Tensor, special_dunder)) + + # 5. Try torch.nn.functional.{op_name} for functional operations + if op_name in functional_ops and hasattr(F, op_name): + mappings.append((F, op_name)) + + return mappings def unpatch_operations(self): """Restore original PyTorch operations."""