2323from aie .dialects .aie import AIEDevice
2424
2525
26- # The `iron.jit` decorator below caches compiled kenrels inside the `IRON_CACHE_DIR ` directory.
26+ # The `iron.jit` decorator below caches compiled kenrels inside the `IRON_CACHE_HOME ` directory.
2727# Kernels are cached based on their hash value of the MLIR module string. If during compilation,
2828# we hit in the cache, the `iron.jit` will load the xclbin and instruction binary files from the cache.
29- IRON_CACHE_DIR = os .path .expanduser ("~/.iron/cache" )
29+ IRON_CACHE_HOME = os .environ .get ("IRON_CACHE_HOME" , os .path .expanduser ("~/.iron/cache" ))
30+
31+
32+ class CircularCache :
33+ def __init__ (self , max_size ):
34+ self .max_size = max_size
35+ self .cache = [None ] * max_size
36+ self .keys = [None ] * max_size
37+ self .index = 0
38+
39+ def __contains__ (self , key ):
40+ return key in self .keys
41+
42+ def __getitem__ (self , key ):
43+ idx = self .keys .index (key )
44+ return self .cache [idx ]
45+
46+ def __setitem__ (self , key , value ):
47+ self .cache [self .index ] = value
48+ self .keys [self .index ] = key
49+ self .index = (self .index + 1 ) % self .max_size
50+
51+ def __len__ (self ):
52+ return sum (1 for k in self .keys if k is not None )
53+
54+ def clear (self ):
55+ self .cache = [None ] * self .max_size
56+ self .keys = [None ] * self .max_size
57+ self .index = 0
58+
59+
60+ # Global cache for compiled kernels at the function level
61+ # Key: (function_name, args_signature) -> NPUKernel instance
62+ # There is a limit on the number of kernels we have in cache
63+ _compiled_kernels = CircularCache (max_size = 1 )
3064
3165
3266class NPUKernel :
@@ -117,8 +151,21 @@ def __del__(self):
117151 """
118152 Destructor to clean up resources and delete the kernel and device objects.
119153 """
120- del self .__kernel
121- del self .__device
154+ if hasattr (self , "_NPUKernel__insts_buffer_bo" ):
155+ del self .__insts_buffer_bo
156+ self .__insts_buffer_bo = None
157+ if hasattr (self , "_NPUKernel__kernel" ):
158+ del self .__kernel
159+ self .__kernel = None
160+ if hasattr (self , "_NPUKernel__context" ):
161+ del self .__context
162+ self .__context = None
163+ if hasattr (self , "_NPUKernel__xclbin" ):
164+ del self .__xclbin
165+ self .__xclbin = None
166+ if hasattr (self , "_NPUKernel__device" ):
167+ del self .__device
168+ self .__device = None
122169
123170
124171class NPUKernel_Error (Exception ):
@@ -145,6 +192,12 @@ def jit(function=None, is_placed=True, use_cache=True):
145192 def decorator (* args , ** kwargs ):
146193 from .kernel import ExternalFunction
147194
195+ # Check if we already have a compiled kernel for this function signature
196+ cache_key = _create_function_cache_key (function , args , kwargs )
197+ if cache_key in _compiled_kernels :
198+ cached_kernel = _compiled_kernels [cache_key ]
199+ return cached_kernel (* args , ** kwargs )
200+
148201 # Clear any instances from previous runs to make sure if the user provided any broken code we don't try to recompile it
149202 ExternalFunction ._instances .clear ()
150203
@@ -198,7 +251,7 @@ def decorator(*args, **kwargs):
198251
199252 # Hash of the IR string, ExternalFunction compiler options, and target architecture
200253 module_hash = hash_module (mlir_module , external_kernels , target_arch )
201- kernel_dir = os .path .join (IRON_CACHE_DIR , f"{ module_hash } " )
254+ kernel_dir = os .path .join (IRON_CACHE_HOME , f"{ module_hash } " )
202255 mlir_path = os .path .join (kernel_dir , "aie.mlir" )
203256
204257 # Ensure cache directory exists
@@ -238,6 +291,10 @@ def decorator(*args, **kwargs):
238291 kernel_name = "MLIR_AIE"
239292 try :
240293 kernel = NPUKernel (xclbin_path , inst_path , kernel_name = kernel_name )
294+
295+ # Cache the kernel for this function signature
296+ _compiled_kernels [cache_key ] = kernel
297+
241298 result = kernel (* args , ** kwargs )
242299 return result
243300 except Exception as e :
@@ -313,15 +370,14 @@ def hash_module(module, external_kernels=None, target_arch=None):
313370 """
314371 mlir_str = str (module )
315372
316- # Include ExternalFunction compiler options in the hash
373+ # Include ExternalFunction compiler options and source code in the hash
317374 if external_kernels :
318- compiler_options = []
375+ running_hash = ""
376+ source_contents = []
319377 for func in external_kernels :
320- compiler_options .extend (func ._include_dirs )
321- compiler_options .extend (func ._compile_flags )
378+ running_hash += str (hash (func ))
322379
323- # Create a combined string for hashing
324- combined_str = mlir_str + "|" + "|" .join (compiler_options )
380+ combined_str = mlir_str + "|" + "|" .join (running_hash )
325381 else :
326382 combined_str = mlir_str
327383
@@ -331,3 +387,52 @@ def hash_module(module, external_kernels=None, target_arch=None):
331387
332388 hash_result = hashlib .sha256 (combined_str .encode ("utf-8" )).hexdigest ()[:16 ]
333389 return hash_result
390+
391+
392+ def _hash_argument (arg , prefix = "" ):
393+ """
394+ Helper function to hash supported argument types (tensors and callables).
395+ Returns a string representation for cache key generation.
396+ """
397+ from aie .iron .tensor import Tensor
398+ from aie .iron .kernel import ExternalFunction
399+
400+ if isinstance (arg , Tensor ):
401+ # Tensor argument - include shape and dtype
402+ return f"{ prefix } tensor_{ arg .shape } _{ arg .dtype } "
403+ elif isinstance (arg , ExternalFunction ):
404+ # ExternalFunction argument - use its custom hash method
405+ func_hash = hash (arg )
406+ return f"{ prefix } externalfunction_{ func_hash } "
407+ elif callable (arg ):
408+ # Function argument - use hash of function address for uniqueness
409+ func_hash = hash (arg )
410+ return f"{ prefix } function_{ func_hash } "
411+ else :
412+ # Unsupported type - use type name
413+ return f"{ prefix } { type (arg ).__name__ } "
414+
415+
416+ def _create_function_cache_key (function , args , kwargs ):
417+ """
418+ Create a cache key for a function call based on function name and argument types/shapes.
419+ This allows us to cache compiled kernels at the function level.
420+ Note that it is not necessary that we cache the tensor shapes since the kernel may be agonstic
421+ to the shape changes but we are doing here for safety.
422+ """
423+ # Get function name
424+ func_name = function .__name__
425+
426+ # Create signature from argument types and shapes
427+ signature_parts = []
428+
429+ for arg in args :
430+ result = _hash_argument (arg )
431+ signature_parts .append (result )
432+
433+ for key , value in sorted (kwargs .items ()):
434+ result = _hash_argument (value , f"{ key } _" )
435+ signature_parts .append (result )
436+
437+ signature = "_" .join (signature_parts )
438+ return (func_name , signature )
0 commit comments