@@ -130,11 +130,161 @@ void LoadFromMemory(void)
130130 free (data);
131131}
132132
133+ #define MAX_CALLS 20
134+
135+ struct CallList {
136+ int current_alloc_call, current_free_call;
137+ CustomAllocFunc alloc_calls[MAX_CALLS];
138+ CustomFreeFunc free_calls[MAX_CALLS];
139+ };
140+
141+ LPVOID MemoryFailingAlloc (LPVOID address, SIZE_T size, DWORD allocationType, DWORD protect, void * userdata)
142+ {
143+ UNREFERENCED_PARAMETER (address);
144+ UNREFERENCED_PARAMETER (size);
145+ UNREFERENCED_PARAMETER (allocationType);
146+ UNREFERENCED_PARAMETER (protect);
147+ UNREFERENCED_PARAMETER (userdata);
148+ return NULL ;
149+ }
150+
151+ LPVOID MemoryMockAlloc (LPVOID address, SIZE_T size, DWORD allocationType, DWORD protect, void * userdata)
152+ {
153+ CallList* calls = (CallList*)userdata;
154+ CustomAllocFunc current_func = calls->alloc_calls [calls->current_alloc_call ++];
155+ assert (current_func != NULL );
156+ return current_func (address, size, allocationType, protect, NULL );
157+ }
158+
159+ BOOL MemoryMockFree (LPVOID lpAddress, SIZE_T dwSize, DWORD dwFreeType, void * userdata)
160+ {
161+ CallList* calls = (CallList*)userdata;
162+ CustomFreeFunc current_func = calls->free_calls [calls->current_free_call ++];
163+ assert (current_func != NULL );
164+ return current_func (lpAddress, dwSize, dwFreeType, NULL );
165+ }
166+
167+ void InitFuncs (void ** funcs, va_list args) {
168+ for (int i = 0 ; ; i++) {
169+ assert (i < MAX_CALLS);
170+ funcs[i] = va_arg (args, void *);
171+ if (funcs[i] == NULL ) break ;
172+ }
173+ }
174+
175+ void InitAllocFuncs (CallList* calls, ...) {
176+ va_list args;
177+ va_start (args, calls);
178+ InitFuncs ((void **)calls->alloc_calls , args);
179+ va_end (args);
180+ calls->current_alloc_call = 0 ;
181+ }
182+
183+ void InitFreeFuncs (CallList* calls, ...) {
184+ va_list args;
185+ va_start (args, calls);
186+ InitFuncs ((void **)calls->free_calls , args);
187+ va_end (args);
188+ calls->current_free_call = 0 ;
189+ }
190+
191+ void InitFreeFunc (CallList* calls, CustomFreeFunc freeFunc) {
192+ for (int i = 0 ; i < MAX_CALLS; i++) {
193+ calls->free_calls [i] = freeFunc;
194+ }
195+ calls->current_free_call = 0 ;
196+ }
197+
198+ void TestFailingAllocation (void *data, long size) {
199+ CallList expected_calls;
200+ HMEMORYMODULE handle;
201+
202+ InitAllocFuncs (&expected_calls, MemoryFailingAlloc, MemoryFailingAlloc, NULL );
203+ InitFreeFuncs (&expected_calls, NULL );
204+
205+ handle = MemoryLoadLibraryEx (
206+ data, size, MemoryMockAlloc, MemoryMockFree, MemoryDefaultLoadLibrary,
207+ MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &expected_calls);
208+
209+ assert (handle == NULL );
210+ assert (GetLastError () == ERROR_OUTOFMEMORY);
211+ assert (expected_calls.current_free_call == 0 );
212+
213+ MemoryFreeLibrary (handle);
214+ assert (expected_calls.current_free_call == 0 );
215+ }
216+
217+ void TestCleanupAfterFailingAllocation (void *data, long size) {
218+ CallList expected_calls;
219+ HMEMORYMODULE handle;
220+ int free_calls_after_loading;
221+
222+ InitAllocFuncs (&expected_calls,
223+ MemoryDefaultAlloc,
224+ MemoryDefaultAlloc,
225+ MemoryDefaultAlloc,
226+ MemoryDefaultAlloc,
227+ MemoryFailingAlloc,
228+ NULL );
229+ InitFreeFuncs (&expected_calls, MemoryDefaultFree, NULL );
230+
231+ handle = MemoryLoadLibraryEx (
232+ data, size, MemoryMockAlloc, MemoryMockFree, MemoryDefaultLoadLibrary,
233+ MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &expected_calls);
234+
235+ free_calls_after_loading = expected_calls.current_free_call ;
236+
237+ MemoryFreeLibrary (handle);
238+ assert (expected_calls.current_free_call == free_calls_after_loading);
239+ }
240+
241+ void TestFreeAfterDefaultAlloc (void *data, long size) {
242+ CallList expected_calls;
243+ HMEMORYMODULE handle;
244+ int free_calls_after_loading;
245+
246+ // Note: free might get called internally multiple times
247+ InitFreeFunc (&expected_calls, MemoryDefaultFree);
248+
249+ handle = MemoryLoadLibraryEx (
250+ data, size, MemoryDefaultAlloc, MemoryMockFree, MemoryDefaultLoadLibrary,
251+ MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, &expected_calls);
252+
253+ assert (handle != NULL );
254+ free_calls_after_loading = expected_calls.current_free_call ;
255+
256+ MemoryFreeLibrary (handle);
257+ assert (expected_calls.current_free_call == free_calls_after_loading + 1 );
258+ }
259+
260+ void TestCustomAllocAndFree (void )
261+ {
262+ void *data;
263+ long size;
264+
265+ data = ReadLibrary (&size);
266+ if (data == NULL )
267+ {
268+ return ;
269+ }
270+
271+ _tprintf (_T (" Test MemoryLoadLibraryEx after initially failing allocation function\n " ));
272+ TestFailingAllocation (data, size);
273+ _tprintf (_T (" Test cleanup after MemoryLoadLibraryEx with failing allocation function\n " ));
274+ TestCleanupAfterFailingAllocation (data, size);
275+ _tprintf (_T (" Test custom free function after MemoryLoadLibraryEx\n " ));
276+ TestFreeAfterDefaultAlloc (data, size);
277+
278+ free (data);
279+ }
280+
133281int main ()
134282{
135283 LoadFromFile ();
136284 printf (" \n\n " );
137285 LoadFromMemory ();
286+ printf (" \n\n " );
287+ TestCustomAllocAndFree ();
138288 return 0 ;
139289}
140290
0 commit comments