Skip to content

Commit 9be9c9a

Browse files
committed
Use thread entry point address in GetMainThreadId
1 parent 2d2bf31 commit 9be9c9a

File tree

2 files changed

+71
-5
lines changed

2 files changed

+71
-5
lines changed

Shared/sdk/SharedUtil.Misc.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,13 +158,15 @@ namespace SharedUtil
158158
bool IsWindows7OrGreater();
159159
bool IsWindows8OrGreater();
160160

161+
bool QueryThreadEntryPointAddress(void* thread, DWORD* entryPointAddress);
162+
163+
DWORD GetMainThreadId();
164+
161165
#endif
162166

163167
// Ensure rand() seed gets set to a new unique value
164168
void RandomizeRandomSeed();
165169

166-
DWORD GetMainThreadId();
167-
168170
//
169171
// Return true if currently executing the main thread
170172
// See implementation for details

Shared/sdk/SharedUtil.Misc.hpp

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,39 +1203,102 @@ void SharedUtil::RandomizeRandomSeed()
12031203
}
12041204

12051205
#ifdef WIN32
1206+
static LONG SafeNtQueryInformationThread(HANDLE ThreadHandle, INT ThreadInformationClass, PVOID ThreadInformation, ULONG ThreadInformationLength,
1207+
PULONG ReturnLength)
1208+
{
1209+
using FunctionPointer = LONG(__stdcall*)(HANDLE, INT /*= THREADINFOCLASS*/, PVOID, ULONG, PULONG);
1210+
1211+
struct FunctionLookup
1212+
{
1213+
FunctionPointer function;
1214+
bool once;
1215+
};
1216+
1217+
static FunctionLookup lookup = {};
1218+
1219+
if (!lookup.once)
1220+
{
1221+
lookup.once = true;
1222+
1223+
HMODULE ntdll = LoadLibraryA("ntdll.dll");
1224+
1225+
if (ntdll)
1226+
lookup.function = (FunctionPointer)GetProcAddress(ntdll, "NtQueryInformationThread");
1227+
else
1228+
return 0xC0000135L; // STATUS_DLL_NOT_FOUND
1229+
}
1230+
1231+
if (lookup.function)
1232+
return lookup.function(ThreadHandle, ThreadInformationClass, ThreadInformation, ThreadInformationLength, ReturnLength);
1233+
else
1234+
return 0xC00000BBL; // STATUS_NOT_SUPPORTED
1235+
}
1236+
1237+
bool SharedUtil::QueryThreadEntryPointAddress(void* thread, DWORD* entryPointAddress)
1238+
{
1239+
return SafeNtQueryInformationThread(thread, 9 /*ThreadQuerySetWin32StartAddress*/, entryPointAddress, sizeof(DWORD), nullptr) == 0;
1240+
}
1241+
12061242
DWORD SharedUtil::GetMainThreadId()
12071243
{
12081244
static DWORD dwMainThreadID = 0;
1245+
12091246
if (dwMainThreadID == 0)
12101247
{
1248+
// Get the module information for the currently running process
1249+
MODULEINFO moduleInfo = {};
1250+
GetModuleInformation(GetCurrentProcess(), GetModuleHandle(nullptr), &moduleInfo, sizeof(MODULEINFO));
1251+
1252+
DWORD processEntryPointAddress = reinterpret_cast<DWORD>(moduleInfo.EntryPoint);
1253+
12111254
// Find oldest thread in the current process ( http://www.codeproject.com/Questions/78801/How-to-get-the-main-thread-ID-of-a-process-known-b )
12121255
HANDLE hThreadSnap = CreateToolhelp32Snapshot(TH32CS_SNAPTHREAD, 0);
1256+
12131257
if (hThreadSnap != INVALID_HANDLE_VALUE)
12141258
{
1215-
ULONGLONG ullMinCreateTime = ULLONG_MAX;
1216-
THREADENTRY32 th32;
1259+
ULONGLONG ullMinCreateTime = ULLONG_MAX;
1260+
1261+
THREADENTRY32 th32 = {};
12171262
th32.dwSize = sizeof(THREADENTRY32);
1263+
12181264
for (BOOL bOK = Thread32First(hThreadSnap, &th32); bOK; bOK = Thread32Next(hThreadSnap, &th32))
12191265
{
12201266
if (th32.th32OwnerProcessID == GetCurrentProcessId())
12211267
{
12221268
HANDLE hThread = OpenThread(THREAD_QUERY_INFORMATION, TRUE, th32.th32ThreadID);
1269+
12231270
if (hThread)
12241271
{
1225-
FILETIME afTimes[4] = {0};
1272+
// Check the thread by entry point first
1273+
DWORD entryPointAddress = 0;
1274+
1275+
if (QueryThreadEntryPointAddress(hThread, &entryPointAddress) && entryPointAddress == processEntryPointAddress)
1276+
{
1277+
dwMainThreadID = th32.th32ThreadID;
1278+
CloseHandle(hThread);
1279+
CloseHandle(hThreadSnap);
1280+
return dwMainThreadID;
1281+
}
1282+
1283+
// If entry point check failed, find the oldest thread in the system
1284+
FILETIME afTimes[4] = {};
1285+
12261286
if (GetThreadTimes(hThread, &afTimes[0], &afTimes[1], &afTimes[2], &afTimes[3]))
12271287
{
12281288
ULONGLONG ullTest = (ULONGLONG(afTimes[0].dwHighDateTime) << 32) + afTimes[0].dwLowDateTime;
1289+
12291290
if (ullTest && ullTest < ullMinCreateTime)
12301291
{
12311292
ullMinCreateTime = ullTest;
12321293
dwMainThreadID = th32.th32ThreadID;
12331294
}
12341295
}
1296+
12351297
CloseHandle(hThread);
12361298
}
12371299
}
12381300
}
1301+
12391302
CloseHandle(hThreadSnap);
12401303
}
12411304

@@ -1245,6 +1308,7 @@ DWORD SharedUtil::GetMainThreadId()
12451308
dwMainThreadID = GetCurrentThreadId();
12461309
}
12471310
}
1311+
12481312
return dwMainThreadID;
12491313
}
12501314
#endif

0 commit comments

Comments
 (0)