1616#include " llvm/ExecutionEngine/Orc/LLJIT.h"
1717#include " llvm/Support/Error.h"
1818
19+ #include " mlir/Dialect/DLTI/DLTI.h"
1920#include " mlir/Dialect/Func/IR/FuncOps.h"
2021#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
22+ #include " mlir/Interfaces/DataLayoutInterfaces.h"
2123#include " mlir/Pass/PassManager.h"
2224
2325namespace mlir ::gc::gpu {
@@ -148,10 +150,12 @@ struct Kernel {
148150 }
149151
150152 ~Kernel () {
151- CL_CHECKR (clReleaseKernel (kernel), " Failed to release OpenCL kernel." );
152- gcLogD (" Released OpenCL kernel: " , kernel);
153- CL_CHECKR (clReleaseProgram (program), " Failed to release OpenCL program." );
154- gcLogD (" Released OpenCL program: " , program);
153+ if (kernel != nullptr ) {
154+ CL_CHECKR (clReleaseKernel (kernel), " Failed to release OpenCL kernel." );
155+ gcLogD (" Released OpenCL kernel: " , kernel);
156+ CL_CHECKR (clReleaseProgram (program), " Failed to release OpenCL program." );
157+ gcLogD (" Released OpenCL program: " , program);
158+ }
155159 }
156160};
157161
@@ -220,7 +224,14 @@ struct OclRuntime::Exports {
220224 gcLogD (" The program has been built: " , program);
221225
222226 auto kernel = clCreateKernel (program, name, &err);
223- CL_CHECKR (err, " Failed to create OpenCL kernel from program: " , program);
227+ if (err != CL_SUCCESS) {
228+ // This is a special case, handled by OclModuleBuilder::build(), that
229+ // allows rebuilding the kernel with different options in case of failure.
230+ clReleaseProgram (program);
231+ gcLogD (" OpenCL error " , err,
232+ " : Failed to create OpenCL kernel from program: " , program);
233+ return new Kernel (nullptr , nullptr , gridSize, blockSize, argNum, argSize);
234+ }
224235 gcLogD (" Created new OpenCL kernel " , kernel, " from program " , program);
225236
226237 cl_bool enable = CL_TRUE;
@@ -639,8 +650,7 @@ void OclContext::setLastEvent(cl_event event) {
639650 }
640651}
641652
642- OclModule::~OclModule () {
643- assert (engine);
653+ static void destroyKernels (const std::unique_ptr<ExecutionEngine> &engine) {
644654 auto fn = engine->lookup (GPU_OCL_MOD_DESTRUCTOR);
645655 if (fn) {
646656 reinterpret_cast <void (*)()>(fn.get ())();
@@ -649,13 +659,19 @@ OclModule::~OclModule() {
649659 }
650660}
651661
662+ OclModule::~OclModule () {
663+ assert (engine);
664+ destroyKernels (engine);
665+ }
666+
652667// If all arguments of 'origFunc' are memrefs with static shape, create a new
653668// function called gcGpuOclStaticMain, that accepts 2 arguments: a pointer to
654669// OclContext and a pointer to an array, containing pointers to aligned memory
655670// buffers. The function will call the original function with the context,
656671// buffers and the offset/shape/strides, statically created from the
657672// memref descriptor.
658- StringRef createStaticMain (ModuleOp &module , const StringRef &funcName,
673+ StringRef createStaticMain (OpBuilder &builder, ModuleOp &module ,
674+ const StringRef &funcName,
659675 const ArrayRef<Type> argTypes) {
660676 auto mainFunc = module .lookupSymbol <LLVM::LLVMFuncOp>(funcName);
661677 if (!mainFunc) {
@@ -670,11 +686,8 @@ StringRef createStaticMain(ModuleOp &module, const StringRef &funcName,
670686 " ' must have an least 3 arguments." );
671687 }
672688
673- auto ctx = module .getContext ();
674- ctx->getOrLoadDialect <LLVM::LLVMDialect>();
675- OpBuilder builder (ctx);
676689 auto i64Type = builder.getI64Type ();
677- auto ptrType = LLVM::LLVMPointerType::get (ctx );
690+ auto ptrType = LLVM::LLVMPointerType::get (builder. getContext () );
678691
679692 if (mainArgTypes[nargs - 3 ] != ptrType ||
680693 mainArgTypes[nargs - 2 ] != ptrType ||
@@ -722,7 +735,7 @@ StringRef createStaticMain(ModuleOp &module, const StringRef &funcName,
722735 auto loc = mainFunc.getLoc ();
723736 auto newFuncType = LLVM::LLVMFunctionType::get (
724737 mainFunc.getNumResults () ? mainFunc->getResult (0 ).getType ()
725- : LLVM::LLVMVoidType::get (ctx ),
738+ : LLVM::LLVMVoidType::get (builder. getContext () ),
726739 {ptrType, ptrType});
727740 auto newFunc =
728741 OpBuilder::atBlockEnd (module .getBody ())
@@ -848,17 +861,57 @@ OclModuleBuilder::build(cl_device_id device, cl_context context) {
848861
849862llvm::Expected<std::shared_ptr<const OclModule>>
850863OclModuleBuilder::build (const OclRuntime::Ext &ext) {
851- auto mod = mlirModule.clone ();
852- PassManager pm{mod.getContext ()};
853- pipeline (pm);
854- CHECK (!pm.run (mod).failed (), " GPU pipeline failed!" );
864+ auto ctx = mlirModule.getContext ();
865+ ctx->getOrLoadDialect <DLTIDialect>();
866+ ctx->getOrLoadDialect <LLVM::LLVMDialect>();
867+ OpBuilder builder (ctx);
868+ DataLayoutEntryInterface dltiAttrs[6 ];
855869
856- auto staticMain = createStaticMain (mod, funcName, argTypes);
870+ {
871+ struct DevInfo {
872+ cl_device_info key;
873+ const char *attrName;
874+ };
875+ DevInfo devInfo[]{
876+ {CL_DEVICE_MAX_COMPUTE_UNITS, " num_exec_units" },
877+ {CL_DEVICE_NUM_EUS_PER_SUB_SLICE_INTEL, " num_exec_units_per_slice" },
878+ {CL_DEVICE_NUM_THREADS_PER_EU_INTEL, " num_threads_per_eu" },
879+ {CL_DEVICE_LOCAL_MEM_SIZE, " local_mem_size" },
880+ };
857881
858- if (printIr) {
859- mod.dump ();
860- }
882+ unsigned i = 0 ;
883+ for (auto &[key, attrName] : devInfo) {
884+ int64_t value = 0 ;
885+ CL_CHECK (
886+ clGetDeviceInfo (ext.device , key, sizeof (cl_ulong), &value, nullptr ),
887+ " Failed to get the device property " , attrName);
888+ gcLogD (" Device property " , attrName, " =" , value);
889+ dltiAttrs[i++] =
890+ DataLayoutEntryAttr::get (ctx, builder.getStringAttr (attrName),
891+ builder.getI64IntegerAttr (value));
892+ }
861893
894+ // There is no a corresponding property in the OpenCL API, using the
895+ // hardcoded value.
896+ // TODO: Get the real value.
897+ dltiAttrs[i] = DataLayoutEntryAttr::get (
898+ ctx, builder.getStringAttr (" max_vector_op_width" ),
899+ builder.getI64IntegerAttr (512 ));
900+ }
901+
902+ OclRuntime rt (ext);
903+ auto expectedQueue = rt.createQueue ();
904+ CHECKE (expectedQueue, " Failed to create queue!" );
905+ struct OclQueue {
906+ cl_command_queue queue;
907+ ~OclQueue () { clReleaseCommandQueue (queue); }
908+ } queue{*expectedQueue};
909+ OclContext oclCtx{rt, queue.queue , false };
910+
911+ ModuleOp mod;
912+ StringRef staticMain;
913+ std::unique_ptr<ExecutionEngine> eng;
914+ auto devStr = builder.getStringAttr (" GPU" /* device ID*/ );
862915 ExecutionEngineOptions opts;
863916 opts.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive;
864917 opts.enableObjectDump = enableObjectDump;
@@ -868,18 +921,86 @@ OclModuleBuilder::build(const OclRuntime::Ext &ext) {
868921 opts.enablePerfNotificationListener = false ;
869922#endif
870923
871- auto eng = ExecutionEngine::create (mod, opts);
872- CHECKE (eng, " Failed to create ExecutionEngine!" );
873- eng->get ()->registerSymbols (OclRuntime::Exports::symbolMap);
924+ // Build the module and check the kernels workgroup size. If the workgroup
925+ // size is different, rebuild the module with the new size.
926+ for (size_t wgSize = 64 , maxSize = std::numeric_limits<size_t >::max ();;) {
927+ dltiAttrs[sizeof (dltiAttrs) / sizeof (DataLayoutEntryInterface) - 1 ] =
928+ DataLayoutEntryAttr::get (
929+ ctx, builder.getStringAttr (" max_work_group_size" ),
930+ builder.getI64IntegerAttr (static_cast <int64_t >(wgSize)));
931+ TargetDeviceSpecInterface devSpec =
932+ TargetDeviceSpecAttr::get (ctx, dltiAttrs);
933+ auto sysSpec =
934+ TargetSystemSpecAttr::get (ctx, ArrayRef (std::pair (devStr, devSpec)));
935+ mod = mlirModule.clone ();
936+ mod.getOperation ()->setAttr (" #dlti.sys_spec" , sysSpec);
937+ PassManager pm{ctx};
938+ pipeline (pm);
939+ CHECK (!pm.run (mod).failed (), " GPU pipeline failed!" );
940+ staticMain = createStaticMain (builder, mod, funcName, argTypes);
941+ auto expectedEng = ExecutionEngine::create (mod, opts);
942+ CHECKE (expectedEng, " Failed to create ExecutionEngine!" );
943+ expectedEng->get ()->registerSymbols (OclRuntime::Exports::symbolMap);
944+
945+ // Find all kernels and query the workgroup size
946+ size_t minSize = maxSize;
947+ mod.walk <>([&](LLVM::LLVMFuncOp func) {
948+ auto name = func.getName ();
949+ if (!name.starts_with (" createGcGpuOclKernel_" )) {
950+ return WalkResult::skip ();
951+ }
952+ auto fn = expectedEng.get ()->lookup (name);
953+ if (!fn) {
954+ gcLogE (" Function not found: " , name.data ());
955+ return WalkResult::skip ();
956+ }
957+
958+ Kernel *kernel =
959+ reinterpret_cast <Kernel *(*)(OclContext *)>(fn.get ())(&oclCtx);
960+
961+ if (kernel->kernel == nullptr ) {
962+ maxSize = wgSize / 2 ;
963+ if (maxSize == 0 ) {
964+ gcReportErr (" Failed to build the kernel." );
965+ }
966+ minSize = maxSize;
967+ return WalkResult::interrupt ();
968+ }
969+
970+ size_t s = 0 ;
971+ auto err = clGetKernelWorkGroupInfo (kernel->kernel , ext.device ,
972+ CL_KERNEL_WORK_GROUP_SIZE,
973+ sizeof (size_t ), &s, nullptr );
974+ if (err == CL_SUCCESS) {
975+ minSize = std::min (minSize, s);
976+ } else {
977+ gcLogE (" Failed to get the kernel workgroup size: " , err);
978+ }
979+ return WalkResult::skip ();
980+ });
981+
982+ if (minSize == wgSize || minSize == std::numeric_limits<size_t >::max ()) {
983+ eng = std::move (*expectedEng);
984+ break ;
985+ }
986+
987+ destroyKernels (expectedEng.get ());
988+ gcLogD (" Changing the workgroup size from " , wgSize, " to " , minSize);
989+ wgSize = minSize;
990+ }
991+
992+ if (printIr) {
993+ mod.dump ();
994+ }
874995
875996 OclModule::MainFunc main = {nullptr };
876997
877998 if (staticMain.empty ()) {
878- auto expect = eng. get () ->lookupPacked (funcName);
999+ auto expect = eng->lookupPacked (funcName);
8791000 CHECKE (expect, " Packed function '" , funcName.begin (), " ' not found!" );
8801001 main.wrappedMain = *expect;
8811002 } else {
882- auto expect = eng. get () ->lookup (staticMain);
1003+ auto expect = eng->lookup (staticMain);
8831004 CHECKE (expect, " Compiled function '" , staticMain.begin (), " ' not found!" );
8841005 main.staticMain = reinterpret_cast <OclModule::StaticMainFunc>(*expect);
8851006 }
@@ -889,8 +1010,7 @@ OclModuleBuilder::build(const OclRuntime::Ext &ext) {
8891010 return it->second ;
8901011 }
8911012 std::shared_ptr<const OclModule> ptr (
892- new OclModule (OclRuntime (ext), !staticMain.empty (), main, argTypes,
893- std::move (eng.get ())));
1013+ new OclModule (rt, !staticMain.empty (), main, argTypes, std::move (eng)));
8941014 return cache.emplace (OclDevCtxPair (ext.device , ext.context ), ptr)
8951015 .first ->second ;
8961016}
0 commit comments