diff --git a/python/iron/jit.py b/python/iron/jit.py index 5bf38a938d3..42880c8841c 100644 --- a/python/iron/jit.py +++ b/python/iron/jit.py @@ -12,8 +12,6 @@ import numpy as np import pyxrt as xrt import shutil -import sys -import traceback from aie.extras.context import mlir_mod_ctx from ..utils.xrt import read_insts_binary @@ -200,6 +198,8 @@ def decorator(*args, **kwargs): # Clear any instances from previous runs to make sure if the user provided any broken code we don't try to recompile it ExternalFunction._instances.clear() + # Create object file name to combine the partial object files into + ExternalFunction._bin_name = function.__name__ + ".o" # Find ExternalFunction instances in arguments and kwargs external_kernels = [] @@ -226,10 +226,7 @@ def decorator(*args, **kwargs): # Compile all ExternalFunction instances that were created during this JIT compilation for func in ExternalFunction._instances: - if ( - not hasattr(func, "_compiled") or not func._compiled - ): # Don't compile if already compiled - external_kernels.append(func) + external_kernels.append(func) # Determine target architecture based on device type try: @@ -272,8 +269,25 @@ def decorator(*args, **kwargs): print(mlir_module, file=f) # Compile ExternalFunctions from inside the JIT compilation directory + object_files = [] for func in external_kernels: - compile_external_kernel(func, kernel_dir, target_arch) + compile_external_kernel( + func, kernel_dir, target_arch, func._object_file_name + ) + object_files.append( + os.path.join(kernel_dir, func._object_file_name) + ) + + # Combine all object files in a single object file + if object_files: + from .compile.link import merge_object_files + + merged_object_file = os.path.join( + kernel_dir, ExternalFunction._bin_name + ) + merge_object_files( + object_paths=object_files, output_path=merged_object_file + ) # Compile the MLIR module compile_mlir_module( @@ -303,7 +317,7 @@ def decorator(*args, **kwargs): return decorator -def compile_external_kernel(func, kernel_dir, target_arch): +def compile_external_kernel(func, kernel_dir, target_arch, output_file): """ Compile an ExternalFunction to an object file in the kernel directory. @@ -312,15 +326,6 @@ def compile_external_kernel(func, kernel_dir, target_arch): kernel_dir: Directory to place the compiled object file target_arch: Target architecture (e.g., "aie2" or "aie2p") """ - # Skip if already compiled - if hasattr(func, "_compiled") and func._compiled: - return - - # Check if object file already exists in kernel directory - output_file = os.path.join(kernel_dir, func._object_file_name) - if os.path.exists(output_file): - return - # Create source file in kernel directory source_file = os.path.join(kernel_dir, f"{func._name}.cc") @@ -360,9 +365,6 @@ def compile_external_kernel(func, kernel_dir, target_arch): except Exception as e: raise - # Mark the function as compiled - func._compiled = True - def hash_module(module, external_kernels=None, target_arch=None): """ diff --git a/python/iron/kernel.py b/python/iron/kernel.py index d0b35f77209..72b489ec69c 100644 --- a/python/iron/kernel.py +++ b/python/iron/kernel.py @@ -77,6 +77,7 @@ def resolve( class ExternalFunction(BaseKernel): _instances = set() + _bin_name = str() def __init__( self, @@ -108,7 +109,6 @@ def __init__( self._object_file_name = object_file_name else: self._object_file_name = f"{self._name}.o" - self._compiled = False # Track this instance for JIT compilation ExternalFunction._instances.add(self) @@ -132,9 +132,9 @@ def __exit__(self, exc_type, exc_value, traceback): """Exit the context.""" pass - @property - def bin_name(self) -> str: - return self._object_file_name + @classmethod + def bin_name(cls) -> str: + return ExternalFunction._bin_name def tile_size(self, arg_index: int = 0) -> int: """Get the tile size from the specified array argument type. diff --git a/python/iron/worker.py b/python/iron/worker.py index 9791f65c72b..a1525206424 100644 --- a/python/iron/worker.py +++ b/python/iron/worker.py @@ -86,7 +86,9 @@ def do_nothing_core_fun(*args) -> None: # Check arguments to the core. Some information is saved for resolution. for arg in self.fn_args: - if isinstance(arg, (Kernel, ExternalFunction)): + if isinstance(arg, ExternalFunction): + bin_names.add(ExternalFunction._bin_name) + if isinstance(arg, Kernel): bin_names.add(arg.bin_name) elif isinstance(arg, ObjectFifoHandle): arg.endpoint = self