Skip to content

Commit ad5e25c

Browse files
committed
Improve fault tolerance for GetMainThreadId
1 parent bd8feac commit ad5e25c

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

Shared/sdk/SharedUtil.Misc.hpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,7 +1223,7 @@ static LONG SafeNtQueryInformationThread(HANDLE ThreadHandle, INT ThreadInformat
12231223
HMODULE ntdll = LoadLibraryA("ntdll.dll");
12241224

12251225
if (ntdll)
1226-
lookup.function = (FunctionPointer)GetProcAddress(ntdll, "NtQueryInformationThread");
1226+
lookup.function = reinterpret_cast<FunctionPointer>(GetProcAddress(ntdll, "NtQueryInformationThread"));
12271227
else
12281228
return 0xC0000135L; // STATUS_DLL_NOT_FOUND
12291229
}
@@ -1246,10 +1246,13 @@ DWORD SharedUtil::GetMainThreadId()
12461246
if (dwMainThreadID == 0)
12471247
{
12481248
// Get the module information for the currently running process
1249+
DWORD processEntryPointAddress = 0;
12491250
MODULEINFO moduleInfo = {};
1250-
GetModuleInformation(GetCurrentProcess(), GetModuleHandle(nullptr), &moduleInfo, sizeof(MODULEINFO));
1251-
1252-
DWORD processEntryPointAddress = reinterpret_cast<DWORD>(moduleInfo.EntryPoint);
1251+
1252+
if (GetModuleInformation(GetCurrentProcess(), GetModuleHandle(nullptr), &moduleInfo, sizeof(MODULEINFO)) != 0)
1253+
{
1254+
processEntryPointAddress = reinterpret_cast<DWORD>(moduleInfo.EntryPoint);
1255+
}
12531256

12541257
// Find oldest thread in the current process ( http://www.codeproject.com/Questions/78801/How-to-get-the-main-thread-ID-of-a-process-known-b )
12551258
HANDLE hThreadSnap = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0);
@@ -1258,26 +1261,31 @@ DWORD SharedUtil::GetMainThreadId()
12581261
{
12591262
ULONGLONG ullMinCreateTime = ULLONG_MAX;
12601263

1264+
DWORD currentProcessID = GetCurrentProcessId();
1265+
12611266
THREADENTRY32 th32 = {};
12621267
th32.dwSize = sizeof(THREADENTRY32);
12631268

12641269
for (BOOL bOK = Thread32First(hThreadSnap, &th32); bOK; bOK = Thread32Next(hThreadSnap, &th32))
12651270
{
1266-
if (th32.th32OwnerProcessID == GetCurrentProcessId())
1271+
if (th32.th32OwnerProcessID == currentProcessID)
12671272
{
12681273
HANDLE hThread = OpenThread(THREAD_QUERY_INFORMATION, TRUE, th32.th32ThreadID);
12691274

12701275
if (hThread)
12711276
{
12721277
// Check the thread by entry point first
1273-
DWORD entryPointAddress = 0;
1274-
1275-
if (QueryThreadEntryPointAddress(hThread, &entryPointAddress) && entryPointAddress == processEntryPointAddress)
1278+
if (processEntryPointAddress != 0)
12761279
{
1277-
dwMainThreadID = th32.th32ThreadID;
1278-
CloseHandle(hThread);
1279-
CloseHandle(hThreadSnap);
1280-
return dwMainThreadID;
1280+
DWORD entryPointAddress = 0;
1281+
1282+
if (QueryThreadEntryPointAddress(hThread, &entryPointAddress) && entryPointAddress == processEntryPointAddress)
1283+
{
1284+
dwMainThreadID = th32.th32ThreadID;
1285+
CloseHandle(hThread);
1286+
CloseHandle(hThreadSnap);
1287+
return dwMainThreadID;
1288+
}
12811289
}
12821290

12831291
// If entry point check failed, find the oldest thread in the system

0 commit comments

Comments
 (0)