Skip to content

Commit 4d4d951

Browse files
nbpatelsilee2
authored andcommitted
Add zeModuleDestroy in struct destructor
Cache Spirv Module
1 parent 526d235 commit 4d4d951

File tree

4 files changed

+95
-11
lines changed

4 files changed

+95
-11
lines changed

lib/ExecutionEngine/SYCLRUNTIME/SyclRuntimeWrappers.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
#include <CL/sycl.hpp>
2828
#include <level_zero/ze_api.h>
29+
#include <map>
30+
#include <mutex>
2931
#include <sycl/ext/oneapi/backend/level_zero.hpp>
3032

3133
#ifdef _WIN32
@@ -62,6 +64,21 @@ template <typename F> auto catchAll(F &&func) {
6264

6365
} // namespace
6466

67+
struct SpirvModule {
68+
ze_module_handle_t module = nullptr;
69+
~SpirvModule();
70+
};
71+
72+
namespace {
73+
// Create a Map for the spirv module lookup
74+
std::map<void *, SpirvModule> moduleCache;
75+
std::mutex mutexLock;
76+
} // namespace
77+
78+
SpirvModule::~SpirvModule() {
79+
L0_SAFE_CALL(zeModuleDestroy(SpirvModule::module));
80+
}
81+
6582
struct ParamDesc {
6683
void *data;
6784
size_t size;
@@ -153,6 +170,13 @@ static ze_module_handle_t loadModule(GPUSYCLQUEUE *queue, const void *data,
153170
assert(data);
154171
auto syclQueue = queue->syclQueue_;
155172
ze_module_handle_t zeModule;
173+
174+
auto it = moduleCache.find((void *)data);
175+
// Check the map if the module is present/cached.
176+
if (it != moduleCache.end()) {
177+
return it->second.module;
178+
}
179+
156180
ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
157181
nullptr,
158182
ZE_MODULE_FORMAT_IL_SPIRV,
@@ -165,6 +189,8 @@ static ze_module_handle_t loadModule(GPUSYCLQUEUE *queue, const void *data,
165189
auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
166190
syclQueue.get_context());
167191
L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr));
192+
std::lock_guard<std::mutex> entryLock(mutexLock);
193+
moduleCache[(void *)data].module = zeModule;
168194
return zeModule;
169195
}
170196

@@ -177,8 +203,8 @@ static sycl::kernel *getKernel(GPUSYCLQUEUE *queue, ze_module_handle_t zeModule,
177203
sycl::kernel *syclKernel;
178204
ze_kernel_desc_t desc = {};
179205
desc.pKernelName = name;
180-
L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
181206

207+
L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
182208
sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle =
183209
sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
184210
sycl::bundle_state::executable>(

test/PlaidML/CppEdsl.Add.mlir

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
// RUN: --runner imex-cpu-runner -e main \
33
// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils \
44
// RUN: --entry-point-result=void --filecheck
5-
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \
5+
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm-caching.pp \
66
// RUN: --runner imex-cpu-runner -e main \
77
// RUN: --entry-point-result=void \
88
// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
9-
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \
9+
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm-caching.pp \
1010
// RUN: --runner imex-cpu-runner -e main \
1111
// RUN: --entry-point-result=void \
1212
// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
@@ -15,9 +15,14 @@ module @add {
1515
func.func @main() {
1616
%0= arith.constant dense<[[1, 2, 3], [4, 5, 4102], [16777223, 4294967304, 1099511627785]]>:tensor<3x3xi64>
1717
%1= arith.constant dense<[[1, 4098, 3], [16777220, 5, 4294967302], [7, 1099511627784, 9]]>:tensor<3x3xi64>
18-
%2 = call @test(%0,%1) : (tensor<3x3xi64>,tensor<3x3xi64>) -> tensor<3x3xi64>
19-
%unranked = tensor.cast %2 : tensor<3x3xi64>to tensor<*xi64>
20-
call @printMemrefI64(%unranked) : (tensor<*xi64>) -> ()
18+
%lb = arith.constant 0 : index
19+
%ub = arith.constant 100 : index
20+
%step = arith.constant 1 : index
21+
scf.for %temp = %lb to %ub step %step {
22+
%2 = func.call @test(%0,%1) : (tensor<3x3xi64>,tensor<3x3xi64>) -> tensor<3x3xi64>
23+
%unranked = tensor.cast %2 : tensor<3x3xi64> to tensor<*xi64>
24+
func.call @printMemrefI64(%unranked) : (tensor<*xi64>) -> ()
25+
}
2126
return
2227
}
2328
func.func private @printMemrefI64(tensor<*xi64>)

test/PlaidML/OpTest.GEMM_FLOAT32.mlir

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
// RUN: --runner imex-cpu-runner -e main \
33
// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils \
44
// RUN: --entry-point-result=void --filecheck
5-
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \
5+
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm-caching.pp \
66
// RUN: --runner imex-cpu-runner -e main \
77
// RUN: --entry-point-result=void \
88
// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
9-
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm.pp \
9+
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/linalg-to-llvm-caching.pp \
1010
// RUN: --runner imex-cpu-runner -e main \
1111
// RUN: --entry-point-result=void \
1212
// RUN: --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
@@ -18,9 +18,14 @@ func.func @main() {
1818
%0= arith.constant dense<[[0.5, 0.2, 4.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.3]]>:tensor<3x3xf32>
1919
%1= arith.constant dense<[[1.0, 2.0, 3.0], [3.0, 4.0, 0.5], [3.0, 3.0, 3.0]]>:tensor<3x3xf32>
2020
%2= arith.constant dense<[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]>:tensor<3x3xf32>
21-
%3 = call @test(%0,%1,%2) : (tensor<3x3xf32>,tensor<3x3xf32>,tensor<3x3xf32>) -> tensor<3x3xf32>
22-
%unranked = tensor.cast %3 : tensor<3x3xf32>to tensor<*xf32>
23-
call @printMemrefF32(%unranked) : (tensor<*xf32>) -> ()
21+
%lb = arith.constant 0 : index
22+
%ub = arith.constant 100 : index
23+
%step = arith.constant 1 : index
24+
scf.for %temp = %lb to %ub step %step {
25+
%3 = func.call @test(%0,%1,%2) : (tensor<3x3xf32>,tensor<3x3xf32>,tensor<3x3xf32>) -> tensor<3x3xf32>
26+
%unranked = tensor.cast %3 : tensor<3x3xf32> to tensor<*xf32>
27+
func.call @printMemrefF32(%unranked) : (tensor<*xf32>) -> ()
28+
}
2429
// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
2530
// CHECK-NEXT: [14.1, 14.8, 14.6]
2631
// CHECK-NEXT: [11, 13, 10.5]
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// linalg dialect to gpu dialect lowering pipeline
2+
// Ready for vulkan runner or narrow scope l0/sycl runner starting from GPU dialect.
3+
builtin.module(convert-tensor-to-linalg
4+
arith-bufferize
5+
func.func(empty-tensor-to-alloc-tensor
6+
eliminate-empty-tensors
7+
scf-bufferize
8+
shape-bufferize
9+
linalg-bufferize
10+
bufferization-bufferize
11+
tensor-bufferize)
12+
func-bufferize
13+
func.func(finalizing-bufferize
14+
convert-linalg-to-parallel-loops
15+
gpu-map-parallel-loops
16+
convert-parallel-loops-to-gpu)
17+
// insert-gpu-allocs pass can have client-api = opencl or vulkan args
18+
func.func(insert-gpu-allocs{client-api=opencl})
19+
canonicalize
20+
normalize-memrefs
21+
// Unstride memrefs does not seem to be needed.
22+
// func.func(unstride-memrefs)
23+
func.func(lower-affine)
24+
gpu-kernel-outlining
25+
canonicalize
26+
cse
27+
// The following set-spirv-* passes can have client-api = opencl or vulkan args
28+
set-spirv-capabilities{client-api=opencl}
29+
gpu.module(set-spirv-abi-attrs{client-api=opencl})
30+
canonicalize
31+
fold-memref-alias-ops
32+
imex-convert-gpu-to-spirv
33+
spirv.module(spirv-lower-abi-attrs
34+
spirv-update-vce)
35+
func.func(llvm-request-c-wrappers)
36+
serialize-spirv
37+
convert-gpu-to-gpux
38+
convert-scf-to-cf
39+
convert-cf-to-llvm
40+
convert-arith-to-llvm
41+
convert-func-to-llvm
42+
convert-math-to-llvm
43+
convert-gpux-to-llvm
44+
expand-strided-metadata
45+
lower-affine
46+
convert-memref-to-llvm
47+
reconcile-unrealized-casts)
48+
// End

0 commit comments

Comments
 (0)