Skip to content

Commit 69ffd7a

Browse files
authored
Cache NPUKernel objects (#2611)
Signed-off-by: Muhammad Awad <[email protected]>
1 parent 38b5685 commit 69ffd7a

File tree

6 files changed

+604
-19
lines changed

6 files changed

+604
-19
lines changed

.github/workflows/buildAndTestRyzenAI.yml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ jobs:
6060
fail-fast: false
6161
matrix:
6262
runner_type: [ amd7940hs, amdhx370 ]
63+
env:
64+
IRON_CACHE_HOME: ${{ github.workspace }}/iron-cache-${{ matrix.runner_type }}-${{ github.run_id }}
6365
steps:
6466
- uses: actions/checkout@v4
6567
with:
@@ -126,6 +128,9 @@ jobs:
126128
-DMLIR_DIR=$PWD/../mlir/lib/cmake/mlir \
127129
$CMAKE_ARGS
128130
131+
# Create runner-specific cache directory
132+
mkdir -p $IRON_CACHE_HOME
133+
129134
ninja install
130135
ninja check-aie
131136
popd
@@ -137,6 +142,8 @@ jobs:
137142
fail-fast: false
138143
matrix:
139144
runner_type: [ amd7940hs, amdhx370 ]
145+
env:
146+
IRON_CACHE_HOME: ${{ github.workspace }}/iron-cache-${{ matrix.runner_type }}-${{ github.run_id }}
140147
steps:
141148
- uses: actions/checkout@v4
142149
with:
@@ -183,8 +190,10 @@ jobs:
183190
LIT_OPTS="-j12 $LIT_OPTS"
184191
fi
185192
193+
# Create runner-specific cache directory
194+
mkdir -p $IRON_CACHE_HOME
195+
186196
ninja install
187197
ninja check-reference-designs
188198
ninja check-programming-guide
189-
190-
popd
199+
popd

python/iron/jit.py

Lines changed: 116 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,44 @@
2323
from aie.dialects.aie import AIEDevice
2424

2525

26-
# The `iron.jit` decorator below caches compiled kenrels inside the `IRON_CACHE_DIR` directory.
26+
# The `iron.jit` decorator below caches compiled kenrels inside the `IRON_CACHE_HOME` directory.
2727
# Kernels are cached based on their hash value of the MLIR module string. If during compilation,
2828
# we hit in the cache, the `iron.jit` will load the xclbin and instruction binary files from the cache.
29-
IRON_CACHE_DIR = os.path.expanduser("~/.iron/cache")
29+
IRON_CACHE_HOME = os.environ.get("IRON_CACHE_HOME", os.path.expanduser("~/.iron/cache"))
30+
31+
32+
class CircularCache:
33+
def __init__(self, max_size):
34+
self.max_size = max_size
35+
self.cache = [None] * max_size
36+
self.keys = [None] * max_size
37+
self.index = 0
38+
39+
def __contains__(self, key):
40+
return key in self.keys
41+
42+
def __getitem__(self, key):
43+
idx = self.keys.index(key)
44+
return self.cache[idx]
45+
46+
def __setitem__(self, key, value):
47+
self.cache[self.index] = value
48+
self.keys[self.index] = key
49+
self.index = (self.index + 1) % self.max_size
50+
51+
def __len__(self):
52+
return sum(1 for k in self.keys if k is not None)
53+
54+
def clear(self):
55+
self.cache = [None] * self.max_size
56+
self.keys = [None] * self.max_size
57+
self.index = 0
58+
59+
60+
# Global cache for compiled kernels at the function level
61+
# Key: (function_name, args_signature) -> NPUKernel instance
62+
# There is a limit on the number of kernels we have in cache
63+
_compiled_kernels = CircularCache(max_size=1)
3064

3165

3266
class NPUKernel:
@@ -117,8 +151,21 @@ def __del__(self):
117151
"""
118152
Destructor to clean up resources and delete the kernel and device objects.
119153
"""
120-
del self.__kernel
121-
del self.__device
154+
if hasattr(self, "_NPUKernel__insts_buffer_bo"):
155+
del self.__insts_buffer_bo
156+
self.__insts_buffer_bo = None
157+
if hasattr(self, "_NPUKernel__kernel"):
158+
del self.__kernel
159+
self.__kernel = None
160+
if hasattr(self, "_NPUKernel__context"):
161+
del self.__context
162+
self.__context = None
163+
if hasattr(self, "_NPUKernel__xclbin"):
164+
del self.__xclbin
165+
self.__xclbin = None
166+
if hasattr(self, "_NPUKernel__device"):
167+
del self.__device
168+
self.__device = None
122169

123170

124171
class NPUKernel_Error(Exception):
@@ -145,6 +192,12 @@ def jit(function=None, is_placed=True, use_cache=True):
145192
def decorator(*args, **kwargs):
146193
from .kernel import ExternalFunction
147194

195+
# Check if we already have a compiled kernel for this function signature
196+
cache_key = _create_function_cache_key(function, args, kwargs)
197+
if cache_key in _compiled_kernels:
198+
cached_kernel = _compiled_kernels[cache_key]
199+
return cached_kernel(*args, **kwargs)
200+
148201
# Clear any instances from previous runs to make sure if the user provided any broken code we don't try to recompile it
149202
ExternalFunction._instances.clear()
150203

@@ -198,7 +251,7 @@ def decorator(*args, **kwargs):
198251

199252
# Hash of the IR string, ExternalFunction compiler options, and target architecture
200253
module_hash = hash_module(mlir_module, external_kernels, target_arch)
201-
kernel_dir = os.path.join(IRON_CACHE_DIR, f"{module_hash}")
254+
kernel_dir = os.path.join(IRON_CACHE_HOME, f"{module_hash}")
202255
mlir_path = os.path.join(kernel_dir, "aie.mlir")
203256

204257
# Ensure cache directory exists
@@ -238,6 +291,10 @@ def decorator(*args, **kwargs):
238291
kernel_name = "MLIR_AIE"
239292
try:
240293
kernel = NPUKernel(xclbin_path, inst_path, kernel_name=kernel_name)
294+
295+
# Cache the kernel for this function signature
296+
_compiled_kernels[cache_key] = kernel
297+
241298
result = kernel(*args, **kwargs)
242299
return result
243300
except Exception as e:
@@ -313,15 +370,14 @@ def hash_module(module, external_kernels=None, target_arch=None):
313370
"""
314371
mlir_str = str(module)
315372

316-
# Include ExternalFunction compiler options in the hash
373+
# Include ExternalFunction compiler options and source code in the hash
317374
if external_kernels:
318-
compiler_options = []
375+
running_hash = ""
376+
source_contents = []
319377
for func in external_kernels:
320-
compiler_options.extend(func._include_dirs)
321-
compiler_options.extend(func._compile_flags)
378+
running_hash += str(hash(func))
322379

323-
# Create a combined string for hashing
324-
combined_str = mlir_str + "|" + "|".join(compiler_options)
380+
combined_str = mlir_str + "|" + "|".join(running_hash)
325381
else:
326382
combined_str = mlir_str
327383

@@ -331,3 +387,52 @@ def hash_module(module, external_kernels=None, target_arch=None):
331387

332388
hash_result = hashlib.sha256(combined_str.encode("utf-8")).hexdigest()[:16]
333389
return hash_result
390+
391+
392+
def _hash_argument(arg, prefix=""):
393+
"""
394+
Helper function to hash supported argument types (tensors and callables).
395+
Returns a string representation for cache key generation.
396+
"""
397+
from aie.iron.tensor import Tensor
398+
from aie.iron.kernel import ExternalFunction
399+
400+
if isinstance(arg, Tensor):
401+
# Tensor argument - include shape and dtype
402+
return f"{prefix}tensor_{arg.shape}_{arg.dtype}"
403+
elif isinstance(arg, ExternalFunction):
404+
# ExternalFunction argument - use its custom hash method
405+
func_hash = hash(arg)
406+
return f"{prefix}externalfunction_{func_hash}"
407+
elif callable(arg):
408+
# Function argument - use hash of function address for uniqueness
409+
func_hash = hash(arg)
410+
return f"{prefix}function_{func_hash}"
411+
else:
412+
# Unsupported type - use type name
413+
return f"{prefix}{type(arg).__name__}"
414+
415+
416+
def _create_function_cache_key(function, args, kwargs):
417+
"""
418+
Create a cache key for a function call based on function name and argument types/shapes.
419+
This allows us to cache compiled kernels at the function level.
420+
Note that it is not necessary that we cache the tensor shapes since the kernel may be agonstic
421+
to the shape changes but we are doing here for safety.
422+
"""
423+
# Get function name
424+
func_name = function.__name__
425+
426+
# Create signature from argument types and shapes
427+
signature_parts = []
428+
429+
for arg in args:
430+
result = _hash_argument(arg)
431+
signature_parts.append(result)
432+
433+
for key, value in sorted(kwargs.items()):
434+
result = _hash_argument(value, f"{key}_")
435+
signature_parts.append(result)
436+
437+
signature = "_".join(signature_parts)
438+
return (func_name, signature)

python/iron/kernel.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,34 @@ def resolve(
186186
# Create the external function
187187
self._op = external_func(self._name, inputs=self._arg_types)
188188

189+
def __hash__(self):
190+
"""
191+
Compute a hash for the ExternalFunction based on its properties.
192+
This allows ExternalFunction instances to be used in cache keys.
193+
"""
194+
import hashlib
195+
196+
# Create a string representation of the function's key properties
197+
hash_parts = [
198+
self._name,
199+
str(self._arg_types),
200+
str(sorted(self._include_dirs)),
201+
str(sorted(self._compile_flags)),
202+
]
203+
204+
# Include source content for uniqueness
205+
# TODO: This solution needs to be extended to handle headers. See https://github.com/Xilinx/mlir-aie/issues/2543
206+
if self._source_string:
207+
hash_parts.append(self._source_string)
208+
elif self._source_file:
209+
with open(self._source_file, "r") as f:
210+
file_content = f.read()
211+
hash_parts.append(file_content)
212+
213+
# Create hash from combined string
214+
combined = "|".join(hash_parts)
215+
return int(hashlib.sha256(combined.encode("utf-8")).hexdigest()[:8], 16)
216+
189217
def __call__(self, *args, **kwargs):
190218
if not self._op:
191219
raise ValueError("Need to resolve ExternalFunction before it can be called")

python/iron/tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,9 @@ def __del__(self):
484484
485485
Releases associated device memory (e.g., XRT buffer object).
486486
"""
487-
del self.bo
488-
self.bo = None
487+
if hasattr(self, "bo"):
488+
del self.bo
489+
self.bo = None
489490

490491

491492
def tensor(data, dtype=np.float32, device="npu"):

python/utils/xrt.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,12 @@ def call(self):
7474
return h
7575

7676
def __del__(self):
77-
del self.kernel
78-
del self.device
77+
if hasattr(self, "kernel"):
78+
del self.kernel
79+
self.kernel = None
80+
if hasattr(self, "device"):
81+
del self.device
82+
self.device = None
7983

8084

8185
# This class wraps up access to the xrt.bo buffer object where sync calls are added
@@ -114,8 +118,9 @@ def sync_from_device(self):
114118
return self.bo.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE)
115119

116120
def __del__(self):
117-
del self.bo
118-
self.bo = None
121+
if hasattr(self, "bo"):
122+
del self.bo
123+
self.bo = None
119124

120125

121126
class AIE_Application_Error(Exception):

0 commit comments

Comments
 (0)