@@ -128,7 +128,8 @@ struct OclRuntime {
128128 struct Exports ;
129129 friend OclContext;
130130 friend OclModuleBuilder;
131- template <unsigned N> friend struct OclModuleExecutor ;
131+ template <unsigned N> friend struct DynamicExecutor ;
132+ template <unsigned N> friend struct StaticExecutor ;
132133 explicit OclRuntime (const Ext &ext);
133134 const Ext &ext;
134135
@@ -173,12 +174,13 @@ struct OclContext {
173174 OclContext (const OclContext &) = delete ;
174175 OclContext &operator =(const OclContext &) = delete ;
175176
176- void finish ();
177+ [[nodiscard]] llvm::Expected< bool > finish ();
177178
178179private:
179180 friend OclRuntime;
180181 friend OclRuntime::Exports;
181- template <unsigned N> friend struct OclModuleExecutor ;
182+ template <unsigned N> friend struct DynamicExecutor ;
183+ template <unsigned N> friend struct StaticExecutor ;
182184 std::unordered_set<void *> *clPtrs;
183185
184186 void setLastEvent (cl_event event) {
@@ -195,6 +197,9 @@ struct OclContext {
195197
196198struct OclModule {
197199 const OclRuntime runtime;
200+ // If all the function arguments have static shapes, then this field is true
201+ // and main.staticMain is used. Otherwise, main.wrappedMain is used.
202+ const bool isStatic;
198203
199204 ~OclModule ();
200205 OclModule (const OclModule &) = delete ;
@@ -204,24 +209,35 @@ struct OclModule {
204209
205210private:
206211 friend OclModuleBuilder;
207- template <unsigned N> friend struct OclModuleExecutor ;
208- using MainFunc = void (*)(void **);
212+ template <unsigned N> friend struct DynamicExecutor ;
213+ template <unsigned N> friend struct OclModuleExecutorBase ;
214+ template <unsigned N> friend struct StaticExecutor ;
215+ // This function is only created when all args are memrefs with static shape.
216+ using StaticMainFunc = void (*)(OclContext *, void **);
217+ // Wrapper, generated by the engine. The arguments are pointers to the values.
218+ using WrappedMainFunc = void (*)(void **);
219+ union MainFunc {
220+ StaticMainFunc staticMain;
221+ WrappedMainFunc wrappedMain;
222+ };
209223 const MainFunc main;
210224 const FunctionType functionType;
211225 std::unique_ptr<ExecutionEngine> engine;
212226
213- explicit OclModule (const OclRuntime &runtime, const MainFunc main ,
214- func::FuncOp functionOp ,
227+ explicit OclModule (const OclRuntime &runtime, const bool isStatic ,
228+ const MainFunc main, const FunctionType functionType ,
215229 std::unique_ptr<ExecutionEngine> engine)
216- : runtime(runtime), main(main),
217- functionType(functionOp.getFunctionType() ), engine(std::move(engine)) {}
230+ : runtime(runtime), isStatic(isStatic), main(main),
231+ functionType(functionType ), engine(std::move(engine)) {}
218232};
219233
220234struct OclModuleBuilder {
221235 friend OclRuntime;
222236 explicit OclModuleBuilder (ModuleOp module );
223237 explicit OclModuleBuilder (OwningOpRef<ModuleOp> &module )
224238 : OclModuleBuilder(module .release()) {}
239+ explicit OclModuleBuilder (OwningOpRef<ModuleOp> &&module )
240+ : OclModuleBuilder(module .release()) {}
225241
226242 llvm::Expected<std::shared_ptr<const OclModule>>
227243 build (const OclRuntime &runtime);
@@ -243,105 +259,141 @@ struct OclModuleBuilder {
243259 build (const OclRuntime::Ext &ext);
244260};
245261
246- // The main function arguments are added in the following format -
247- // https://mlir.llvm.org/docs/TargetLLVMIR/#c-compatible-wrapper-emission.
248262// NOTE: This class is mutable and not thread-safe!
249- // NOTE: The argument values are not copied, only the pointers are stored!
250- template <unsigned N = 64 > struct OclModuleExecutor {
251- explicit OclModuleExecutor (std::shared_ptr<const OclModule> &mod)
263+ template <unsigned N> struct OclModuleExecutorBase {
264+
265+ void reset () {
266+ args.clear ();
267+ clPtrs.clear ();
268+ argCounter = 0 ;
269+ }
270+
271+ Type getArgType (unsigned idx) const {
272+ assert (idx < mod->functionType .getNumInputs ());
273+ return mod->functionType .getInput (idx);
274+ }
275+
276+ [[nodiscard]] bool isSmall () const { return args.small (); }
277+
278+ protected:
279+ struct Args : SmallVector<void *, N> {
280+ [[nodiscard]] bool small () const { return this ->isSmall (); }
281+ };
282+
283+ const std::shared_ptr<const OclModule> &mod;
284+ // Contains the pointers of all non-USM arguments. It's expected, that the
285+ // arguments are either USM or CL pointers and most probably are USM, thus,
286+ // in most cases, this set will be empty.
287+ std::unordered_set<void *> clPtrs;
288+ Args args;
289+ unsigned argCounter = 0 ;
290+
291+ explicit OclModuleExecutorBase (std::shared_ptr<const OclModule> &mod)
252292 : mod(mod) {}
253- OclModuleExecutor (const OclModuleExecutor &) = delete ;
254- OclModuleExecutor &operator =(const OclModuleExecutor &) = delete ;
255- OclModuleExecutor (const OclModuleExecutor &&) = delete ;
256- OclModuleExecutor &operator =(const OclModuleExecutor &&) = delete ;
257293
258- void exec (OclContext &ctx) {
259294#ifndef NDEBUG
295+ void checkCtx (const OclContext &ctx) const {
260296 auto rt = OclRuntime::get (ctx.queue );
261297 assert (rt);
262298 assert (*rt == mod->runtime );
299+ assert (argCounter == mod->functionType .getNumInputs ());
300+ }
301+
302+ void checkArg (void *alignedPtr, bool isUsm = true ) const {
303+ assert (!isUsm || mod->runtime .isUsm (alignedPtr));
304+ // It's recommended to have at least 16-byte alignment
305+ assert (reinterpret_cast <std::uintptr_t >(alignedPtr) % 16 == 0 );
306+ }
263307#endif
264- auto size = args.size ();
265- auto ctxPtr = &ctx;
266- ctx.clPtrs = &clPtrs;
267- args.emplace_back (&ctxPtr);
268- args.emplace_back (&ctxPtr);
269- args.emplace_back (ZERO_PTR);
270- mod->main (args.data ());
271- args.truncate (size);
308+ };
309+
310+ // NOTE: This executor can only be used if mod->isStatic == true!
311+ template <unsigned N = 8 > struct StaticExecutor : OclModuleExecutorBase<N> {
312+ explicit StaticExecutor (std::shared_ptr<const OclModule> &mod)
313+ : OclModuleExecutorBase<N>(mod) {
314+ assert (this ->mod ->isStatic );
272315 }
273316
274- void operator ()(OclContext &ctx) { exec (ctx); }
317+ void exec (OclContext &ctx) {
318+ #ifndef NDEBUG
319+ this ->checkCtx (ctx);
320+ #endif
321+ ctx.clPtrs = &this ->clPtrs ;
322+ this ->mod ->main .staticMain (&ctx, this ->args .data ());
323+ }
275324
276- template <typename T>
277- [[nodiscard]] bool operator ()(OclContext &ctx, T **ptr1, ...) {
278- {
279- SmallVector<int64_t > values;
280- auto argTypes = mod->functionType .getInputs ();
281- unsigned numValues = 0 ;
282-
283- for (unsigned i = 0 , n = argTypes.size () - 1 ; i < n; i++) {
284- if (auto type = llvm::dyn_cast<MemRefType>(argTypes[i])) {
285- if (type.hasStaticShape ()) {
286- numValues += type.getShape ().size () * 2 + 1 ;
287- continue ;
288- }
289- }
325+ void arg (void *alignedPtr, bool isUsm = true ) {
290326#ifndef NDEBUG
291- OclRuntime::debug (
292- __FILE__, __LINE__,
293- " Only memref arguments with static shape are supported." );
327+ this ->checkArg (alignedPtr, isUsm);
328+ std::ostringstream oss;
329+ oss << " Arg" << this ->argCounter << " : alignedPtr=" << alignedPtr
330+ << " , isUsm=" << (isUsm ? " true" : " false" );
331+ OclRuntime::debug (__FILE__, __LINE__, oss.str ().c_str ());
294332#endif
295- return false ;
296- }
333+ ++this ->argCounter ;
334+ this ->args .emplace_back (alignedPtr);
335+ if (!isUsm) {
336+ this ->clPtrs .insert (alignedPtr);
337+ }
338+ }
339+
340+ template <typename T> void arg (T *alignedPtr, bool isUsm = true ) {
341+ arg (reinterpret_cast <void *>(alignedPtr), isUsm);
342+ }
343+
344+ void operator ()(OclContext &ctx) { exec (ctx); }
297345
298- values.reserve (numValues);
299- SmallVector<int64_t > strides;
300- int64_t offset;
346+ template <typename T> void operator ()(OclContext &ctx, T *ptr1, ...) {
347+ {
348+ this ->reset ();
349+ arg (reinterpret_cast <void *>(ptr1));
301350 va_list args;
302351 va_start (args, ptr1);
303-
304- for (unsigned i = 0 , n = argTypes.size () - 1 ; i < n; i++) {
305- auto type = llvm::dyn_cast<MemRefType>(argTypes[i]);
306- strides.clear ();
307- if (failed (getStridesAndOffset (type, strides, offset))) {
308- #ifndef NDEBUG
309- OclRuntime::debug (__FILE__, __LINE__,
310- " Failed to get strides and offset." );
311- #endif
312- return false ;
313- }
314- auto offsetPtr = values.end ();
315- values.emplace_back (offset);
316- auto shapePtr = values.end ();
317- auto shape = type.getShape ();
318- values.append (shape.begin (), shape.end ());
319- auto stridesPtr = values.end ();
320- values.append (strides.begin (), strides.end ());
321- auto ptr =
322- (i == 0 ) ? reinterpret_cast <void **>(ptr1) : va_arg (args, void **);
323- addArg (*ptr, *ptr, *offsetPtr, shape.size (), shapePtr, stridesPtr);
352+ for (unsigned i = 0 , n = this ->mod ->functionType .getNumInputs () - 1 ;
353+ i < n; i++) {
354+ arg (va_arg (args, void *));
324355 }
325-
326356 va_end (args);
327357 exec (ctx);
328- return true ;
329358 }
330359 }
360+ };
331361
332- void addArg (void *&alignedPtr, size_t rank, const int64_t *shape,
333- const int64_t *strides, bool isUsm = true ) {
334- addArg (alignedPtr, alignedPtr, ZERO, rank, shape, strides, isUsm);
362+ // The main function arguments are added in the following format -
363+ // https://mlir.llvm.org/docs/TargetLLVMIR/#c-compatible-wrapper-emission.
364+ // NOTE: This executor can only be used if mod->isStatic != true!
365+ template <unsigned N = 64 > struct DynamicExecutor : OclModuleExecutorBase<N> {
366+ explicit DynamicExecutor (std::shared_ptr<const OclModule> &mod)
367+ : OclModuleExecutorBase<N>(mod) {
368+ assert (!this ->mod ->isStatic );
335369 }
336370
337- void addArg (void *&allocatedPtr, void *&alignedPtr, const int64_t &offset,
338- size_t rank, const int64_t *shape, const int64_t *strides,
339- bool isUsm = true ) {
371+ void exec (OclContext &ctx) {
340372#ifndef NDEBUG
341- assert (!isUsm || mod->runtime .isUsm (alignedPtr));
342- // It's recommended to have at least 16-byte alignment
343- assert (reinterpret_cast <std::uintptr_t >(alignedPtr) % 16 == 0 );
344- if (auto type = llvm::dyn_cast<MemRefType>(getArgType (argCounter))) {
373+ this ->checkCtx (ctx);
374+ #endif
375+ auto size = this ->args .size ();
376+ auto ctxPtr = &ctx;
377+ this ->args .emplace_back (&ctxPtr);
378+ this ->args .emplace_back (&ctxPtr);
379+ this ->args .emplace_back (ZERO_PTR);
380+ this ->mod ->main .wrappedMain (this ->args .data ());
381+ this ->args .truncate (size);
382+ }
383+
384+ void arg (void *&alignedPtr, size_t rank, const int64_t *shape,
385+ const int64_t *strides, bool isUsm = true ) {
386+ arg (alignedPtr, alignedPtr, ZERO, rank, shape, strides, isUsm);
387+ }
388+
389+ // NOTE: The argument values are not copied, only the pointers are stored!
390+ void arg (void *&allocatedPtr, void *&alignedPtr, const int64_t &offset,
391+ size_t rank, const int64_t *shape, const int64_t *strides,
392+ bool isUsm = true ) {
393+ #ifndef NDEBUG
394+ this ->checkArg (alignedPtr, isUsm);
395+ if (auto type =
396+ llvm::dyn_cast<MemRefType>(this ->getArgType (this ->argCounter ))) {
345397 if (type.hasStaticShape ()) {
346398 auto size = type.getShape ();
347399 assert (rank == size.size ());
@@ -361,8 +413,9 @@ template <unsigned N = 64> struct OclModuleExecutor {
361413 }
362414
363415 std::ostringstream oss;
364- oss << " Arg" << argCounter << " : ptr=" << allocatedPtr
365- << " , alignedPtr=" << alignedPtr << " , offset=" << offset
416+ oss << " Arg" << this ->argCounter << " : ptr=" << allocatedPtr
417+ << " , alignedPtr=" << alignedPtr
418+ << " , isUsm=" << (isUsm ? " true" : " false" ) << " , offset=" << offset
366419 << " , shape=[" ;
367420 for (unsigned i = 0 ; i < rank; i++) {
368421 oss << shape[i] << (i + 1 < rank ? " , " : " ]" );
@@ -374,55 +427,36 @@ template <unsigned N = 64> struct OclModuleExecutor {
374427 OclRuntime::debug (__FILE__, __LINE__, oss.str ().c_str ());
375428#endif
376429
377- argCounter++ ;
378- args.emplace_back (&allocatedPtr);
379- args.emplace_back (&alignedPtr);
380- args.emplace_back (const_cast <int64_t *>(&offset));
430+ ++ this -> argCounter ;
431+ this -> args .emplace_back (&allocatedPtr);
432+ this -> args .emplace_back (&alignedPtr);
433+ this -> args .emplace_back (const_cast <int64_t *>(&offset));
381434 for (size_t i = 0 ; i < rank; i++) {
382- args.emplace_back (const_cast <int64_t *>(&shape[i]));
435+ this -> args .emplace_back (const_cast <int64_t *>(&shape[i]));
383436 }
384437 for (size_t i = 0 ; i < rank; i++) {
385- args.emplace_back (const_cast <int64_t *>(&strides[i]));
438+ this -> args .emplace_back (const_cast <int64_t *>(&strides[i]));
386439 }
387440 if (!isUsm) {
388- clPtrs.insert (alignedPtr);
441+ this -> clPtrs .insert (alignedPtr);
389442 }
390443 }
391444
392445 template <typename T>
393- void addArg (T *&alignedPtr, size_t rank, const int64_t *shape,
394- const int64_t *strides, bool isUsm = true ) {
395- addArg (reinterpret_cast <void *&>(alignedPtr), rank, shape, strides, isUsm);
446+ void arg (T *&alignedPtr, size_t rank, const int64_t *shape,
447+ const int64_t *strides, bool isUsm = true ) {
448+ arg (reinterpret_cast <void *&>(alignedPtr), rank, shape, strides, isUsm);
396449 }
397450
398451 template <typename T>
399- void addArg (T *&allocatedPtr, T *&alignedPtr, const int64_t &offset,
400- size_t rank, const int64_t *shape, const int64_t *strides,
401- bool isUsm = true ) {
402- addArg (reinterpret_cast <void *&>(allocatedPtr),
403- reinterpret_cast <void *&>(alignedPtr), offset, rank, shape, strides,
404- isUsm);
405- }
406-
407- Type getArgType (unsigned idx) const {
408- assert (idx < mod->functionType .getNumInputs () - 1 );
409- return mod->functionType .getInput (idx);
410- }
411-
412- void reset () {
413- args.clear ();
414- clPtrs.clear ();
415- argCounter = 0 ;
452+ void arg (T *&allocatedPtr, T *&alignedPtr, const int64_t &offset, size_t rank,
453+ const int64_t *shape, const int64_t *strides, bool isUsm = true ) {
454+ arg (reinterpret_cast <void *&>(allocatedPtr),
455+ reinterpret_cast <void *&>(alignedPtr), offset, rank, shape, strides,
456+ isUsm);
416457 }
417458
418- private:
419- const std::shared_ptr<const OclModule> &mod;
420- // Contains the pointers of all non-USM arguments. It's expected, that the
421- // arguments are either USM or CL pointers and most probably are USM, thus,
422- // in most cases, this set will be empty.
423- std::unordered_set<void *> clPtrs;
424- SmallVector<void *, N + 3 > args;
425- unsigned argCounter = 0 ;
459+ void operator ()(OclContext &ctx) { exec (ctx); }
426460};
427461}; // namespace mlir::gc::gpu
428462#else
0 commit comments