1616#include " llvm/ExecutionEngine/Orc/LLJIT.h"
1717#include " llvm/Support/Error.h"
1818
19+ #include " mlir/Dialect/DLTI/DLTI.h"
1920#include " mlir/Dialect/Func/IR/FuncOps.h"
2021#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
22+ #include " mlir/Interfaces/DataLayoutInterfaces.h"
2123#include " mlir/Pass/PassManager.h"
2224
2325namespace mlir ::gc::gpu {
@@ -655,7 +657,8 @@ OclModule::~OclModule() {
655657// buffers. The function will call the original function with the context,
656658// buffers and the offset/shape/strides, statically created from the
657659// memref descriptor.
658- StringRef createStaticMain (ModuleOp &module , const StringRef &funcName,
660+ StringRef createStaticMain (OpBuilder &builder, ModuleOp &module ,
661+ const StringRef &funcName,
659662 const ArrayRef<Type> argTypes) {
660663 auto mainFunc = module .lookupSymbol <LLVM::LLVMFuncOp>(funcName);
661664 if (!mainFunc) {
@@ -670,11 +673,8 @@ StringRef createStaticMain(ModuleOp &module, const StringRef &funcName,
670673 " ' must have an least 3 arguments." );
671674 }
672675
673- auto ctx = module .getContext ();
674- ctx->getOrLoadDialect <LLVM::LLVMDialect>();
675- OpBuilder builder (ctx);
676676 auto i64Type = builder.getI64Type ();
677- auto ptrType = LLVM::LLVMPointerType::get (ctx );
677+ auto ptrType = LLVM::LLVMPointerType::get (builder. getContext () );
678678
679679 if (mainArgTypes[nargs - 3 ] != ptrType ||
680680 mainArgTypes[nargs - 2 ] != ptrType ||
@@ -722,7 +722,7 @@ StringRef createStaticMain(ModuleOp &module, const StringRef &funcName,
722722 auto loc = mainFunc.getLoc ();
723723 auto newFuncType = LLVM::LLVMFunctionType::get (
724724 mainFunc.getNumResults () ? mainFunc->getResult (0 ).getType ()
725- : LLVM::LLVMVoidType::get (ctx ),
725+ : LLVM::LLVMVoidType::get (builder. getContext () ),
726726 {ptrType, ptrType});
727727 auto newFunc =
728728 OpBuilder::atBlockEnd (module .getBody ())
@@ -848,17 +848,57 @@ OclModuleBuilder::build(cl_device_id device, cl_context context) {
848848
849849llvm::Expected<std::shared_ptr<const OclModule>>
850850OclModuleBuilder::build (const OclRuntime::Ext &ext) {
851- auto mod = mlirModule.clone ();
852- PassManager pm{mod. getContext ()} ;
853- pipeline (pm );
854- CHECK (!pm. run (mod). failed (), " GPU pipeline failed! " ) ;
851+ auto ctx = mlirModule.getContext ();
852+ ctx-> getOrLoadDialect <LLVM::LLVMDialect>() ;
853+ OpBuilder builder (ctx );
854+ DataLayoutEntryInterface dltiAttrs[ 6 ] ;
855855
856- auto staticMain = createStaticMain (mod, funcName, argTypes);
856+ {
857+ struct DevInfo {
858+ cl_device_info key;
859+ const char *attrName;
860+ };
861+ DevInfo devInfo[]{
862+ {CL_DEVICE_MAX_COMPUTE_UNITS, " num_exec_units" },
863+ {CL_DEVICE_NUM_EUS_PER_SUB_SLICE_INTEL, " num_exec_units_per_slice" },
864+ {CL_DEVICE_NUM_THREADS_PER_EU_INTEL, " num_threads_per_eu" },
865+ // Assuming the cache size is equal to the local mem
866+ {CL_DEVICE_LOCAL_MEM_SIZE, " L1_cache_size_in_bytes" },
867+ };
857868
858- if (printIr) {
859- mod.dump ();
860- }
869+ unsigned i = 0 ;
870+ for (auto &[key, attrName] : devInfo) {
871+ int64_t value = 0 ;
872+ CL_CHECK (
873+ clGetDeviceInfo (ext.device , key, sizeof (cl_ulong), &value, nullptr ),
874+ " Failed to get the device property " , attrName);
875+ gcLogD (" Device property " , attrName, " =" , value);
876+ dltiAttrs[i++] =
877+ DataLayoutEntryAttr::get (ctx, builder.getStringAttr (attrName),
878+ builder.getI64IntegerAttr (value));
879+ }
861880
881+ // There is no a corresponding property in the OpenCL API, using the
882+ // hardcoded value.
883+ // TODO: Get the real value.
884+ dltiAttrs[i] = DataLayoutEntryAttr::get (
885+ ctx, builder.getStringAttr (" max_vector_op_width" ),
886+ builder.getI64IntegerAttr (512 ));
887+ }
888+
889+ OclRuntime rt (ext);
890+ auto expectedQueue = rt.createQueue ();
891+ CHECKE (expectedQueue, " Failed to create queue!" );
892+ struct OclQueue {
893+ cl_command_queue queue;
894+ ~OclQueue () { clReleaseCommandQueue (queue); }
895+ } queue{*expectedQueue};
896+ OclContext oclCtx{rt, queue.queue , false };
897+
898+ ModuleOp mod;
899+ StringRef staticMain;
900+ std::unique_ptr<ExecutionEngine> eng;
901+ auto devStr = builder.getStringAttr (" GPU" /* device ID*/ );
862902 ExecutionEngineOptions opts;
863903 opts.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive;
864904 opts.enableObjectDump = enableObjectDump;
@@ -868,18 +908,75 @@ OclModuleBuilder::build(const OclRuntime::Ext &ext) {
868908 opts.enablePerfNotificationListener = false ;
869909#endif
870910
871- auto eng = ExecutionEngine::create (mod, opts);
872- CHECKE (eng, " Failed to create ExecutionEngine!" );
873- eng->get ()->registerSymbols (OclRuntime::Exports::symbolMap);
911+ // Build the module and check the kernels workgroup size. If the workgroup
912+ // size is different, rebuild the module with the new size.
913+ for (size_t wgSize = 64 ;;) {
914+ dltiAttrs[sizeof (dltiAttrs) / sizeof (DataLayoutEntryInterface) - 1 ] =
915+ DataLayoutEntryAttr::get (
916+ ctx, builder.getStringAttr (" max_work_group_size" ),
917+ builder.getI64IntegerAttr (static_cast <int64_t >(wgSize)));
918+ TargetDeviceSpecInterface devSpec =
919+ TargetDeviceSpecAttr::get (ctx, dltiAttrs);
920+ auto sysSpec =
921+ TargetSystemSpecAttr::get (ctx, ArrayRef (std::pair (devStr, devSpec)));
922+ mod = mlirModule.clone ();
923+ mod.getOperation ()->setAttr (" #dlti.sys_spec" , sysSpec);
924+ PassManager pm{ctx};
925+ pipeline (pm);
926+ CHECK (!pm.run (mod).failed (), " GPU pipeline failed!" );
927+ staticMain = createStaticMain (builder, mod, funcName, argTypes);
928+ auto expectedEng = ExecutionEngine::create (mod, opts);
929+ CHECKE (expectedEng, " Failed to create ExecutionEngine!" );
930+ expectedEng->get ()->registerSymbols (OclRuntime::Exports::symbolMap);
931+
932+ // Find all kernels and query the workgroup size
933+ size_t minSize = std::numeric_limits<size_t >::max ();
934+ mod.walk <>([&](LLVM::LLVMFuncOp func) {
935+ auto name = func.getName ();
936+ if (!name.starts_with (" createGcGpuOclKernel_" )) {
937+ return WalkResult::skip ();
938+ }
939+ auto fn = expectedEng.get ()->lookup (name);
940+ if (!fn) {
941+ gcLogE (" Function not found: " , name.data ());
942+ return WalkResult::skip ();
943+ }
944+
945+ Kernel *kernel =
946+ reinterpret_cast <Kernel *(*)(OclContext *)>(fn.get ())(&oclCtx);
947+ size_t s = 0 ;
948+ auto err = clGetKernelWorkGroupInfo (kernel->kernel , ext.device ,
949+ CL_KERNEL_WORK_GROUP_SIZE,
950+ sizeof (size_t ), &s, nullptr );
951+ if (err == CL_SUCCESS) {
952+ minSize = std::min (minSize, s);
953+ } else {
954+ gcLogE (" Failed to get the kernel workgroup size: " , err);
955+ }
956+ return WalkResult::skip ();
957+ });
958+
959+ if (minSize == std::numeric_limits<size_t >::max () || minSize == wgSize) {
960+ eng = std::move (*expectedEng);
961+ break ;
962+ }
963+
964+ gcLogD (" Changing the workgroup size to " , minSize);
965+ wgSize = minSize;
966+ }
967+
968+ if (printIr) {
969+ mod.dump ();
970+ }
874971
875972 OclModule::MainFunc main = {nullptr };
876973
877974 if (staticMain.empty ()) {
878- auto expect = eng. get () ->lookupPacked (funcName);
975+ auto expect = eng->lookupPacked (funcName);
879976 CHECKE (expect, " Packed function '" , funcName.begin (), " ' not found!" );
880977 main.wrappedMain = *expect;
881978 } else {
882- auto expect = eng. get () ->lookup (staticMain);
979+ auto expect = eng->lookup (staticMain);
883980 CHECKE (expect, " Compiled function '" , staticMain.begin (), " ' not found!" );
884981 main.staticMain = reinterpret_cast <OclModule::StaticMainFunc>(*expect);
885982 }
@@ -888,9 +985,8 @@ OclModuleBuilder::build(const OclRuntime::Ext &ext) {
888985 if (auto it = cache.find (ext); it != cache.end ()) {
889986 return it->second ;
890987 }
891- std::shared_ptr<const OclModule> ptr (
892- new OclModule (OclRuntime (ext), !staticMain.empty (), main, argTypes,
893- std::move (eng.get ())));
988+ std::shared_ptr<const OclModule> ptr (new OclModule (
989+ OclRuntime (ext), !staticMain.empty (), main, argTypes, std::move (eng)));
894990 return cache.emplace (OclDevCtxPair (ext.device , ext.context ), ptr)
895991 .first ->second ;
896992}
0 commit comments