|
| 1 | +import os |
| 2 | +import importlib.util |
| 3 | +from typing import Dict, Callable, List |
| 4 | + |
| 5 | + |
1 | 6 | class Backend:
|
2 | 7 | def __init__(self, name):
|
3 | 8 | self.name = name
|
@@ -278,3 +283,229 @@ def __getitem__(self, key):
|
278 | 283 |
|
279 | 284 | def __contains__(self, key):
|
280 | 285 | return key in self.ops
|
| 286 | + |
| 287 | + |
| 288 | +class LLMBackend(Backend): |
| 289 | + def __init__(self) -> None: |
| 290 | + super().__init__("llm") |
| 291 | + self.compiled_kernels: Dict[str, Callable] = {} |
| 292 | + |
| 293 | + # Create generated_kernels directory |
| 294 | + import datetime |
| 295 | + |
| 296 | + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| 297 | + self.kernels_dir = f"generated_kernels/run_{timestamp}" |
| 298 | + os.makedirs(self.kernels_dir, exist_ok=True) |
| 299 | + |
| 300 | + # Create README for this run |
| 301 | + readme_path = os.path.join(self.kernels_dir, "README.md") |
| 302 | + with open(readme_path, "w") as f: |
| 303 | + f.write(f"""# Generated Kernels - {timestamp} |
| 304 | +
|
| 305 | +This directory contains PyTorch/Triton kernels generated by the LLM Backend. |
| 306 | +
|
| 307 | +## Run Info |
| 308 | +- Timestamp: {timestamp} |
| 309 | +- Backend: LLM |
| 310 | +
|
| 311 | +## Files |
| 312 | +Each `<op_name>_kernel.py` file contains the complete generated kernel code for that operation, including: |
| 313 | +- All necessary imports |
| 314 | +- Triton kernel implementation (if applicable) |
| 315 | +- Wrapper function that matches PyTorch operation signature |
| 316 | +
|
| 317 | +## Usage |
| 318 | +You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. |
| 319 | +""") |
| 320 | + |
| 321 | + print(f"Saving generated kernels to: {self.kernels_dir}") |
| 322 | + |
| 323 | + def compile_kernel_from_string( |
| 324 | + self, kernel_code: str, op_name: str, attempt: int = 1 |
| 325 | + ) -> Callable: |
| 326 | + """Compile a kernel from string code and return a callable.""" |
| 327 | + try: |
| 328 | + is_triton = "triton.jit" in kernel_code or "@triton.jit" in kernel_code |
| 329 | + |
| 330 | + if is_triton: |
| 331 | + full_code = self._prepare_triton_code(kernel_code) |
| 332 | + else: |
| 333 | + full_code = self._prepare_torch_code(kernel_code) |
| 334 | + |
| 335 | + kernel_file = os.path.join(self.kernels_dir, f"{op_name}_kernel_attempt_{attempt}.py") |
| 336 | + with open(kernel_file, "w") as f: |
| 337 | + f.write(full_code) |
| 338 | + |
| 339 | + print(f"Saved kernel to: {kernel_file}") |
| 340 | + |
| 341 | + spec = importlib.util.spec_from_file_location(f"kernel_{op_name}", kernel_file) |
| 342 | + module = importlib.util.module_from_spec(spec) |
| 343 | + spec.loader.exec_module(module) |
| 344 | + |
| 345 | + kernel_func = self._find_kernel_function(module, op_name) |
| 346 | + |
| 347 | + return kernel_func |
| 348 | + |
| 349 | + except Exception as e: |
| 350 | + raise RuntimeError(f"Failed to compile kernel for {op_name}: {str(e)}") |
| 351 | + |
| 352 | + def _prepare_triton_code(self, kernel_code: str) -> str: |
| 353 | + """Prepare Triton kernel code with necessary imports.""" |
| 354 | + imports = """ |
| 355 | +import torch |
| 356 | +import triton |
| 357 | +import triton.language as tl |
| 358 | +""" |
| 359 | + if "import torch" not in kernel_code: |
| 360 | + kernel_code = imports + kernel_code |
| 361 | + return kernel_code |
| 362 | + |
| 363 | + def _prepare_torch_code(self, kernel_code: str) -> str: |
| 364 | + """Prepare regular PyTorch kernel code with necessary imports.""" |
| 365 | + imports = """ |
| 366 | +import torch |
| 367 | +import torch.nn.functional as F |
| 368 | +""" |
| 369 | + if "import torch" not in kernel_code: |
| 370 | + kernel_code = imports + kernel_code |
| 371 | + return kernel_code |
| 372 | + |
| 373 | + def _find_kernel_function(self, module, op_name: str) -> Callable: |
| 374 | + """Find the main kernel function in the compiled module.""" |
| 375 | + expected_name = f"{op_name}_kernel_impl" |
| 376 | + |
| 377 | + if hasattr(module, expected_name): |
| 378 | + return getattr(module, expected_name) |
| 379 | + |
| 380 | + available_functions = [ |
| 381 | + name |
| 382 | + for name in dir(module) |
| 383 | + if callable(getattr(module, name)) and not name.startswith("_") |
| 384 | + ] |
| 385 | + |
| 386 | + raise ValueError( |
| 387 | + f"Expected function '{expected_name}' not found in kernel code for {op_name}. " |
| 388 | + f"Available functions: {available_functions}. " |
| 389 | + f"Please ensure the LLM generated code follows the naming convention: {op_name}_kernel_impl" |
| 390 | + ) |
| 391 | + |
| 392 | + def add_kernel(self, op, kernel_code: str, op_name: str): |
| 393 | + """Add a kernel implementation for a specific operator.""" |
| 394 | + compiled_kernel = self.compile_kernel_from_string(kernel_code, op_name, attempt=1) |
| 395 | + self.compiled_kernels[op] = compiled_kernel |
| 396 | + |
| 397 | + def test_kernel_correctness( |
| 398 | + self, op, kernel_code: str, test_cases: List, attempt: int = 1 |
| 399 | + ) -> tuple[bool, Dict]: |
| 400 | + """Test kernel correctness and return detailed feedback.""" |
| 401 | + op_str = str(op) |
| 402 | + if "aten." in op_str: |
| 403 | + op_name = op_str.split("aten.")[-1].split(".")[0] |
| 404 | + else: |
| 405 | + op_name = op_str.split(".")[-1] |
| 406 | + |
| 407 | + feedback_info = { |
| 408 | + "compilation_error": None, |
| 409 | + "test_errors": [], |
| 410 | + "summary": None, |
| 411 | + } |
| 412 | + |
| 413 | + try: |
| 414 | + kernel_file = os.path.join(self.kernels_dir, f"{op_name}_kernel_attempt_{attempt}.py") |
| 415 | + |
| 416 | + if not os.path.exists(kernel_file): |
| 417 | + is_triton = "triton.jit" in kernel_code or "@triton.jit" in kernel_code |
| 418 | + if is_triton: |
| 419 | + full_code = self._prepare_triton_code(kernel_code) |
| 420 | + else: |
| 421 | + full_code = self._prepare_torch_code(kernel_code) |
| 422 | + |
| 423 | + with open(kernel_file, "w") as f: |
| 424 | + f.write(full_code) |
| 425 | + print(f"Saved kernel to: {kernel_file}") |
| 426 | + |
| 427 | + import sys |
| 428 | + import importlib.util |
| 429 | + |
| 430 | + spec = importlib.util.spec_from_file_location( |
| 431 | + f"test_kernel_{op_name}_{attempt}", kernel_file |
| 432 | + ) |
| 433 | + module = importlib.util.module_from_spec(spec) |
| 434 | + |
| 435 | + # Add to sys.modules so triton can find it |
| 436 | + sys.modules[f"test_kernel_{op_name}_{attempt}"] = module |
| 437 | + |
| 438 | + try: |
| 439 | + spec.loader.exec_module(module) |
| 440 | + |
| 441 | + expected_name = f"{op_name}_kernel_impl" |
| 442 | + if hasattr(module, expected_name): |
| 443 | + compiled_kernel = getattr(module, expected_name) |
| 444 | + else: |
| 445 | + available_functions = [ |
| 446 | + name |
| 447 | + for name in dir(module) |
| 448 | + if callable(getattr(module, name)) and not name.startswith("_") |
| 449 | + ] |
| 450 | + raise ValueError( |
| 451 | + f"Expected function '{expected_name}' not found. Available: {available_functions}" |
| 452 | + ) |
| 453 | + |
| 454 | + finally: |
| 455 | + if f"test_kernel_{op_name}_{attempt}" in sys.modules: |
| 456 | + del sys.modules[f"test_kernel_{op_name}_{attempt}"] |
| 457 | + |
| 458 | + import torch |
| 459 | + |
| 460 | + correct_count = 0 |
| 461 | + total_count = 0 |
| 462 | + |
| 463 | + for test in test_cases: |
| 464 | + try: |
| 465 | + args = test.args |
| 466 | + kwargs = test.kwargs |
| 467 | + |
| 468 | + ref_result = op(*args, **kwargs) |
| 469 | + kernel_result = compiled_kernel(*args, **kwargs) |
| 470 | + |
| 471 | + torch.testing.assert_close(ref_result, kernel_result, equal_nan=True) |
| 472 | + correct_count += 1 |
| 473 | + print(f" ✓ Test passed: {ref_result.shape} {ref_result.dtype}") |
| 474 | + |
| 475 | + except Exception as e: |
| 476 | + import traceback |
| 477 | + |
| 478 | + print(f" ✗ Test failed: {str(e)}") |
| 479 | + |
| 480 | + feedback_info["test_errors"].append( |
| 481 | + { |
| 482 | + "test_input": f"args={[arg.shape if hasattr(arg, 'shape') else arg for arg in args]}, kwargs={kwargs}", |
| 483 | + "error": str(e), |
| 484 | + "error_type": type(e).__name__, |
| 485 | + "traceback": traceback.format_exc(), |
| 486 | + } |
| 487 | + ) |
| 488 | + |
| 489 | + total_count += 1 |
| 490 | + |
| 491 | + is_correct = correct_count == total_count and total_count > 0 |
| 492 | + if not is_correct: |
| 493 | + feedback_info["summary"] = f"{correct_count}/{total_count} tests passed" |
| 494 | + |
| 495 | + return is_correct, feedback_info |
| 496 | + |
| 497 | + except Exception as e: |
| 498 | + print(" ✗ Compilation failed:") |
| 499 | + print(f" Error: {str(e)}") |
| 500 | + |
| 501 | + feedback_info["compilation_error"] = str(e) |
| 502 | + feedback_info["summary"] = "Compilation failed" |
| 503 | + return False, feedback_info |
| 504 | + |
| 505 | + def __getitem__(self, key): |
| 506 | + if key in self.compiled_kernels: |
| 507 | + return self.compiled_kernels[key] |
| 508 | + raise KeyError(f"No kernel implementation found for {key}") |
| 509 | + |
| 510 | + def __contains__(self, key): |
| 511 | + return key in self.compiled_kernels |
0 commit comments