@@ -20,9 +20,301 @@ constexpr char GPU_OCL_MOD_DESTRUCTOR[] = "gcGpuOclModuleDestructor";
2020} // namespace mlir::gc::gpu
2121
2222#ifndef GC_GPU_OCL_CONST_ONLY
23+ #include < cstdarg>
24+ #include < unordered_set>
25+ #include < vector>
2326
24- // TBD
27+ # include < CL/cl.h >
2528
29+ #include < llvm/ADT/SmallString.h>
30+
31+ #include " mlir/ExecutionEngine/ExecutionEngine.h"
32+ #include " mlir/IR/BuiltinOps.h"
33+
34+ namespace mlir ::gc::gpu {
35+ struct OclDevCtxPair {
36+ cl_device_id device;
37+ cl_context context;
38+ explicit OclDevCtxPair (cl_device_id device, cl_context context)
39+ : device(device), context(context) {}
40+
41+ bool operator ==(const OclDevCtxPair &other) const {
42+ return device == other.device && context == other.context ;
43+ }
44+ };
45+ } // namespace mlir::gc::gpu
46+ template <> struct std ::hash<const mlir::gc::gpu::OclDevCtxPair> {
47+ std::size_t
48+ operator ()(const mlir::gc::gpu::OclDevCtxPair &pair) const noexcept {
49+ return std::hash<cl_device_id>()(pair.device ) ^
50+ std::hash<cl_context>()(pair.context );
51+ }
52+ }; // namespace std
53+ namespace mlir ::gc::gpu {
54+ struct OclModule ;
55+ struct OclContext ;
56+ struct OclModuleBuilder ;
57+
58+ struct OclRuntime {
59+ // Returns the available Intel GPU device ids.
60+ [[nodiscard]] static llvm::Expected<SmallVector<cl_device_id, 2 >>
61+ gcIntelDevices (size_t max = std::numeric_limits<size_t >::max());
62+
63+ [[nodiscard]] static llvm::Expected<OclRuntime> get ();
64+
65+ [[nodiscard]] static llvm::Expected<OclRuntime> get (cl_device_id device);
66+
67+ [[nodiscard]] static llvm::Expected<OclRuntime> get (cl_command_queue queue);
68+
69+ [[nodiscard]] static llvm::Expected<OclRuntime> get (cl_device_id device,
70+ cl_context context);
71+
72+ static bool isOutOfOrder (cl_command_queue queue);
73+
74+ [[nodiscard]] cl_context getContext () const ;
75+
76+ [[nodiscard]] cl_device_id getDevice () const ;
77+
78+ [[nodiscard]] llvm::Expected<cl_command_queue>
79+ createQueue (bool outOfOrder = false ) const ;
80+
81+ [[nodiscard]] static llvm::Expected<bool >
82+ releaseQueue (cl_command_queue queue);
83+
84+ [[nodiscard]] llvm::Expected<void *> usmAllocDev (size_t size) const ;
85+
86+ [[nodiscard]] llvm::Expected<void *> usmAllocShared (size_t size) const ;
87+
88+ [[nodiscard]] llvm::Expected<bool > usmFree (const void *ptr) const ;
89+
90+ [[nodiscard]] llvm::Expected<bool > usmCpy (OclContext &ctx, const void *src,
91+ void *dst, size_t size) const ;
92+
93+ template <typename T>
94+ [[nodiscard]] llvm::Expected<T *> usmNewDev (size_t size) const {
95+ auto expected = usmAllocDev (size * sizeof (T));
96+ if (expected) {
97+ return static_cast <T *>(*expected);
98+ }
99+ return expected.takeError ();
100+ }
101+
102+ template <typename T>
103+ [[nodiscard]] llvm::Expected<T *> usmNewShared (size_t size) const {
104+ auto expected = usmAllocShared (size * sizeof (T));
105+ if (expected) {
106+ return static_cast <T *>(*expected);
107+ }
108+ return expected.takeError ();
109+ }
110+
111+ template <typename T>
112+ [[nodiscard]] llvm::Expected<bool > usmCpy (OclContext &ctx, const T *src,
113+ T *dst, size_t size) const {
114+ return usmCpy (ctx, static_cast <const void *>(src), static_cast <void *>(dst),
115+ size * sizeof (T));
116+ }
117+
118+ // Use with caution! This is safe to check validity of USM, but may be false
119+ // positive for any other kinds.
120+ bool isUsm (const void *ptr) const ;
121+
122+ bool operator ==(const OclRuntime &other) const {
123+ return getDevice () == other.getDevice () &&
124+ getContext () == other.getContext ();
125+ }
126+
127+ private:
128+ struct Ext ;
129+ struct Exports ;
130+ friend OclContext;
131+ friend OclModuleBuilder;
132+ explicit OclRuntime (const Ext &ext);
133+ const Ext &ext;
134+ };
135+
136+ static constexpr int64_t ZERO = 0 ;
137+ static constexpr auto ZERO_PTR = const_cast <int64_t *>(&ZERO);
138+
139+ // NOTE: The context is mutable and not thread-safe! It's expected to be used in
140+ // a single thread only.
141+ struct OclContext {
142+ const OclRuntime &runtime;
143+ cl_command_queue const queue;
144+ // Preserve the execution order. This is required in case of out-of-order
145+ // execution (CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE). When the execution
146+ // is completed, the 'lastEvent' field contains the event of the last enqueued
147+ // command. If this field is false, 'waitList' is ignored.
148+ const bool preserveOrder;
149+ cl_uint waitListLen;
150+ cl_event *waitList;
151+ cl_event lastEvent;
152+
153+ explicit OclContext (const OclRuntime &runtime, cl_command_queue queue,
154+ cl_uint waitListLen = 0 , cl_event *waitList = nullptr )
155+ : OclContext(runtime, queue, OclRuntime::isOutOfOrder(queue), waitListLen,
156+ waitList) {}
157+
158+ explicit OclContext (const OclRuntime &runtime, cl_command_queue queue,
159+ bool preserveOrder, cl_uint waitListLen,
160+ cl_event *waitList)
161+ : runtime(runtime), queue(queue), preserveOrder(preserveOrder),
162+ waitListLen(preserveOrder ? waitListLen : 0 ),
163+ waitList(preserveOrder ? waitList : nullptr ), lastEvent(nullptr ) {
164+ assert (!OclRuntime::isOutOfOrder (queue) || preserveOrder);
165+ assert (preserveOrder || (waitListLen == 0 && waitList == nullptr ));
166+ }
167+
168+ OclContext (const OclContext &) = delete ;
169+ OclContext &operator =(const OclContext &) = delete ;
170+
171+ void finish ();
172+
173+ private:
174+ friend OclRuntime;
175+ friend OclRuntime::Exports;
176+ template <unsigned N> friend struct OclModuleArgs ;
177+ // Contains the pointers of all non-USM arguments. It's expected, that the
178+ // arguments are either USM or CL pointers and most probably are USM, thus,
179+ // in most cases, this set will be empty.
180+ std::unordered_set<void *> clPtrs;
181+
182+ void setLastEvent (cl_event event) {
183+ lastEvent = event;
184+ if (event) {
185+ waitListLen = 1 ;
186+ waitList = &lastEvent;
187+ } else {
188+ waitListLen = 0 ;
189+ waitList = nullptr ;
190+ }
191+ }
192+ };
193+
194+ // The main function arguments in the following format -
195+ // https://mlir.llvm.org/docs/TargetLLVMIR/#c-compatible-wrapper-emission.
196+ // NOTE: The values are not copied, only the pointers are stored!
197+ // NOTE: This class is mutable and not thread-safe!
198+ template <unsigned N = 64 > struct OclModuleArgs {
199+ explicit OclModuleArgs (OclContext &ctx) : ctx(ctx) {}
200+ OclModuleArgs (const OclModuleArgs &) = delete ;
201+ OclModuleArgs &operator =(const OclModuleArgs &) = delete ;
202+
203+ void add (void *&alignedPtr, size_t rank, const int64_t *shape,
204+ const int64_t *strides, bool isUsm = true ) {
205+ add (alignedPtr, alignedPtr, ZERO, rank, shape, strides, isUsm);
206+ }
207+
208+ void add (void *&allocatedPtr, void *&alignedPtr, const int64_t &offset,
209+ size_t rank, const int64_t *shape, const int64_t *strides,
210+ bool isUsm = true ) {
211+ #ifndef NDEBUG
212+ assert (!isUsm || ctx.runtime .isUsm (alignedPtr));
213+ // It's recommended to have at least 16-byte alignment
214+ assert (reinterpret_cast <std::uintptr_t >(alignedPtr) % 16 == 0 );
215+ #endif
216+
217+ args.emplace_back (&allocatedPtr);
218+ args.emplace_back (&alignedPtr);
219+ args.emplace_back (const_cast <int64_t *>(&offset));
220+ for (size_t i = 0 ; i < rank; i++) {
221+ args.emplace_back (const_cast <int64_t *>(&shape[i]));
222+ }
223+ for (size_t i = 0 ; i < rank; i++) {
224+ args.emplace_back (const_cast <int64_t *>(&strides[i]));
225+ }
226+ if (!isUsm) {
227+ ctx.clPtrs .insert (alignedPtr);
228+ }
229+ }
230+
231+ template <typename T>
232+ void add (T *&alignedPtr, size_t rank, const int64_t *shape,
233+ const int64_t *strides, bool isUsm = true ) {
234+ add (reinterpret_cast <void *&>(alignedPtr), rank, shape, strides, isUsm);
235+ }
236+
237+ template <typename T>
238+ void add (T *&allocatedPtr, T *&alignedPtr, const int64_t &offset, size_t rank,
239+ const int64_t *shape, const int64_t *strides, bool isUsm = true ) {
240+ add (reinterpret_cast <void *&>(allocatedPtr),
241+ reinterpret_cast <void *&>(alignedPtr), offset, rank, shape, strides,
242+ isUsm);
243+ }
244+
245+ void clear () {
246+ args.clear ();
247+ ctx.clPtrs .clear ();
248+ }
249+
250+ private:
251+ friend OclModule;
252+ OclContext &ctx;
253+ SmallVector<void *, N + 3 > args;
254+ };
255+
256+ struct OclModule {
257+ const OclRuntime runtime;
258+
259+ using MainFunc = void (*)(void **);
260+
261+ explicit OclModule (const OclRuntime &runtime,
262+ std::unique_ptr<ExecutionEngine> engine, MainFunc main)
263+ : runtime(runtime), engine(std::move(engine)), main(main) {}
264+
265+ template <unsigned N> void exec (OclModuleArgs<N> &args) const {
266+ OclContext &ctx = args.ctx ;
267+ #ifndef NDEBUG
268+ auto rt = OclRuntime::get (ctx.queue );
269+ assert (rt);
270+ assert (*rt == this ->runtime );
271+ #endif
272+ auto size = args.args .size ();
273+ auto ctxPtr = &ctx;
274+ args.args .emplace_back (&ctxPtr);
275+ args.args .emplace_back (&ctxPtr);
276+ args.args .emplace_back (ZERO_PTR);
277+ main (args.args .data ());
278+ args.args .truncate (size);
279+ }
280+
281+ ~OclModule ();
282+ OclModule (const OclModule &) = delete ;
283+ OclModule &operator =(const OclModule &) = delete ;
284+ OclModule (const OclModule &&) = delete ;
285+ OclModule &operator =(const OclModule &&) = delete ;
286+
287+ private:
288+ std::unique_ptr<ExecutionEngine> engine;
289+ MainFunc main;
290+ };
291+
292+ struct OclModuleBuilder {
293+ friend OclRuntime;
294+ explicit OclModuleBuilder (ModuleOp module );
295+ explicit OclModuleBuilder (OwningOpRef<ModuleOp> &module )
296+ : OclModuleBuilder(module .release()) {}
297+
298+ llvm::Expected<std::shared_ptr<const OclModule>>
299+ build (const OclRuntime &runtime);
300+
301+ llvm::Expected<std::shared_ptr<const OclModule>>
302+ build (cl_command_queue queue);
303+
304+ llvm::Expected<std::shared_ptr<const OclModule>> build (cl_device_id device,
305+ cl_context context);
306+
307+ private:
308+ std::shared_mutex mux;
309+ ModuleOp mlirModule;
310+ SmallString<32 > funcName;
311+ std::unordered_map<const OclDevCtxPair, std::shared_ptr<const OclModule>>
312+ cache;
313+ llvm::Expected<std::shared_ptr<const OclModule>>
314+
315+ build (const OclRuntime::Ext &ext);
316+ };
317+ }; // namespace mlir::gc::gpu
26318#else
27319#undef GC_GPU_OCL_CONST_ONLY
28320#endif
0 commit comments