1717#include < list>
1818#include < map>
1919#include < shared_mutex>
20+ #include < unordered_map>
21+ #include < unordered_set>
2022#include < vector>
2123
2224#include " ExclusiveAccess.h"
@@ -58,6 +60,7 @@ struct GenericPluginTy;
5860struct GenericKernelTy ;
5961struct GenericDeviceTy ;
6062struct RecordReplayTy ;
63+ struct KernelRunRecordTy ;
6164
6265// / Class that wraps the __tgt_async_info to simply its usage. In case the
6366// / object is constructed without a valid __tgt_async_info, the object will use
@@ -1105,6 +1108,8 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
11051108
11061109 bool getMultiDeviceKernelValue (void *EntryPtr);
11071110
1111+ KernelRunRecordTy *getKernelRunRecords () const { return KernelRunRecords; }
1112+
11081113 // / Return true if a descriptor of size 'Size' should be allocated using
11091114 // / shared memory. Default implementation returns 'false',
11101115 virtual bool useSharedMemForDescriptor (int64_t Size);
@@ -1256,6 +1261,9 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
12561261 // / This is used to run the RPC server during task synchronization.
12571262 RPCServerTy *RPCServer;
12581263
1264+ // / Structs for functions and data used in runtime autotuning.
1265+ KernelRunRecordTy *KernelRunRecords;
1266+
12591267private:
12601268#ifdef OMPT_SUPPORT
12611269 // / OMPT callback functions
@@ -1282,6 +1290,109 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
12821290 bool IsFastReductionEnabled = false ;
12831291};
12841292
1293+ // / Struct represents the metadata for each kernel run on the device.
1294+ struct KernelRunRecordTy {
1295+
1296+ struct KernelRunEntryTy {
1297+ std::string KernelName;
1298+ uint32_t NumTeams = 0 ;
1299+ uint32_t NumThreads = 0 ;
1300+ uint64_t RunDuration = 0 ;
1301+ };
1302+
1303+ // Metadata used in tuning process.
1304+ struct TuningMetadataTy {
1305+ uint32_t IdxThread = 0 ;
1306+ uint32_t IdxCUMultiplier = 0 ;
1307+ // Run counters.
1308+ uint32_t RunCounters = 0 ;
1309+ // Entry with minimum running time.
1310+ KernelRunEntryTy MinEntry;
1311+ };
1312+
1313+ // Add a new entry
1314+ void addEntry (std::string KernelName, uint32_t NumTeams, uint32_t NumThreads,
1315+ uint64_t RunDuration) {
1316+ TuningData[KernelName].RunCounters ++;
1317+
1318+ // Update min entries.
1319+ uint64_t MinDuration = 0 ;
1320+ auto It = TuningData.find (KernelName);
1321+ if (It != TuningData.end ()) {
1322+ MinDuration = It->second .MinEntry .RunDuration ;
1323+ }
1324+ if (MinDuration > RunDuration || MinDuration == 0 ) {
1325+ TuningData[KernelName].MinEntry = {KernelName, NumTeams, NumThreads,
1326+ RunDuration};
1327+ }
1328+ }
1329+
1330+ // Get parameters for next kernel launch.
1331+ std::pair<uint32_t , uint32_t >
1332+ getLaunchParamsForKernel (std::string KernelName,
1333+ GenericDeviceTy &GenericDevice) {
1334+ // If the kernel reaches the run limit,
1335+ // return the current optimal launch parameters.
1336+ if (reachedRunLimitForKernel (KernelName)) {
1337+ auto MinEntry = TuningData[KernelName].MinEntry ;
1338+ return {MinEntry.NumTeams , MinEntry.NumThreads };
1339+ }
1340+
1341+ // Pick new launch parameters.
1342+ uint32_t IdxCUMulti = TuningData[KernelName].IdxCUMultiplier ;
1343+ uint32_t IdxThread = TuningData[KernelName].IdxThread ;
1344+
1345+ if (IdxCUMulti >= CUMultiplierCandidate.size ()) {
1346+ // No more element to search.
1347+ // Return current optimal launch parameters.
1348+ return {TuningData[KernelName].MinEntry .NumTeams ,
1349+ TuningData[KernelName].MinEntry .NumThreads };
1350+ }
1351+
1352+ // New team/thread pair for launch parameters.
1353+ uint32_t NumCU = GenericDevice.getNumComputeUnits ();
1354+ std::pair<uint32_t , uint32_t > NewLaunchParams = {
1355+ CUMultiplierCandidate[IdxCUMulti] * NumCU, ThreadCandidate[IdxThread]};
1356+
1357+ // Update indices.
1358+ IdxThread++;
1359+ TuningData[KernelName].IdxThread = IdxThread;
1360+
1361+ if (IdxThread >= ThreadCandidate.size ()) {
1362+ TuningData[KernelName].IdxThread = 0 ;
1363+ TuningData[KernelName].IdxCUMultiplier ++;
1364+ }
1365+
1366+ return NewLaunchParams;
1367+ }
1368+
1369+ bool reachedRunLimitForKernel (std::string KernelName) {
1370+ if (TuningData.find (KernelName) == TuningData.end ()) {
1371+ // If no record for this kernel.
1372+ return false ;
1373+ }
1374+
1375+ return TuningData[KernelName].RunCounters > RunLimiter;
1376+ }
1377+
1378+ uint32_t getRunCounterForKernel (std::string KernelName) {
1379+ if (TuningData.find (KernelName) == TuningData.end ()) {
1380+ return 0 ;
1381+ }
1382+
1383+ return TuningData[KernelName].RunCounters ;
1384+ }
1385+
1386+ private:
1387+ // Candidates for thread and team.
1388+ std::vector<uint32_t > ThreadCandidate = {32 , 64 , 128 , 256 , 512 , 1024 };
1389+ std::vector<uint32_t > CUMultiplierCandidate = {4 , 8 , 16 , 32 , 64 , 128 };
1390+ // The max number of tuning runs for each kernel.
1391+ uint32_t RunLimiter = ThreadCandidate.size() * CUMultiplierCandidate.size();
1392+ // Used for keeping track of the metatdata used in tuning for each kernel.
1393+ std::unordered_map<std::string, TuningMetadataTy> TuningData;
1394+ };
1395+
12851396// / Class implementing common functionalities of offload plugins. Each plugin
12861397// / should define the specific plugin class, derive from this generic one, and
12871398// / implement the necessary virtual function members.
0 commit comments