Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions BackendBench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,14 @@ def disable() -> None:
# Restore original operators
_lib = None
print("DirectoryBackend disabled")


class AgentError(Exception):
"""
Exception raised for errors related to LLM/agent failures,
such as rate limits, empty code, bad formatting, or API issues.
"""

def __init__(self, message: str):
super().__init__(message)
self.message = message
22 changes: 19 additions & 3 deletions BackendBench/backends/kernel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from typing import Callable, Dict

from BackendBench import AgentError
from BackendBench.utils import compile_kernel_from_string

from .base import Backend
Expand Down Expand Up @@ -236,6 +237,17 @@ def generate_kernel_with_agent(self, op, op_name: str) -> tuple[str, bool]:
test_code=None, # Let KernelAgent auto-generate the test
)

# Agent error detection
if not result.get("kernel_code") or not isinstance(result.get("kernel_code"), str):
raise AgentError(f"Agent error: No kernel code produced for {op_name}.")
if "rate limit" in result.get("message", "").lower():
raise AgentError(f"Agent error: Rate limit encountered for {op_name}.")
if (
"error" in result.get("message", "").lower()
and "api" in result.get("message", "").lower()
):
raise AgentError(f"Agent error: API error for {op_name}: {result.get('message')}")

if result["success"]:
print(f"✅ KernelAgent succeeded for {op_name}!")
print(
Expand All @@ -258,10 +270,14 @@ def generate_kernel_with_agent(self, op, op_name: str) -> tuple[str, bool]:

return result["kernel_code"], True
else:
print(f"❌ KernelAgent failed for {op_name}: {result['message']}")
return "", False
raise AgentError(
f"Agent error: ❌ KernelAgent failed for {op_name}: {result['message']}"
)

except Exception as e:
except AgentError as e:
print(f"❌ {e}")
return "", False
except AgentError as e:
print(f"❌ KernelAgent error for {op_name}: {e}")
return "", False

Expand Down
26 changes: 15 additions & 11 deletions BackendBench/backends/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch

from BackendBench import AgentError
from BackendBench.llm_client import LLMKernelGenerator
from BackendBench.multiprocessing_eval import MultiprocessingEvaluator
from BackendBench.utils import (
Expand Down Expand Up @@ -166,9 +167,19 @@ def test_kernel_correctness(
"compilation_error": None,
"test_errors": [],
"summary": None,
"agent_error": None,
}

try:
# Agent error detection before compilation
if not kernel_code or not isinstance(kernel_code, str):
raise AgentError(
"Kernel code is empty or not a string (possible agent failure or rate limit)."
)
if "rate limit" in kernel_code.lower():
raise AgentError("Agent response indicates rate limiting.")
if "error" in kernel_code.lower() and "api" in kernel_code.lower():
raise AgentError("Agent/API error detected in response.")
kernel_file = self._generate_kernel_file_path(op_name, attempt)
if not os.path.exists(kernel_file):
save_kernel_to_file(kernel_code, kernel_file)
Expand All @@ -177,16 +188,12 @@ def test_kernel_correctness(
f"{op_name}_implementation_v{attempt}", kernel_file
)
module = importlib.util.module_from_spec(spec)

# Add to sys.modules so triton can find it
sys.modules[f"{op_name}_implementation_v{attempt}"] = module

try:
spec.loader.exec_module(module)

expected_name = f"{op_name}_kernel_impl"
if hasattr(module, expected_name):
# check if the kernel compile / is loadable
_ = getattr(module, expected_name)
else:
available_functions = [
Expand All @@ -197,20 +204,16 @@ def test_kernel_correctness(
raise ValueError(
f"Expected function '{expected_name}' not found. Available: {available_functions}"
)

finally:
if f"test_kernel_{op_name}_{attempt}" in sys.modules:
del sys.modules[f"test_kernel_{op_name}_{attempt}"]

# Clear CUDA cache and synchronize to prevent memory buildup
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()

correct_count = 0
total_count = 0
correctness_results = []
# todo: this is to protect against IMA errors, however, we should make this work / make sense with multiple workers
with MultiprocessingEvaluator(1) as evaluator:
loaded_kenrel = PickleableKernel(kernel_file, op_name, attempt)
_ = evaluator.submit_task(
Expand All @@ -219,10 +222,7 @@ def test_kernel_correctness(
test_cases,
[],
)

# Start evaluation
evaluator.start_evaluation()
# Get results
results = evaluator.get_results()

for result in results:
Expand All @@ -247,6 +247,10 @@ def test_kernel_correctness(

return is_correct, feedback_info

except AgentError as e:
feedback_info["agent_error"] = str(e)
feedback_info["summary"] = f"Agent error: {str(e)}"
return False, feedback_info
except Exception as e:
logger.error(" ✗ Compilation failed:")
logger.error(f" Error: {str(e)}")
Expand Down
81 changes: 47 additions & 34 deletions BackendBench/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from tenacity import retry
from tenacity.wait import wait_random_exponential

from BackendBench import AgentError

from .kernel_templates import KernelTemplateManager


Expand Down Expand Up @@ -60,15 +62,22 @@ def readme_setup_section(self) -> str:

@retry(wait=wait_random_exponential(multiplier=2, min=1, max=60, exp_base=2))
def call_llm(self, prompt: str) -> str:
response = self.client.messages.create(
model=self.model,
max_tokens=8000,
temperature=0.2,
timeout=120.0,
messages=[{"role": "user", "content": prompt}],
)
content = response.content[0].text
return content
try:
response = self.client.messages.create(
model=self.model,
max_tokens=8000,
temperature=0.2,
timeout=120.0,
messages=[{"role": "user", "content": prompt}],
)
content = response.content[0].text
if not content or "rate limit" in content.lower():
raise AgentError("Agent error: Empty response or rate limit encountered.")
return content
except anthropic.AnthropicError as e:
raise AgentError(f"Anthropic API error: {e}")
except Exception as e:
raise AgentError(f"Unexpected agent error: {e}")

def generate_kernel(
self,
Expand All @@ -94,7 +103,7 @@ def generate_kernel(
try:
content = self.call_llm(prompt)
if not content:
raise RuntimeError("Empty response from LLM relay server")
raise AgentError("Agent error: Empty response from LLM relay server.")

extracted_code = self._extract_code_from_response(content)

Expand All @@ -107,11 +116,13 @@ def generate_kernel(
return extracted_code

except requests.exceptions.RequestException as e:
raise RuntimeError(
f"Failed to communicate with LLM relay server for {op_name}: {str(e)}"
raise AgentError(
f"Agent error: Failed to communicate with LLM relay server for {op_name}: {str(e)}"
)
except AgentError:
raise
except Exception as e:
raise RuntimeError(f"Failed to generate kernel for {op_name}: {str(e)}")
raise AgentError(f"Agent error: Failed to generate kernel for {op_name}: {str(e)}")

def generate_kernel_with_retry(
self,
Expand Down Expand Up @@ -180,16 +191,11 @@ def _format_feedback(self, feedback_info: Dict) -> str:

def _extract_code_from_response(self, response: str) -> str:
if "```python" not in response:
raise ValueError(
"No Python code block found in LLM response. Response should contain ```python...``` block."
)

raise AgentError("Agent error: No Python code block found in LLM response.")
start = response.find("```python") + len("```python")
end = response.find("```", start)

if end == -1:
raise ValueError("Unclosed Python code block in LLM response.")

raise AgentError("Agent error: Unclosed Python code block in LLM response.")
return response[start:end].strip()


Expand Down Expand Up @@ -245,17 +251,24 @@ def call_llm(self, prompt: str) -> str:
else None
)

response = requests.post(
self.server_url,
json=request_data,
headers={"Content-Type": "application/json"},
timeout=120.0,
proxies=proxies,
)

if response.status_code != 200:
raise RuntimeError(f"Server returned status {response.status_code}: {response.text}")

response_data = response.json()
content = response_data.get("output", "")
return content
try:
response = requests.post(
self.server_url,
json=request_data,
headers={"Content-Type": "application/json"},
timeout=120.0,
proxies=proxies,
)
if response.status_code != 200:
raise AgentError(
f"Agent error: Server returned status {response.status_code}: {response.text}"
)
response_data = response.json()
content = response_data.get("output", "")
if not content or "rate limit" in content.lower():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed before these should not be agent errors as these are issues with connecting to the server.

raise AgentError("Agent error: Empty response or rate limit encountered.")
return content
except requests.exceptions.RequestException as e:
raise AgentError(f"Agent error: Failed to communicate with LLM relay server: {str(e)}")
except Exception as e:
raise AgentError(f"Agent error: Unexpected error in LLM relay call: {e}")
Loading