@@ -38,6 +38,8 @@ struct umf_memory_tracker_t {
3838 // for another memory pool (nested memory pooling).
3939 critnib * alloc_segments_map [MAX_LEVELS_OF_ALLOC_SEGMENT_MAP ];
4040 utils_mutex_t mutex ;
41+ // number of memory regions at levels > 0
42+ uint64_t n_at_higher_levels ;
4143};
4244
4345typedef struct tracker_alloc_info_t {
@@ -131,9 +133,14 @@ umfMemoryTrackerAddAtLevel(umf_memory_tracker_handle_t hTracker, int level,
131133 int ret = critnib_insert (hTracker -> alloc_segments_map [level ],
132134 (uintptr_t )ptr , value , 0 );
133135 if (ret == 0 ) {
136+ if (level > 0 ) {
137+ utils_fetch_and_add64 (& hTracker -> n_at_higher_levels , 1 );
138+ }
139+
134140 LOG_DEBUG ("memory region is added, tracker=%p, level=%i, pool=%p, "
135141 "ptr=%p, size=%zu" ,
136142 (void * )hTracker , level , (void * )pool , ptr , size );
143+
137144 if (parent_value ) {
138145 parent_value -> n_children ++ ;
139146 LOG_DEBUG (
@@ -158,9 +165,10 @@ umfMemoryTrackerAddAtLevel(umf_memory_tracker_handle_t hTracker, int level,
158165 return umf_result ;
159166}
160167
161- static umf_result_t umfMemoryTrackerAdd (umf_memory_tracker_handle_t hTracker ,
162- umf_memory_pool_handle_t pool ,
163- const void * ptr , size_t size ) {
168+ static umf_result_t
169+ umfMemoryTrackerAddLock (umf_memory_tracker_handle_t hTracker ,
170+ umf_memory_pool_handle_t pool , const void * ptr ,
171+ size_t size , int lock ) {
164172 assert (ptr );
165173
166174 umf_result_t umf_result = UMF_RESULT_ERROR_UNKNOWN ;
@@ -170,10 +178,13 @@ static umf_result_t umfMemoryTrackerAdd(umf_memory_tracker_handle_t hTracker,
170178 uintptr_t rkey = 0 ;
171179 int level = 0 ;
172180 int found = 0 ;
181+ int ret ;
173182
174- int ret = utils_mutex_lock (& hTracker -> mutex );
175- if (ret ) {
176- return UMF_RESULT_ERROR_UNKNOWN ;
183+ if (lock ) {
184+ ret = utils_mutex_lock (& hTracker -> mutex );
185+ if (ret ) {
186+ return UMF_RESULT_ERROR_UNKNOWN ;
187+ }
177188 }
178189
179190 // Find the most nested (in the highest level) entry
@@ -217,13 +228,25 @@ static umf_result_t umfMemoryTrackerAdd(umf_memory_tracker_handle_t hTracker,
217228 umf_result = UMF_RESULT_SUCCESS ;
218229
219230err_unlock :
220- utils_mutex_unlock (& hTracker -> mutex );
231+ if (lock ) {
232+ utils_mutex_unlock (& hTracker -> mutex );
233+ }
221234
222235 return umf_result ;
223236}
224237
225- static umf_result_t umfMemoryTrackerRemove (umf_memory_tracker_handle_t hTracker ,
226- const void * ptr ) {
238+ static umf_result_t umfMemoryTrackerAdd (umf_memory_tracker_handle_t hTracker ,
239+ umf_memory_pool_handle_t pool ,
240+ const void * ptr , size_t size ) {
241+
242+ return umfMemoryTrackerAddLock (
243+ hTracker , pool , ptr , size ,
244+ (utils_fetch_and_add64 (& hTracker -> n_at_higher_levels , 0 ) > 0 ));
245+ }
246+
247+ static umf_result_t
248+ umfMemoryTrackerRemoveLock (umf_memory_tracker_handle_t hTracker ,
249+ const void * ptr , int lock ) {
227250 assert (ptr );
228251
229252 // TODO: there is no support for removing partial ranges (or multiple entries
@@ -235,10 +258,13 @@ static umf_result_t umfMemoryTrackerRemove(umf_memory_tracker_handle_t hTracker,
235258 tracker_alloc_info_t * parent_value = NULL ;
236259 uintptr_t parent_key = 0 ;
237260 int level = 0 ;
261+ int ret ;
238262
239- int ret = utils_mutex_lock (& hTracker -> mutex );
240- if (ret ) {
241- return UMF_RESULT_ERROR_UNKNOWN ;
263+ if (lock ) {
264+ ret = utils_mutex_lock (& hTracker -> mutex );
265+ if (ret ) {
266+ return UMF_RESULT_ERROR_UNKNOWN ;
267+ }
242268 }
243269
244270 // Find the most nested (on the highest level) entry in the map
@@ -254,6 +280,10 @@ static umf_result_t umfMemoryTrackerRemove(umf_memory_tracker_handle_t hTracker,
254280 value = critnib_remove (hTracker -> alloc_segments_map [level ], (uintptr_t )ptr );
255281 assert (value );
256282
283+ if (level > 0 ) {
284+ utils_fetch_and_add64 (& hTracker -> n_at_higher_levels , -1 );
285+ }
286+
257287 LOG_DEBUG ("memory region removed: tracker=%p, level=%i, pool=%p, ptr=%p, "
258288 "size=%zu" ,
259289 (void * )hTracker , level , value -> pool , ptr , value -> size );
@@ -272,11 +302,20 @@ static umf_result_t umfMemoryTrackerRemove(umf_memory_tracker_handle_t hTracker,
272302 umf_result = UMF_RESULT_SUCCESS ;
273303
274304err_unlock :
275- utils_mutex_unlock (& hTracker -> mutex );
305+ if (lock ) {
306+ utils_mutex_unlock (& hTracker -> mutex );
307+ }
276308
277309 return umf_result ;
278310}
279311
312+ static umf_result_t umfMemoryTrackerRemove (umf_memory_tracker_handle_t hTracker ,
313+ const void * ptr ) {
314+ return umfMemoryTrackerRemoveLock (
315+ hTracker , ptr ,
316+ (utils_fetch_and_add64 (& hTracker -> n_at_higher_levels , 0 ) > 0 ));
317+ }
318+
280319umf_memory_pool_handle_t umfMemoryTrackerGetPool (const void * ptr ) {
281320 umf_alloc_info_t allocInfo = {NULL , 0 , NULL };
282321 umf_result_t ret = umfMemoryTrackerGetAllocInfo (ptr , & allocInfo );
@@ -1084,6 +1123,7 @@ umf_memory_tracker_handle_t umfMemoryTrackerCreate(void) {
10841123 }
10851124
10861125 handle -> alloc_info_allocator = alloc_info_allocator ;
1126+ handle -> n_at_higher_levels = 0 ;
10871127
10881128 void * mutex_ptr = utils_mutex_init (& handle -> mutex );
10891129 if (!mutex_ptr ) {
0 commit comments