@@ -1203,39 +1203,102 @@ void SharedUtil::RandomizeRandomSeed()
1203
1203
}
1204
1204
1205
1205
#ifdef WIN32
1206
+ static LONG SafeNtQueryInformationThread (HANDLE ThreadHandle, INT ThreadInformationClass, PVOID ThreadInformation, ULONG ThreadInformationLength,
1207
+ PULONG ReturnLength)
1208
+ {
1209
+ using FunctionPointer = LONG (__stdcall*)(HANDLE, INT /* = THREADINFOCLASS*/ , PVOID, ULONG, PULONG);
1210
+
1211
+ struct FunctionLookup
1212
+ {
1213
+ FunctionPointer function;
1214
+ bool once;
1215
+ };
1216
+
1217
+ static FunctionLookup lookup = {};
1218
+
1219
+ if (!lookup.once )
1220
+ {
1221
+ lookup.once = true ;
1222
+
1223
+ HMODULE ntdll = LoadLibraryA (" ntdll.dll" );
1224
+
1225
+ if (ntdll)
1226
+ lookup.function = (FunctionPointer)GetProcAddress (ntdll, " NtQueryInformationThread" );
1227
+ else
1228
+ return 0xC0000135L ; // STATUS_DLL_NOT_FOUND
1229
+ }
1230
+
1231
+ if (lookup.function )
1232
+ return lookup.function (ThreadHandle, ThreadInformationClass, ThreadInformation, ThreadInformationLength, ReturnLength);
1233
+ else
1234
+ return 0xC00000BBL ; // STATUS_NOT_SUPPORTED
1235
+ }
1236
+
1237
+ bool SharedUtil::QueryThreadEntryPointAddress (void * thread, DWORD* entryPointAddress)
1238
+ {
1239
+ return SafeNtQueryInformationThread (thread, 9 /* ThreadQuerySetWin32StartAddress*/ , entryPointAddress, sizeof (DWORD), nullptr ) == 0 ;
1240
+ }
1241
+
1206
1242
DWORD SharedUtil::GetMainThreadId ()
1207
1243
{
1208
1244
static DWORD dwMainThreadID = 0 ;
1245
+
1209
1246
if (dwMainThreadID == 0 )
1210
1247
{
1248
+ // Get the module information for the currently running process
1249
+ MODULEINFO moduleInfo = {};
1250
+ GetModuleInformation (GetCurrentProcess (), GetModuleHandle (nullptr ), &moduleInfo, sizeof (MODULEINFO));
1251
+
1252
+ DWORD processEntryPointAddress = reinterpret_cast <DWORD>(moduleInfo.EntryPoint );
1253
+
1211
1254
// Find oldest thread in the current process ( http://www.codeproject.com/Questions/78801/How-to-get-the-main-thread-ID-of-a-process-known-b )
1212
1255
HANDLE hThreadSnap = CreateToolhelp32Snapshot (TH32CS_SNAPTHREAD, 0 );
1256
+
1213
1257
if (hThreadSnap != INVALID_HANDLE_VALUE)
1214
1258
{
1215
- ULONGLONG ullMinCreateTime = ULLONG_MAX;
1216
- THREADENTRY32 th32;
1259
+ ULONGLONG ullMinCreateTime = ULLONG_MAX;
1260
+
1261
+ THREADENTRY32 th32 = {};
1217
1262
th32.dwSize = sizeof (THREADENTRY32);
1263
+
1218
1264
for (BOOL bOK = Thread32First (hThreadSnap, &th32); bOK; bOK = Thread32Next (hThreadSnap, &th32))
1219
1265
{
1220
1266
if (th32.th32OwnerProcessID == GetCurrentProcessId ())
1221
1267
{
1222
1268
HANDLE hThread = OpenThread (THREAD_QUERY_INFORMATION, TRUE , th32.th32ThreadID );
1269
+
1223
1270
if (hThread)
1224
1271
{
1225
- FILETIME afTimes[4 ] = {0 };
1272
+ // Check the thread by entry point first
1273
+ DWORD entryPointAddress = 0 ;
1274
+
1275
+ if (QueryThreadEntryPointAddress (hThread, &entryPointAddress) && entryPointAddress == processEntryPointAddress)
1276
+ {
1277
+ dwMainThreadID = th32.th32ThreadID ;
1278
+ CloseHandle (hThread);
1279
+ CloseHandle (hThreadSnap);
1280
+ return dwMainThreadID;
1281
+ }
1282
+
1283
+ // If entry point check failed, find the oldest thread in the system
1284
+ FILETIME afTimes[4 ] = {};
1285
+
1226
1286
if (GetThreadTimes (hThread, &afTimes[0 ], &afTimes[1 ], &afTimes[2 ], &afTimes[3 ]))
1227
1287
{
1228
1288
ULONGLONG ullTest = (ULONGLONG (afTimes[0 ].dwHighDateTime ) << 32 ) + afTimes[0 ].dwLowDateTime ;
1289
+
1229
1290
if (ullTest && ullTest < ullMinCreateTime)
1230
1291
{
1231
1292
ullMinCreateTime = ullTest;
1232
1293
dwMainThreadID = th32.th32ThreadID ;
1233
1294
}
1234
1295
}
1296
+
1235
1297
CloseHandle (hThread);
1236
1298
}
1237
1299
}
1238
1300
}
1301
+
1239
1302
CloseHandle (hThreadSnap);
1240
1303
}
1241
1304
@@ -1245,6 +1308,7 @@ DWORD SharedUtil::GetMainThreadId()
1245
1308
dwMainThreadID = GetCurrentThreadId ();
1246
1309
}
1247
1310
}
1311
+
1248
1312
return dwMainThreadID;
1249
1313
}
1250
1314
#endif
0 commit comments