diff --git a/src/libumf.c b/src/libumf.c index 64314f4d39..f8f6cc61ff 100644 --- a/src/libumf.c +++ b/src/libumf.c @@ -12,6 +12,7 @@ #include "base_alloc_global.h" #include "ipc_cache.h" #include "memspace_internal.h" +#include "pool/pool_scalable_internal.h" #include "provider_cuda_internal.h" #include "provider_level_zero_internal.h" #include "provider_tracking.h" @@ -83,6 +84,7 @@ void umfTearDown(void) { fini_umfTearDown: fini_ze_global_state(); fini_cu_global_state(); + fini_tbb_global_state(); LOG_DEBUG("UMF library finalized"); } } diff --git a/src/pool/pool_scalable.c b/src/pool/pool_scalable.c index 2ee265df8c..e1ab3d376d 100644 --- a/src/pool/pool_scalable.c +++ b/src/pool/pool_scalable.c @@ -20,6 +20,7 @@ #include "base_alloc_global.h" #include "libumf.h" +#include "pool_scalable_internal.h" #include "utils_common.h" #include "utils_concurrency.h" #include "utils_load_library.h" @@ -33,6 +34,7 @@ static __TLS umf_result_t TLS_last_allocation_error; static __TLS umf_result_t TLS_last_free_error; static const size_t DEFAULT_GRANULARITY = 2 * 1024 * 1024; // 2MB + typedef struct tbb_mem_pool_policy_t { raw_alloc_tbb_type pAlloc; raw_free_tbb_type pFree; @@ -66,7 +68,6 @@ typedef struct tbb_callbacks_t { typedef struct tbb_memory_pool_t { umf_memory_provider_handle_t mem_provider; void *tbb_pool; - tbb_callbacks_t tbb_callbacks; } tbb_memory_pool_t; typedef enum tbb_enums_t { @@ -82,6 +83,10 @@ typedef enum tbb_enums_t { TBB_POOL_SYMBOLS_MAX // it has to be the last one } tbb_enums_t; +static UTIL_ONCE_FLAG tbb_initialized = UTIL_ONCE_FLAG_INIT; +static int tbb_init_result = 0; +static tbb_callbacks_t tbb_callbacks = {0}; + static const char *tbb_symbol[TBB_POOL_SYMBOLS_MAX] = { #ifdef _WIN32 // symbols copied from oneTBB/src/tbbmalloc/def/win64-tbbmalloc.def @@ -109,46 +114,60 @@ static const char *tbb_symbol[TBB_POOL_SYMBOLS_MAX] = { #endif }; -static int init_tbb_callbacks(tbb_callbacks_t *tbb_callbacks) { - assert(tbb_callbacks); - +static void init_tbb_callbacks_once(void) { const char *lib_name = tbb_symbol[TBB_LIB_NAME]; - tbb_callbacks->lib_handle = utils_open_library(lib_name, 0); - if (!tbb_callbacks->lib_handle) { + tbb_callbacks.lib_handle = utils_open_library(lib_name, 0); + if (!tbb_callbacks.lib_handle) { LOG_ERR("%s required by Scalable Pool not found - install TBB malloc " "or make sure it is in the default search paths.", lib_name); - return -1; + tbb_init_result = -1; + return; } - - *(void **)&tbb_callbacks->pool_malloc = utils_get_symbol_addr( - tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_MALLOC], lib_name); - *(void **)&tbb_callbacks->pool_realloc = utils_get_symbol_addr( - tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_REALLOC], lib_name); - *(void **)&tbb_callbacks->pool_aligned_malloc = - utils_get_symbol_addr(tbb_callbacks->lib_handle, + *(void **)&tbb_callbacks.pool_malloc = utils_get_symbol_addr( + tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_MALLOC], lib_name); + *(void **)&tbb_callbacks.pool_realloc = utils_get_symbol_addr( + tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_REALLOC], lib_name); + *(void **)&tbb_callbacks.pool_aligned_malloc = + utils_get_symbol_addr(tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_ALIGNED_MALLOC], lib_name); - *(void **)&tbb_callbacks->pool_free = utils_get_symbol_addr( - tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_FREE], lib_name); - *(void **)&tbb_callbacks->pool_create_v1 = utils_get_symbol_addr( - tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_CREATE_V1], lib_name); - *(void **)&tbb_callbacks->pool_destroy = utils_get_symbol_addr( - tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_DESTROY], lib_name); - *(void **)&tbb_callbacks->pool_identify = utils_get_symbol_addr( - tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_IDENTIFY], lib_name); - *(void **)&tbb_callbacks->pool_msize = utils_get_symbol_addr( - tbb_callbacks->lib_handle, tbb_symbol[TBB_POOL_MSIZE], lib_name); - - if (!tbb_callbacks->pool_malloc || !tbb_callbacks->pool_realloc || - !tbb_callbacks->pool_aligned_malloc || !tbb_callbacks->pool_free || - !tbb_callbacks->pool_create_v1 || !tbb_callbacks->pool_destroy || - !tbb_callbacks->pool_identify) { + *(void **)&tbb_callbacks.pool_free = utils_get_symbol_addr( + tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_FREE], lib_name); + *(void **)&tbb_callbacks.pool_create_v1 = utils_get_symbol_addr( + tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_CREATE_V1], lib_name); + *(void **)&tbb_callbacks.pool_destroy = utils_get_symbol_addr( + tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_DESTROY], lib_name); + *(void **)&tbb_callbacks.pool_identify = utils_get_symbol_addr( + tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_IDENTIFY], lib_name); + *(void **)&tbb_callbacks.pool_msize = utils_get_symbol_addr( + tbb_callbacks.lib_handle, tbb_symbol[TBB_POOL_MSIZE], lib_name); + + if (!tbb_callbacks.pool_malloc || !tbb_callbacks.pool_realloc || + !tbb_callbacks.pool_aligned_malloc || !tbb_callbacks.pool_free || + !tbb_callbacks.pool_create_v1 || !tbb_callbacks.pool_destroy || + !tbb_callbacks.pool_identify) { LOG_FATAL("Could not find all TBB symbols in %s", lib_name); - utils_close_library(tbb_callbacks->lib_handle); - return -1; + if (utils_close_library(tbb_callbacks.lib_handle)) { + LOG_ERR("Could not close %s library", lib_name); + } + tbb_init_result = -1; } +} - return 0; +static int init_tbb_callbacks(void) { + utils_init_once(&tbb_initialized, init_tbb_callbacks_once); + return tbb_init_result; +} + +void fini_tbb_global_state(void) { + if (tbb_callbacks.lib_handle) { + if (!utils_close_library(tbb_callbacks.lib_handle)) { + tbb_callbacks.lib_handle = NULL; + LOG_DEBUG("TBB library closed"); + } else { + LOG_ERR("TBB library cannot be unloaded"); + } + } } static void *tbb_raw_alloc_wrapper(intptr_t pool_id, size_t *raw_bytes) { @@ -264,35 +283,41 @@ static umf_result_t tbb_pool_initialize(umf_memory_provider_handle_t provider, return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY; } - int ret = init_tbb_callbacks(&pool_data->tbb_callbacks); + umf_result_t res = UMF_RESULT_SUCCESS; + int ret = init_tbb_callbacks(); if (ret != 0) { LOG_FATAL("loading TBB symbols failed"); - return UMF_RESULT_ERROR_UNKNOWN; + res = UMF_RESULT_ERROR_UNKNOWN; + goto err_tbb_init; } pool_data->mem_provider = provider; - ret = pool_data->tbb_callbacks.pool_create_v1((intptr_t)pool_data, &policy, - &(pool_data->tbb_pool)); + ret = tbb_callbacks.pool_create_v1((intptr_t)pool_data, &policy, + &(pool_data->tbb_pool)); if (ret != 0 /* TBBMALLOC_OK */) { - return UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC; + res = UMF_RESULT_ERROR_MEMORY_PROVIDER_SPECIFIC; + goto err_tbb_init; } *pool = (void *)pool_data; - return UMF_RESULT_SUCCESS; + return res; + +err_tbb_init: + umf_ba_global_free(pool_data); + return res; } static void tbb_pool_finalize(void *pool) { tbb_memory_pool_t *pool_data = (tbb_memory_pool_t *)pool; - pool_data->tbb_callbacks.pool_destroy(pool_data->tbb_pool); - utils_close_library(pool_data->tbb_callbacks.lib_handle); + tbb_callbacks.pool_destroy(pool_data->tbb_pool); umf_ba_global_free(pool_data); } static void *tbb_malloc(void *pool, size_t size) { tbb_memory_pool_t *pool_data = (tbb_memory_pool_t *)pool; TLS_last_allocation_error = UMF_RESULT_SUCCESS; - void *ptr = pool_data->tbb_callbacks.pool_malloc(pool_data->tbb_pool, size); + void *ptr = tbb_callbacks.pool_malloc(pool_data->tbb_pool, size); if (ptr == NULL) { if (TLS_last_allocation_error == UMF_RESULT_SUCCESS) { TLS_last_allocation_error = UMF_RESULT_ERROR_UNKNOWN; @@ -319,8 +344,7 @@ static void *tbb_calloc(void *pool, size_t num, size_t size) { static void *tbb_realloc(void *pool, void *ptr, size_t size) { tbb_memory_pool_t *pool_data = (tbb_memory_pool_t *)pool; TLS_last_allocation_error = UMF_RESULT_SUCCESS; - void *new_ptr = - pool_data->tbb_callbacks.pool_realloc(pool_data->tbb_pool, ptr, size); + void *new_ptr = tbb_callbacks.pool_realloc(pool_data->tbb_pool, ptr, size); if (new_ptr == NULL) { if (TLS_last_allocation_error == UMF_RESULT_SUCCESS) { TLS_last_allocation_error = UMF_RESULT_ERROR_UNKNOWN; @@ -334,8 +358,8 @@ static void *tbb_realloc(void *pool, void *ptr, size_t size) { static void *tbb_aligned_malloc(void *pool, size_t size, size_t alignment) { tbb_memory_pool_t *pool_data = (tbb_memory_pool_t *)pool; TLS_last_allocation_error = UMF_RESULT_SUCCESS; - void *ptr = pool_data->tbb_callbacks.pool_aligned_malloc( - pool_data->tbb_pool, size, alignment); + void *ptr = + tbb_callbacks.pool_aligned_malloc(pool_data->tbb_pool, size, alignment); if (ptr == NULL) { if (TLS_last_allocation_error == UMF_RESULT_SUCCESS) { TLS_last_allocation_error = UMF_RESULT_ERROR_UNKNOWN; @@ -360,7 +384,7 @@ static umf_result_t tbb_free(void *pool, void *ptr) { utils_annotate_release(pool); tbb_memory_pool_t *pool_data = (tbb_memory_pool_t *)pool; - if (pool_data->tbb_callbacks.pool_free(pool_data->tbb_pool, ptr)) { + if (tbb_callbacks.pool_free(pool_data->tbb_pool, ptr)) { return UMF_RESULT_SUCCESS; } @@ -373,7 +397,7 @@ static umf_result_t tbb_free(void *pool, void *ptr) { static size_t tbb_malloc_usable_size(void *pool, void *ptr) { tbb_memory_pool_t *pool_data = (tbb_memory_pool_t *)pool; - return pool_data->tbb_callbacks.pool_msize(pool_data->tbb_pool, ptr); + return tbb_callbacks.pool_msize(pool_data->tbb_pool, ptr); } static umf_result_t tbb_get_last_allocation_error(void *pool) { diff --git a/src/pool/pool_scalable_internal.h b/src/pool/pool_scalable_internal.h new file mode 100644 index 0000000000..cfdc668fc6 --- /dev/null +++ b/src/pool/pool_scalable_internal.h @@ -0,0 +1,10 @@ +/* + * + * Copyright (C) 2025 Intel Corporation + * + * Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + */ + +void fini_tbb_global_state(void);