Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions python/iron/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import numpy as np
import pyxrt as xrt
import shutil
import sys
import traceback

from aie.extras.context import mlir_mod_ctx
from ..utils.xrt import read_insts_binary
Expand Down Expand Up @@ -200,6 +198,8 @@ def decorator(*args, **kwargs):

# Clear any instances from previous runs to make sure if the user provided any broken code we don't try to recompile it
ExternalFunction._instances.clear()
# Create object file name to combine the partial object files into
ExternalFunction._bin_name = function.__name__ + ".o"

# Find ExternalFunction instances in arguments and kwargs
external_kernels = []
Expand All @@ -226,10 +226,7 @@ def decorator(*args, **kwargs):

# Compile all ExternalFunction instances that were created during this JIT compilation
for func in ExternalFunction._instances:
if (
not hasattr(func, "_compiled") or not func._compiled
): # Don't compile if already compiled
external_kernels.append(func)
external_kernels.append(func)

# Determine target architecture based on device type
try:
Expand Down Expand Up @@ -272,8 +269,25 @@ def decorator(*args, **kwargs):
print(mlir_module, file=f)

# Compile ExternalFunctions from inside the JIT compilation directory
object_files = []
for func in external_kernels:
compile_external_kernel(func, kernel_dir, target_arch)
compile_external_kernel(
func, kernel_dir, target_arch, func._object_file_name
)
object_files.append(
os.path.join(kernel_dir, func._object_file_name)
)

# Combine all object files in a single object file
if object_files:
from .compile.link import merge_object_files

merged_object_file = os.path.join(
kernel_dir, ExternalFunction._bin_name
)
merge_object_files(
object_paths=object_files, output_path=merged_object_file
)

# Compile the MLIR module
compile_mlir_module(
Expand Down Expand Up @@ -303,7 +317,7 @@ def decorator(*args, **kwargs):
return decorator


def compile_external_kernel(func, kernel_dir, target_arch):
def compile_external_kernel(func, kernel_dir, target_arch, output_file):
"""
Compile an ExternalFunction to an object file in the kernel directory.

Expand All @@ -312,15 +326,6 @@ def compile_external_kernel(func, kernel_dir, target_arch):
kernel_dir: Directory to place the compiled object file
target_arch: Target architecture (e.g., "aie2" or "aie2p")
"""
# Skip if already compiled
if hasattr(func, "_compiled") and func._compiled:
return

# Check if object file already exists in kernel directory
output_file = os.path.join(kernel_dir, func._object_file_name)
if os.path.exists(output_file):
return

# Create source file in kernel directory
source_file = os.path.join(kernel_dir, f"{func._name}.cc")

Expand Down Expand Up @@ -360,9 +365,6 @@ def compile_external_kernel(func, kernel_dir, target_arch):
except Exception as e:
raise

# Mark the function as compiled
func._compiled = True


def hash_module(module, external_kernels=None, target_arch=None):
"""
Expand Down
8 changes: 4 additions & 4 deletions python/iron/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def resolve(

class ExternalFunction(BaseKernel):
_instances = set()
_bin_name = str()

def __init__(
self,
Expand Down Expand Up @@ -108,7 +109,6 @@ def __init__(
self._object_file_name = object_file_name
else:
self._object_file_name = f"{self._name}.o"
self._compiled = False

# Track this instance for JIT compilation
ExternalFunction._instances.add(self)
Expand All @@ -132,9 +132,9 @@ def __exit__(self, exc_type, exc_value, traceback):
"""Exit the context."""
pass

@property
def bin_name(self) -> str:
return self._object_file_name
@classmethod
def bin_name(cls) -> str:
return ExternalFunction._bin_name

def tile_size(self, arg_index: int = 0) -> int:
"""Get the tile size from the specified array argument type.
Expand Down
4 changes: 3 additions & 1 deletion python/iron/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def do_nothing_core_fun(*args) -> None:

# Check arguments to the core. Some information is saved for resolution.
for arg in self.fn_args:
if isinstance(arg, (Kernel, ExternalFunction)):
if isinstance(arg, ExternalFunction):
bin_names.add(ExternalFunction._bin_name)
if isinstance(arg, Kernel):
bin_names.add(arg.bin_name)
elif isinstance(arg, ObjectFifoHandle):
arg.endpoint = self
Expand Down
Loading