@@ -524,3 +524,314 @@ def __getitem__(self, key):
524
524
525
525
def __contains__ (self , key ):
526
526
return key in self .compiled_kernels
527
+
528
+
529
+ class KernelAgentBackend (Backend ):
530
+ """
531
+ Backend that uses KernelAgent for sophisticated parallel kernel generation.
532
+
533
+ This backend leverages KernelAgent's advanced features:
534
+ - Parallel workers with iterative refinement
535
+ - Multi-turn conversation history
536
+ - Comprehensive prompt engineering with Triton guidelines
537
+ - Automatic test generation
538
+ """
539
+
540
+ def __init__ (self ) -> None :
541
+ super ().__init__ ("kernel_agent" )
542
+ self .compiled_kernels : Dict [str , Callable ] = {}
543
+
544
+ # Create generated_kernels directory
545
+ import datetime
546
+
547
+ timestamp = datetime .datetime .now ().strftime ("%Y%m%d_%H%M%S" )
548
+ self .kernels_dir = f"generated_kernels/kernel_agent_run_{ timestamp } "
549
+ os .makedirs (self .kernels_dir , exist_ok = True )
550
+
551
+ # Create README for this run
552
+ readme_path = os .path .join (self .kernels_dir , "README.md" )
553
+ with open (readme_path , "w" ) as f :
554
+ f .write (f"""# Generated Kernels - KernelAgent - { timestamp }
555
+
556
+ This directory contains PyTorch/Triton kernels generated by the KernelAgent Backend.
557
+
558
+ ## Run Info
559
+ - Timestamp: { timestamp }
560
+ - Backend: KernelAgent
561
+ - Features: Parallel workers, iterative refinement, conversation history
562
+
563
+ ## Files
564
+ Each `<op_name>_kernel.py` file contains the complete generated kernel code for that operation.
565
+ KernelAgent session directories contain detailed logs, worker outputs, and generation artifacts.
566
+
567
+ ## KernelAgent Features Used
568
+ - Parallel workers for increased success rate
569
+ - Iterative refinement with multi-turn dialogue
570
+ - Comprehensive Triton programming guidelines
571
+ - Automatic test generation and validation
572
+ - Session logging and artifact preservation
573
+
574
+ ## Usage
575
+ You can inspect these files to debug kernel generation, analyze the parallel worker outputs,
576
+ or understand the sophisticated generation process used by KernelAgent.
577
+ """ )
578
+
579
+ print (f"Saving KernelAgent generated kernels to: { self .kernels_dir } " )
580
+
581
+ # Initialize KernelAgent (imported lazily to avoid dependency issues)
582
+ self .kernel_agent = None
583
+ self .num_workers = 4 # Default values, can be overridden
584
+ self .max_rounds = 10
585
+
586
+ def set_config (self , num_workers : int , max_rounds : int ):
587
+ """Set configuration for KernelAgent."""
588
+ self .num_workers = num_workers
589
+ self .max_rounds = max_rounds
590
+
591
+ def _get_kernel_agent (self ):
592
+ """Lazy initialization of KernelAgent to avoid import issues."""
593
+ if self .kernel_agent is None :
594
+ try :
595
+ # Import KernelAgent from the submodule
596
+ import sys
597
+
598
+ kernel_agent_path = os .path .join (os .path .dirname (__file__ ), ".." , "KernelAgent" )
599
+ if kernel_agent_path not in sys .path :
600
+ sys .path .insert (0 , os .path .abspath (kernel_agent_path ))
601
+
602
+ from triton_kernel_agent import TritonKernelAgent
603
+
604
+ # Create KernelAgent with custom log directory
605
+ agent_log_dir = os .path .join (self .kernels_dir , "agent_logs" )
606
+ os .makedirs (agent_log_dir , exist_ok = True )
607
+
608
+ self .kernel_agent = TritonKernelAgent (
609
+ log_dir = agent_log_dir , num_workers = self .num_workers , max_rounds = self .max_rounds
610
+ )
611
+
612
+ print (f"✓ KernelAgent initialized with log directory: { agent_log_dir } " )
613
+
614
+ except ImportError as e :
615
+ raise ImportError (
616
+ f"Failed to import KernelAgent: { e } \n "
617
+ f"Please ensure KernelAgent submodule is properly initialized.\n "
618
+ f"Run: git submodule update --init --recursive"
619
+ )
620
+
621
+ return self .kernel_agent
622
+
623
+ def _create_problem_description_from_op (self , op , op_name : str ) -> str :
624
+ """
625
+ Create a problem description for KernelAgent based on the PyTorch operation.
626
+
627
+ Args:
628
+ op: PyTorch operation
629
+ op_name: Operation name extracted from op
630
+
631
+ Returns:
632
+ Problem description string for KernelAgent
633
+ """
634
+ # Create a comprehensive problem description that KernelAgent can understand
635
+ problem_description = f"""
636
+ Implement a high-performance Triton kernel for the PyTorch operation: { op_name }
637
+
638
+ Operation details:
639
+ - PyTorch operation: { op }
640
+ - Operation name: { op_name }
641
+ - Framework target: OpenAI Triton
642
+
643
+ Requirements:
644
+ 1. The kernel must be functionally equivalent to the PyTorch operation
645
+ 2. Implement using Triton language primitives (tl.load, tl.store, etc.)
646
+ 3. Handle all tensor shapes and data types that the original operation supports
647
+ 4. Optimize for GPU performance with proper memory coalescing
648
+ 5. Include proper boundary condition handling
649
+ 6. Follow Triton best practices for kernel design
650
+
651
+ The generated kernel should:
652
+ - Take the same input arguments as the PyTorch operation
653
+ - Return outputs with identical shapes, dtypes, and numerical values
654
+ - Be optimized for common tensor shapes and memory layouts
655
+ - Handle edge cases gracefully
656
+
657
+ Please generate a complete, production-ready Triton kernel implementation.
658
+ """
659
+ return problem_description
660
+
661
+ def _adapt_kernel_function_name (self , kernel_code : str , op_name : str ) -> str :
662
+ """
663
+ Adapt KernelAgent's 'kernel_function' to BackendBench's expected naming convention.
664
+
665
+ KernelAgent generates kernels with 'kernel_function' as the main entry point.
666
+ BackendBench expects '{op_name}_kernel_impl' as the function name.
667
+
668
+ Args:
669
+ kernel_code: Original kernel code from KernelAgent
670
+ op_name: Operation name for the expected function name
671
+
672
+ Returns:
673
+ Modified kernel code with correct function name
674
+ """
675
+ expected_name = f"{ op_name } _kernel_impl"
676
+
677
+ # Replace 'def kernel_function' with 'def {op_name}_kernel_impl'
678
+ if "def kernel_function(" in kernel_code :
679
+ adapted_code = kernel_code .replace ("def kernel_function(" , f"def { expected_name } (" )
680
+
681
+ # Also replace any docstring references
682
+ adapted_code = adapted_code .replace (
683
+ '"""Wrapper function that handles kernel launch."""' ,
684
+ f'"""{ op_name } kernel implementation using Triton."""' ,
685
+ )
686
+
687
+ return adapted_code
688
+ else :
689
+ # If kernel_function is not found, add a wrapper that calls the existing function
690
+ wrapper_code = f'''
691
+
692
+ def { expected_name } (*args, **kwargs):
693
+ """{ op_name } kernel implementation using Triton - BackendBench adapter."""
694
+ # Call the original kernel_function from KernelAgent
695
+ return kernel_function(*args, **kwargs)
696
+ '''
697
+ return kernel_code + wrapper_code
698
+
699
+ def compile_kernel_from_string (
700
+ self , kernel_code : str , op_name : str , attempt : int = 1
701
+ ) -> Callable :
702
+ """Compile a kernel from string code and return a callable."""
703
+ try :
704
+ # Adapt the function name for BackendBench compatibility
705
+ adapted_code = self ._adapt_kernel_function_name (kernel_code , op_name )
706
+
707
+ # Prepare the code with necessary imports
708
+ is_triton = "triton.jit" in adapted_code or "@triton.jit" in adapted_code
709
+ if is_triton :
710
+ full_code = self ._prepare_triton_code (adapted_code )
711
+ else :
712
+ full_code = self ._prepare_torch_code (adapted_code )
713
+
714
+ # Save the kernel to file
715
+ kernel_file = os .path .join (self .kernels_dir , f"{ op_name } _kernel.py" )
716
+ with open (kernel_file , "w" ) as f :
717
+ f .write (full_code )
718
+
719
+ print (f"Saved KernelAgent kernel to: { kernel_file } " )
720
+
721
+ # Import and compile the kernel
722
+ spec = importlib .util .spec_from_file_location (f"kernel_agent_{ op_name } " , kernel_file )
723
+ module = importlib .util .module_from_spec (spec )
724
+ spec .loader .exec_module (module )
725
+
726
+ # Find the expected function
727
+ expected_name = f"{ op_name } _kernel_impl"
728
+ if hasattr (module , expected_name ):
729
+ return getattr (module , expected_name )
730
+ else :
731
+ available_functions = [
732
+ name
733
+ for name in dir (module )
734
+ if callable (getattr (module , name )) and not name .startswith ("_" )
735
+ ]
736
+ raise ValueError (
737
+ f"Expected function '{ expected_name } ' not found in KernelAgent kernel. "
738
+ f"Available: { available_functions } "
739
+ )
740
+
741
+ except Exception as e :
742
+ raise RuntimeError (f"Failed to compile KernelAgent kernel for { op_name } : { str (e )} " )
743
+
744
+ def _prepare_triton_code (self , kernel_code : str ) -> str :
745
+ """Prepare Triton kernel code with necessary imports."""
746
+ imports = """
747
+ import torch
748
+ import triton
749
+ import triton.language as tl
750
+ """
751
+ if "import torch" not in kernel_code :
752
+ kernel_code = imports + kernel_code
753
+ return kernel_code
754
+
755
+ def _prepare_torch_code (self , kernel_code : str ) -> str :
756
+ """Prepare regular PyTorch kernel code with necessary imports."""
757
+ imports = """
758
+ import torch
759
+ import torch.nn.functional as F
760
+ """
761
+ if "import torch" not in kernel_code :
762
+ kernel_code = imports + kernel_code
763
+ return kernel_code
764
+
765
+ def add_kernel (self , op , kernel_code : str , op_name : str ):
766
+ """Add a kernel implementation for a specific operator."""
767
+ compiled_kernel = self .compile_kernel_from_string (kernel_code , op_name , attempt = 1 )
768
+ self .compiled_kernels [op ] = compiled_kernel
769
+
770
+ # Save the original KernelAgent code as well
771
+ original_file = os .path .join (self .kernels_dir , f"{ op_name } _original_kernel_agent.py" )
772
+ with open (original_file , "w" ) as f :
773
+ f .write (kernel_code )
774
+
775
+ def generate_kernel_with_agent (self , op , op_name : str ) -> tuple [str , bool ]:
776
+ """
777
+ Generate a kernel using KernelAgent's sophisticated generation system.
778
+
779
+ Args:
780
+ op: PyTorch operation
781
+ op_name: Operation name
782
+
783
+ Returns:
784
+ tuple: (kernel_code, success)
785
+ """
786
+ try :
787
+ agent = self ._get_kernel_agent ()
788
+
789
+ # Create problem description
790
+ problem_description = self ._create_problem_description_from_op (op , op_name )
791
+
792
+ print (
793
+ f"🚀 Generating { op_name } kernel with KernelAgent (parallel workers + refinement)"
794
+ )
795
+
796
+ # Generate kernel using KernelAgent
797
+ result = agent .generate_kernel (
798
+ problem_description = problem_description ,
799
+ test_code = None , # Let KernelAgent auto-generate the test
800
+ )
801
+
802
+ if result ["success" ]:
803
+ print (f"✅ KernelAgent succeeded for { op_name } !" )
804
+ print (
805
+ f" Worker { result ['worker_id' ]} found solution in { result ['rounds' ]} rounds"
806
+ )
807
+ print (f" Session: { result ['session_dir' ]} " )
808
+
809
+ # Copy the session directory to our kernels directory for preservation
810
+ import shutil
811
+
812
+ session_name = os .path .basename (result ["session_dir" ])
813
+ preserved_session = os .path .join (
814
+ self .kernels_dir , f"{ op_name } _session_{ session_name } "
815
+ )
816
+ try :
817
+ shutil .copytree (result ["session_dir" ], preserved_session )
818
+ print (f" Session preserved: { preserved_session } " )
819
+ except Exception as e :
820
+ print (f" Warning: Could not preserve session: { e } " )
821
+
822
+ return result ["kernel_code" ], True
823
+ else :
824
+ print (f"❌ KernelAgent failed for { op_name } : { result ['message' ]} " )
825
+ return "" , False
826
+
827
+ except Exception as e :
828
+ print (f"❌ KernelAgent error for { op_name } : { e } " )
829
+ return "" , False
830
+
831
+ def __getitem__ (self , key ):
832
+ if key in self .compiled_kernels :
833
+ return self .compiled_kernels [key ]
834
+ raise KeyError (f"No KernelAgent kernel implementation found for { key } " )
835
+
836
+ def __contains__ (self , key ):
837
+ return key in self .compiled_kernels
0 commit comments