diff --git a/include/umf/memspace.h b/include/umf/memspace.h index 4b4597ef33..6f3abebfaa 100644 --- a/include/umf/memspace.h +++ b/include/umf/memspace.h @@ -46,9 +46,9 @@ umfMemoryProviderCreateFromMemspace(umf_const_memspace_handle_t hMemspace, umf_const_mempolicy_handle_t hPolicy, umf_memory_provider_handle_t *hProvider); /// -/// \brief Creates new memspace from array of NUMA node ids. +/// \brief Creates new memspace from an array of NUMA node ids. /// \param nodeIds array of NUMA node ids -/// \param numIds size of the array +/// \param numIds size of the array; it has to be greater than 0 /// \param hMemspace [out] handle to the newly created memspace /// \return UMF_RESULT_SUCCESS on success or appropriate error code on failure. /// diff --git a/src/memspace.c b/src/memspace.c index 31b52e26f0..716dc01919 100644 --- a/src/memspace.c +++ b/src/memspace.c @@ -204,9 +204,8 @@ static int propertyCmp(const void *a, const void *b) { umf_result_t umfMemspaceSortDesc(umf_memspace_handle_t hMemspace, umfGetPropertyFn getProperty) { - if (!hMemspace || !getProperty) { - return UMF_RESULT_ERROR_INVALID_ARGUMENT; - } + assert(hMemspace); + assert(getProperty); struct memtarget_sort_entry *entries = umf_ba_global_alloc( sizeof(struct memtarget_sort_entry) * hMemspace->size); @@ -241,9 +240,8 @@ umf_result_t umfMemspaceSortDesc(umf_memspace_handle_t hMemspace, umf_result_t umfMemspaceFilter(umf_const_memspace_handle_t hMemspace, umfGetTargetFn getTarget, umf_memspace_handle_t *filteredMemspace) { - if (!hMemspace || !getTarget) { - return UMF_RESULT_ERROR_INVALID_ARGUMENT; - } + assert(hMemspace); + assert(getTarget); umf_memtarget_handle_t *uniqueBestNodes = umf_ba_global_alloc(hMemspace->size * sizeof(*uniqueBestNodes)); @@ -389,6 +387,7 @@ umfMemspaceMemtargetRemove(umf_memspace_handle_t hMemspace, if (!hMemspace || !hMemtarget) { return UMF_RESULT_ERROR_INVALID_ARGUMENT; } + unsigned i; for (i = 0; i < hMemspace->size; i++) { int cmp; @@ -409,10 +408,16 @@ umfMemspaceMemtargetRemove(umf_memspace_handle_t hMemspace, return UMF_RESULT_ERROR_INVALID_ARGUMENT; } - umf_memtarget_handle_t *newNodes = - umf_ba_global_alloc(sizeof(*newNodes) * (hMemspace->size - 1)); - if (!newNodes) { - return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY; + umf_memtarget_handle_t *newNodes = NULL; + + if (hMemspace->size == 1) { + LOG_DEBUG("Removing the last memory target from the memspace."); + } else { + newNodes = + umf_ba_global_alloc(sizeof(*newNodes) * (hMemspace->size - 1)); + if (!newNodes) { + return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } } for (unsigned j = 0, z = 0; j < hMemspace->size; j++) { @@ -433,10 +438,8 @@ umfMemspaceMemtargetRemove(umf_memspace_handle_t hMemspace, static int umfMemspaceFilterHelper(umf_memspace_handle_t memspace, umf_memspace_filter_func_t filter, void *args) { - - if (!memspace || !filter) { - return UMF_RESULT_ERROR_INVALID_ARGUMENT; - } + assert(memspace); + assert(filter); size_t idx = 0; int ret; diff --git a/src/memtarget.c b/src/memtarget.c index 8eb6e4e8cb..b0fa316cb2 100644 --- a/src/memtarget.c +++ b/src/memtarget.c @@ -20,9 +20,8 @@ umf_result_t umfMemtargetCreate(const umf_memtarget_ops_t *ops, void *params, umf_memtarget_handle_t *memoryTarget) { libumfInit(); - if (!ops || !memoryTarget) { - return UMF_RESULT_ERROR_INVALID_ARGUMENT; - } + assert(ops); + assert(memoryTarget); umf_memtarget_handle_t target = umf_ba_global_alloc(sizeof(umf_memtarget_t)); @@ -93,9 +92,9 @@ umf_result_t umfMemtargetGetCapacity(umf_const_memtarget_handle_t memoryTarget, umf_result_t umfMemtargetGetBandwidth(umf_memtarget_handle_t srcMemoryTarget, umf_memtarget_handle_t dstMemoryTarget, size_t *bandwidth) { - if (!srcMemoryTarget || !dstMemoryTarget || !bandwidth) { - return UMF_RESULT_ERROR_INVALID_ARGUMENT; - } + assert(srcMemoryTarget); + assert(dstMemoryTarget); + assert(bandwidth); return srcMemoryTarget->ops->get_bandwidth( srcMemoryTarget->priv, dstMemoryTarget->priv, bandwidth); @@ -104,9 +103,9 @@ umf_result_t umfMemtargetGetBandwidth(umf_memtarget_handle_t srcMemoryTarget, umf_result_t umfMemtargetGetLatency(umf_memtarget_handle_t srcMemoryTarget, umf_memtarget_handle_t dstMemoryTarget, size_t *latency) { - if (!srcMemoryTarget || !dstMemoryTarget || !latency) { - return UMF_RESULT_ERROR_INVALID_ARGUMENT; - } + assert(srcMemoryTarget); + assert(dstMemoryTarget); + assert(latency); return srcMemoryTarget->ops->get_latency(srcMemoryTarget->priv, dstMemoryTarget->priv, latency); diff --git a/test/memspaces/memspace.cpp b/test/memspaces/memspace.cpp index 412c5beb70..c66447c351 100644 --- a/test/memspaces/memspace.cpp +++ b/test/memspaces/memspace.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -35,9 +35,33 @@ TEST_F(emptyMemspace, create_pool) { ASSERT_EQ(pool, nullptr); } -TEST_F(emptyMemspace, create_provider) { +TEST_F(emptyMemspace, invalid_create_from_memspace) { umf_memory_provider_handle_t provider = nullptr; - auto ret = umfMemoryProviderCreateFromMemspace(memspace, NULL, &provider); - ASSERT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); - ASSERT_EQ(provider, nullptr); + umf_mempolicy_handle_t policy = nullptr; + + // invalid memspace + umf_result_t ret = + umfMemoryProviderCreateFromMemspace(NULL, policy, &provider); + EXPECT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); + EXPECT_EQ(provider, nullptr); + + // invalid provider + ret = umfMemoryProviderCreateFromMemspace(memspace, policy, nullptr); + EXPECT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); + + // Valid params, but memspace is empty + ret = umfMemoryProviderCreateFromMemspace(memspace, policy, &provider); + EXPECT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); + EXPECT_EQ(provider, nullptr); +} + +TEST_F(emptyMemspace, invalid_clone) { + umf_const_memspace_handle_t memspace = nullptr; + umf_memspace_handle_t out_memspace = nullptr; + + umf_result_t ret = umfMemspaceClone(memspace, nullptr); + EXPECT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); + + ret = umfMemspaceClone(nullptr, &out_memspace); + EXPECT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); } diff --git a/test/memspaces/memtarget.cpp b/test/memspaces/memtarget.cpp index 325fa9d1d2..bd80ec14fd 100644 --- a/test/memspaces/memtarget.cpp +++ b/test/memspaces/memtarget.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2024-2025 Intel Corporation // Under the Apache License v2.0 with LLVM Exceptions. See LICENSE.TXT. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -103,3 +103,72 @@ TEST_F(numaNodesTest, getIdInvalid) { ret = umfMemtargetGetId(hTarget, NULL); EXPECT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); } + +TEST_F(test, memTargetInvalidAdd) { + umf_const_memspace_handle_t const_memspace = umfMemspaceHostAllGet(); + umf_memspace_handle_t memspace = nullptr; + umf_result_t ret = umfMemspaceClone(const_memspace, &memspace); + ASSERT_EQ(ret, UMF_RESULT_SUCCESS); + ASSERT_NE(memspace, nullptr); + umf_const_memtarget_handle_t memtarget = + umfMemspaceMemtargetGet(memspace, 0); + + ret = umfMemspaceMemtargetAdd(memspace, nullptr); + EXPECT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); + + ret = umfMemspaceMemtargetAdd(nullptr, memtarget); + EXPECT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); + + // Try to add the same memtarget again + ret = umfMemspaceMemtargetAdd(memspace, memtarget); + EXPECT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); + + ret = umfMemspaceDestroy(memspace); + EXPECT_EQ(ret, UMF_RESULT_SUCCESS); +} + +TEST_F(test, memTargetInvalidRemove) { + umf_const_memspace_handle_t const_memspace = umfMemspaceHostAllGet(); + umf_memspace_handle_t memspace = nullptr; + umf_result_t ret = umfMemspaceClone(const_memspace, &memspace); + ASSERT_EQ(ret, UMF_RESULT_SUCCESS); + ASSERT_NE(memspace, nullptr); + umf_const_memtarget_handle_t memtarget = + umfMemspaceMemtargetGet(memspace, 0); + + ret = umfMemspaceMemtargetRemove(memspace, nullptr); + EXPECT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); + + ret = umfMemspaceMemtargetRemove(nullptr, memtarget); + EXPECT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); + + ret = umfMemspaceDestroy(memspace); + EXPECT_EQ(ret, UMF_RESULT_SUCCESS); +} + +TEST_F(test, memTargetRemoveAll) { + umf_const_memspace_handle_t const_memspace = umfMemspaceHostAllGet(); + umf_memspace_handle_t memspace = nullptr; + umf_result_t ret = umfMemspaceClone(const_memspace, &memspace); + ASSERT_EQ(ret, UMF_RESULT_SUCCESS); + ASSERT_NE(memspace, nullptr); + umf_const_memtarget_handle_t memtarget = nullptr; + + // Remove all memtargets + size_t len = umfMemspaceMemtargetNum(memspace); + ASSERT_GT(len, 0); + size_t i = len - 1; + do { + memtarget = umfMemspaceMemtargetGet(memspace, i); + EXPECT_NE(memtarget, nullptr); + ret = umfMemspaceMemtargetRemove(memspace, memtarget); + ASSERT_EQ(ret, UMF_RESULT_SUCCESS); + } while (i-- > 0); + + // Try to remove the last one for the second time + ret = umfMemspaceMemtargetRemove(memspace, memtarget); + EXPECT_EQ(ret, UMF_RESULT_ERROR_INVALID_ARGUMENT); + + ret = umfMemspaceDestroy(memspace); + EXPECT_EQ(ret, UMF_RESULT_SUCCESS); +}