Skip to content

Commit a1fb42d

Browse files
committed
Add some tests for the custom free and alloc functions
1 parent bc38a6b commit a1fb42d

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed

example/DllLoader/DllLoader.cpp

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,161 @@ void LoadFromMemory(void)
130130
free(data);
131131
}
132132

133+
#define MAX_CALLS 20
134+
135+
struct CallList {
136+
int current_alloc_call, current_free_call;
137+
CustomAllocFunc alloc_calls[MAX_CALLS];
138+
CustomFreeFunc free_calls[MAX_CALLS];
139+
};
140+
141+
LPVOID MemoryFailingAlloc(LPVOID address, SIZE_T size, DWORD allocationType, DWORD protect, void* userdata)
142+
{
143+
UNREFERENCED_PARAMETER(address);
144+
UNREFERENCED_PARAMETER(size);
145+
UNREFERENCED_PARAMETER(allocationType);
146+
UNREFERENCED_PARAMETER(protect);
147+
UNREFERENCED_PARAMETER(userdata);
148+
return NULL;
149+
}
150+
151+
LPVOID MemoryMockAlloc(LPVOID address, SIZE_T size, DWORD allocationType, DWORD protect, void* userdata)
152+
{
153+
CallList* calls = (CallList*)userdata;
154+
CustomAllocFunc current_func = calls->alloc_calls[calls->current_alloc_call++];
155+
assert(current_func != NULL);
156+
return current_func(address, size, allocationType, protect, NULL);
157+
}
158+
159+
BOOL MemoryMockFree(LPVOID lpAddress, SIZE_T dwSize, DWORD dwFreeType, void* userdata)
160+
{
161+
CallList* calls = (CallList*)userdata;
162+
CustomFreeFunc current_func = calls->free_calls[calls->current_free_call++];
163+
assert(current_func != NULL);
164+
return current_func(lpAddress, dwSize, dwFreeType, NULL);
165+
}
166+
167+
void InitFuncs(void** funcs, va_list args) {
168+
for (int i = 0; ; i++) {
169+
assert(i < MAX_CALLS);
170+
funcs[i] = va_arg(args, void*);
171+
if (funcs[i] == NULL) break;
172+
}
173+
}
174+
175+
void InitAllocFuncs(CallList* calls, ...) {
176+
va_list args;
177+
va_start(args, calls);
178+
InitFuncs((void**)calls->alloc_calls, args);
179+
va_end(args);
180+
calls->current_alloc_call = 0;
181+
}
182+
183+
void InitFreeFuncs(CallList* calls, ...) {
184+
va_list args;
185+
va_start(args, calls);
186+
InitFuncs((void**)calls->free_calls, args);
187+
va_end(args);
188+
calls->current_free_call = 0;
189+
}
190+
191+
void InitFreeFunc(CallList* calls, CustomFreeFunc freeFunc) {
192+
for (int i = 0; i < MAX_CALLS; i++) {
193+
calls->free_calls[i] = freeFunc;
194+
}
195+
calls->current_free_call = 0;
196+
}
197+
198+
void TestFailingAllocation(void *data, long size) {
199+
CallList expected_calls;
200+
HMEMORYMODULE handle;
201+
202+
InitAllocFuncs(&expected_calls, MemoryFailingAlloc, MemoryFailingAlloc, NULL);
203+
InitFreeFuncs(&expected_calls, NULL);
204+
205+
handle = MemoryLoadLibraryEx(
206+
data, size, MemoryMockAlloc, MemoryMockFree, MemoryDefaultLoadLibrary,
207+
MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &expected_calls);
208+
209+
assert(handle == NULL);
210+
assert(GetLastError() == ERROR_OUTOFMEMORY);
211+
assert(expected_calls.current_free_call == 0);
212+
213+
MemoryFreeLibrary(handle);
214+
assert(expected_calls.current_free_call == 0);
215+
}
216+
217+
void TestCleanupAfterFailingAllocation(void *data, long size) {
218+
CallList expected_calls;
219+
HMEMORYMODULE handle;
220+
int free_calls_after_loading;
221+
222+
InitAllocFuncs(&expected_calls,
223+
MemoryDefaultAlloc,
224+
MemoryDefaultAlloc,
225+
MemoryDefaultAlloc,
226+
MemoryDefaultAlloc,
227+
MemoryFailingAlloc,
228+
NULL);
229+
InitFreeFuncs(&expected_calls, MemoryDefaultFree, NULL);
230+
231+
handle = MemoryLoadLibraryEx(
232+
data, size, MemoryMockAlloc, MemoryMockFree, MemoryDefaultLoadLibrary,
233+
MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &expected_calls);
234+
235+
free_calls_after_loading = expected_calls.current_free_call;
236+
237+
MemoryFreeLibrary(handle);
238+
assert(expected_calls.current_free_call == free_calls_after_loading);
239+
}
240+
241+
void TestFreeAfterDefaultAlloc(void *data, long size) {
242+
CallList expected_calls;
243+
HMEMORYMODULE handle;
244+
int free_calls_after_loading;
245+
246+
// Note: free might get called internally multiple times
247+
InitFreeFunc(&expected_calls, MemoryDefaultFree);
248+
249+
handle = MemoryLoadLibraryEx(
250+
data, size, MemoryDefaultAlloc, MemoryMockFree, MemoryDefaultLoadLibrary,
251+
MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &expected_calls);
252+
253+
assert(handle != NULL);
254+
free_calls_after_loading = expected_calls.current_free_call;
255+
256+
MemoryFreeLibrary(handle);
257+
assert(expected_calls.current_free_call == free_calls_after_loading + 1);
258+
}
259+
260+
void TestCustomAllocAndFree(void)
261+
{
262+
void *data;
263+
long size;
264+
265+
data = ReadLibrary(&size);
266+
if (data == NULL)
267+
{
268+
return;
269+
}
270+
271+
_tprintf(_T("Test MemoryLoadLibraryEx after initially failing allocation function\n"));
272+
TestFailingAllocation(data, size);
273+
_tprintf(_T("Test cleanup after MemoryLoadLibraryEx with failing allocation function\n"));
274+
TestCleanupAfterFailingAllocation(data, size);
275+
_tprintf(_T("Test custom free function after MemoryLoadLibraryEx\n"));
276+
TestFreeAfterDefaultAlloc(data, size);
277+
278+
free(data);
279+
}
280+
133281
int main()
134282
{
135283
LoadFromFile();
136284
printf("\n\n");
137285
LoadFromMemory();
286+
printf("\n\n");
287+
TestCustomAllocAndFree();
138288
return 0;
139289
}
140290

0 commit comments

Comments
 (0)