From e5b3ee2c157be0305dfdbfecbf349dc10e878873 Mon Sep 17 00:00:00 2001 From: doujiang24 Date: Sun, 27 Jul 2025 13:56:43 +0800 Subject: [PATCH] code format & enable code format checking in ci Signed-off-by: doujiang24 --- .clang-format | 3 + .github/workflows/ci.yml | 33 +- .../transfer_engine/transfer_engine_py.cpp | 130 +++-- .../transfer_engine/transfer_engine_py.h | 49 +- mooncake-store/benchmarks/allocator_bench.cpp | 20 +- mooncake-store/include/allocation_strategy.h | 13 +- mooncake-store/include/allocator.h | 9 +- mooncake-store/include/eviction_strategy.h | 16 +- mooncake-store/include/file_interface.h | 86 +-- mooncake-store/include/ha_helper.h | 5 +- mooncake-store/include/hf3fs/hf3fs.h | 61 +- .../include/master_metric_manager.h | 3 +- mooncake-store/include/master_service.h | 31 +- mooncake-store/include/rpc_service.h | 3 +- mooncake-store/include/segment.h | 19 +- mooncake-store/include/storage_backend.h | 119 ++-- mooncake-store/include/thread_pool.h | 38 +- mooncake-store/include/transfer_task.h | 28 +- mooncake-store/include/types.h | 10 +- mooncake-store/src/allocator.cpp | 10 +- mooncake-store/src/client.cpp | 20 +- mooncake-store/src/etcd_helper.cpp | 3 +- mooncake-store/src/ha_helper.cpp | 11 +- mooncake-store/src/hf3fs/hf3fs_file.cpp | 115 ++-- .../src/hf3fs/hf3fs_resource_manager.cpp | 16 +- mooncake-store/src/master.cpp | 6 +- mooncake-store/src/posix_file.cpp | 40 +- mooncake-store/src/rpc_service.cpp | 7 +- mooncake-store/src/segment.cpp | 8 +- mooncake-store/src/storage_backend.cpp | 134 +++-- mooncake-store/src/thread_pool.cpp | 22 +- mooncake-store/src/transfer_task.cpp | 36 +- mooncake-store/src/types.cpp | 9 +- mooncake-store/src/utils.cpp | 2 +- .../tests/buffer_allocator_test.cpp | 33 +- mooncake-store/tests/e2e/chaosctl.cpp | 24 +- mooncake-store/tests/e2e/client_wrapper.cpp | 14 +- mooncake-store/tests/e2e/clientctl.cpp | 16 +- mooncake-store/tests/e2e/e2e_utils.h | 2 +- .../tests/eviction_strategy_test.cpp | 5 +- mooncake-store/tests/posix_file_test.cpp | 61 +- mooncake-store/tests/stress_workload_test.cpp | 3 +- mooncake-store/tests/thread_pool_test.cpp | 18 +- .../transfer_engine_ascend_one_sided.cpp | 82 +-- .../example/transfer_engine_ascend_perf.cpp | 84 +-- .../example/transfer_engine_bench.cpp | 7 +- .../transfer_engine_bench_with_retry.cpp | 15 +- mooncake-transfer-engine/include/common.h | 50 +- .../include/common/base/status.h | 459 ++++++++------- .../include/transfer_engine.h | 2 +- .../include/transfer_metadata.h | 4 +- .../hccl_transport/hccl_transport.h | 17 +- .../hccl_transport/hccl_transport_mem_c.h | 45 +- .../transport/cxl_transport/cxl_transport.h | 7 +- .../nvlink_transport/nvlink_transport.h | 7 +- .../nvmeof_transport/cufile_desc_pool.h | 2 +- .../nvmeof_transport/nvmeof_transport.h | 11 +- .../transport/rdma_transport/rdma_endpoint.h | 4 +- .../src/common/base/status.cpp | 152 ++--- mooncake-transfer-engine/src/config.cpp | 4 +- .../src/memory_location.cpp | 3 +- .../src/multi_transport.cpp | 19 +- mooncake-transfer-engine/src/topology.cpp | 11 +- .../src/transfer_engine.cpp | 57 +- .../src/transfer_metadata.cpp | 8 +- .../src/transfer_metadata_plugin.cpp | 6 +- .../hccl_transport_mem_c.cpp | 532 +++++++++++------- .../hccl_transport/hccl_transport.cpp | 225 +++++--- .../transport/cxl_transport/cxl_transport.cpp | 67 +-- .../nvlink_transport/nvlink_transport.cpp | 36 +- .../nvmeof_transport/nvmeof_transport.cpp | 7 +- .../rdma_transport/rdma_endpoint.cpp | 8 +- .../rdma_transport/rdma_transport.cpp | 49 +- .../transport/rdma_transport/worker_pool.cpp | 13 +- .../transport/tcp_transport/tcp_transport.cpp | 2 +- .../tests/common_test.cpp | 9 +- .../tests/cxl_transport_test.cpp | 21 +- .../tests/memory_location_test.cpp | 4 +- .../tests/nvlink_transport_test.cpp | 23 +- .../tests/nvmeof_transport_test.cpp | 3 +- .../tests/rdma_loopback_test.cpp | 2 +- .../tests/rdma_transport_test.cpp | 3 +- scripts/ascend/pkg/hccl_mem.h | 19 +- scripts/ascend/pkg/hccl_mem_defs.h | 10 +- scripts/ascend/pkg/transport_mem.h | 52 +- 85 files changed, 1917 insertions(+), 1515 deletions(-) diff --git a/.clang-format b/.clang-format index 83affd882..169d14d2d 100644 --- a/.clang-format +++ b/.clang-format @@ -1,5 +1,8 @@ +--- BasedOnStyle: Google IndentWidth: 4 TabWidth: 4 UseTab: Never ColumnLimit: 80 +SortIncludes: false +... diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ba6664842..049d0d60a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -312,10 +312,10 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 - + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 - + - name: Build Docker image run: docker build -t mooncake-app . @@ -327,3 +327,32 @@ jobs: uses: actions/checkout@v4 - name: Spell Check Repo uses: crate-ci/typos@v1.30.2 + + clang-format: + name: Check code format + runs-on: ubuntu-22.04 + steps: + - name: Checkout Actions Repository + uses: actions/checkout@v4 + + - name: Install clang-format 20 + run: | + wget https://apt.llvm.org/llvm.sh + chmod +x llvm.sh + sudo ./llvm.sh 20 + sudo apt-get install -y clang-format-20 + + - name: run clang-format-20 + run: | + # the old clang-format-14 which is the defaut version in ubuntu 22.04, + # is inconsistent with clang-format-20. + ls -lh /usr/bin/clang-format* + clang-format --version + clang-format-20 --version + # skip cachelib_memory_allocator + find . -type f \( -name "*.h" -o -name "*.cpp" \) | grep -v cachelib_memory_allocator | xargs clang-format-20 -style=file -i + if ! git diff --exit-code; then + echo "Please follow the .clang-format code style, try clang-format -i FILENAME" + exit 1 + fi + shell: bash diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index 1829aaf4d..200820542 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -223,36 +223,34 @@ int TransferEnginePy::transferSyncRead(const char *target_hostname, TransferOpcode::READ); } -int TransferEnginePy::batchTransferSyncWrite(const char *target_hostname, - std::vector buffers, - std::vector peer_buffer_addresses, - std::vector lengths) { - return batchTransferSync(target_hostname, buffers, peer_buffer_addresses, lengths, - TransferOpcode::WRITE); +int TransferEnginePy::batchTransferSyncWrite( + const char *target_hostname, std::vector buffers, + std::vector peer_buffer_addresses, std::vector lengths) { + return batchTransferSync(target_hostname, buffers, peer_buffer_addresses, + lengths, TransferOpcode::WRITE); } -int TransferEnginePy::batchTransferSyncRead(const char *target_hostname, - std::vector buffers, - std::vector peer_buffer_addresses, - std::vector lengths) { - return batchTransferSync(target_hostname, buffers, peer_buffer_addresses, lengths, - TransferOpcode::READ); +int TransferEnginePy::batchTransferSyncRead( + const char *target_hostname, std::vector buffers, + std::vector peer_buffer_addresses, std::vector lengths) { + return batchTransferSync(target_hostname, buffers, peer_buffer_addresses, + lengths, TransferOpcode::READ); } -batch_id_t TransferEnginePy::batchTransferAsyncWrite(const char *target_hostname, - const std::vector &buffers, - const std::vector &peer_buffer_addresses, - const std::vector &lengths) { - return batchTransferAsync(target_hostname, buffers, peer_buffer_addresses, lengths, - TransferOpcode::WRITE); +batch_id_t TransferEnginePy::batchTransferAsyncWrite( + const char *target_hostname, const std::vector &buffers, + const std::vector &peer_buffer_addresses, + const std::vector &lengths) { + return batchTransferAsync(target_hostname, buffers, peer_buffer_addresses, + lengths, TransferOpcode::WRITE); } -batch_id_t TransferEnginePy::batchTransferAsyncRead(const char *target_hostname, - const std::vector &buffers, - const std::vector &peer_buffer_addresses, - const std::vector &lengths) { - return batchTransferAsync(target_hostname, buffers, peer_buffer_addresses, lengths, - TransferOpcode::READ); +batch_id_t TransferEnginePy::batchTransferAsyncRead( + const char *target_hostname, const std::vector &buffers, + const std::vector &peer_buffer_addresses, + const std::vector &lengths) { + return batchTransferAsync(target_hostname, buffers, peer_buffer_addresses, + lengths, TransferOpcode::READ); } int TransferEnginePy::transferSync(const char *target_hostname, @@ -328,11 +326,10 @@ int TransferEnginePy::transferSync(const char *target_hostname, return -1; } -int TransferEnginePy::batchTransferSync(const char *target_hostname, - std::vector buffers, - std::vector peer_buffer_addresses, - std::vector lengths, - TransferOpcode opcode) { +int TransferEnginePy::batchTransferSync( + const char *target_hostname, std::vector buffers, + std::vector peer_buffer_addresses, std::vector lengths, + TransferOpcode opcode) { pybind11::gil_scoped_release release; Transport::SegmentHandle handle; { @@ -346,8 +343,10 @@ int TransferEnginePy::batchTransferSync(const char *target_hostname, } } - if (buffers.size() != peer_buffer_addresses.size() || buffers.size() != lengths.size()) { - LOG(ERROR) << "buffers, peer_buffer_addresses and lengths have different size"; + if (buffers.size() != peer_buffer_addresses.size() || + buffers.size() != lengths.size()) { + LOG(ERROR) + << "buffers, peer_buffer_addresses and lengths have different size"; return -1; } @@ -397,12 +396,14 @@ int TransferEnginePy::batchTransferSync(const char *target_hostname, completed = true; } auto current_ts = getCurrentTimeInNano(); - const int64_t timeout = transfer_timeout_nsec_ + total_length; // 1GiB per second + const int64_t timeout = + transfer_timeout_nsec_ + total_length; // 1GiB per second if (current_ts - start_ts > timeout) { - LOG(INFO) << "Sync batch data transfer timeout after " + LOG(INFO) << "Sync batch data transfer timeout after " << current_ts - start_ts << "ns"; - // TODO: as @doujiang24 mentioned, early free(while there are still waiting tasks) - // the batch_id may fail and cause memory leak(a known issue). + // TODO: as @doujiang24 mentioned, early free(while there are + // still waiting tasks) the batch_id may fail and cause memory + // leak(a known issue). if (!already_freed) { engine_->freeBatchID(batch_id); } @@ -413,11 +414,10 @@ int TransferEnginePy::batchTransferSync(const char *target_hostname, return -1; } -batch_id_t TransferEnginePy::batchTransferAsync(const char *target_hostname, - const std::vector& buffers, - const std::vector& peer_buffer_addresses, - const std::vector& lengths, - TransferOpcode opcode) { +batch_id_t TransferEnginePy::batchTransferAsync( + const char *target_hostname, const std::vector &buffers, + const std::vector &peer_buffer_addresses, + const std::vector &lengths, TransferOpcode opcode) { pybind11::gil_scoped_release release; Transport::SegmentHandle handle; { @@ -431,8 +431,10 @@ batch_id_t TransferEnginePy::batchTransferAsync(const char *target_hostname, } } - if (buffers.size() != peer_buffer_addresses.size() || buffers.size() != lengths.size()) { - LOG(ERROR) << "buffers, peer_buffer_addresses and lengths have different size"; + if (buffers.size() != peer_buffer_addresses.size() || + buffers.size() != lengths.size()) { + LOG(ERROR) + << "buffers, peer_buffer_addresses and lengths have different size"; return 0; } @@ -474,7 +476,8 @@ batch_id_t TransferEnginePy::batchTransferAsync(const char *target_hostname, return batch_id; } -int TransferEnginePy::getBatchTransferStatus(const std::vector& batch_ids) { +int TransferEnginePy::getBatchTransferStatus( + const std::vector &batch_ids) { pybind11::gil_scoped_release release; TransferStatus status; std::unordered_map timeout_table{}; @@ -494,7 +497,7 @@ int TransferEnginePy::getBatchTransferStatus(const std::vector& batc } bool failed_or_timeout = false; - std::unordered_set remove_ids {}; + std::unordered_set remove_ids{}; while (!timeout_table.empty() && !failed_or_timeout) { for (auto &entry : timeout_table) { auto batch_desc = reinterpret_cast(entry.first); @@ -511,8 +514,8 @@ int TransferEnginePy::getBatchTransferStatus(const std::vector& batc } auto current_ts = getCurrentTimeInNano(); if (current_ts - batch_desc->start_timestamp > entry.second) { - LOG(INFO) << "Sync batch data transfer timeout after " - << current_ts - batch_desc->start_timestamp << "ns"; + LOG(INFO) << "Sync batch data transfer timeout after " + << current_ts - batch_desc->start_timestamp << "ns"; failed_or_timeout = true; } } @@ -582,22 +585,24 @@ int TransferEnginePy::transferCheckStatus(batch_id_t batch_id) { } } -int TransferEnginePy::batchRegisterMemory(std::vector buffer_addresses, - std::vector capacities) { +int TransferEnginePy::batchRegisterMemory( + std::vector buffer_addresses, std::vector capacities) { pybind11::gil_scoped_release release; auto batch_size = buffer_addresses.size(); std::vector buffers; - for (size_t i = 0; i < batch_size; i ++ ) { - buffers.push_back(BufferEntry{(void *)buffer_addresses[i], capacities[i]}); + for (size_t i = 0; i < batch_size; i++) { + buffers.push_back( + BufferEntry{(void *)buffer_addresses[i], capacities[i]}); } return engine_->registerLocalMemoryBatch(buffers, kWildcardLocation); } -int TransferEnginePy::batchUnregisterMemory(std::vector buffer_addresses) { +int TransferEnginePy::batchUnregisterMemory( + std::vector buffer_addresses) { pybind11::gil_scoped_release release; auto batch_size = buffer_addresses.size(); std::vector buffers; - for (size_t i = 0; i < batch_size; i ++ ) { + for (size_t i = 0; i < batch_size; i++) { buffers.push_back(reinterpret_cast(buffer_addresses[i])); } return engine_->unregisterLocalMemoryBatch(buffers); @@ -641,14 +646,19 @@ PYBIND11_MODULE(engine, m) { .def("free_managed_buffer", &TransferEnginePy::freeManagedBuffer) .def("transfer_sync_write", &TransferEnginePy::transferSyncWrite) .def("transfer_sync_read", &TransferEnginePy::transferSyncRead) - .def("batch_transfer_sync_write", &TransferEnginePy::batchTransferSyncWrite) - .def("batch_transfer_sync_read", &TransferEnginePy::batchTransferSyncRead) - .def("batch_transfer_async_write", &TransferEnginePy::batchTransferAsyncWrite) - .def("batch_transfer_async_read", &TransferEnginePy::batchTransferAsyncRead) + .def("batch_transfer_sync_write", + &TransferEnginePy::batchTransferSyncWrite) + .def("batch_transfer_sync_read", + &TransferEnginePy::batchTransferSyncRead) + .def("batch_transfer_async_write", + &TransferEnginePy::batchTransferAsyncWrite) + .def("batch_transfer_async_read", + &TransferEnginePy::batchTransferAsyncRead) .def("transfer_sync", &TransferEnginePy::transferSync) .def("batch_transfer_sync", &TransferEnginePy::batchTransferSync) .def("batch_transfer_async", &TransferEnginePy::batchTransferAsync) - .def("get_batch_transfer_status", &TransferEnginePy::getBatchTransferStatus) + .def("get_batch_transfer_status", + &TransferEnginePy::getBatchTransferStatus) .def("transfer_submit_write", &TransferEnginePy::transferSubmitWrite) .def("transfer_check_status", @@ -658,8 +668,10 @@ PYBIND11_MODULE(engine, m) { &TransferEnginePy::readBytesFromBuffer) .def("register_memory", &TransferEnginePy::registerMemory) .def("unregister_memory", &TransferEnginePy::unregisterMemory) - .def("batch_register_memory", &TransferEnginePy::batchRegisterMemory) - .def("batch_unregister_memory", &TransferEnginePy::batchUnregisterMemory) + .def("batch_register_memory", + &TransferEnginePy::batchRegisterMemory) + .def("batch_unregister_memory", + &TransferEnginePy::batchUnregisterMemory) .def("get_first_buffer_address", &TransferEnginePy::getFirstBufferAddress); diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.h b/mooncake-integration/transfer_engine/transfer_engine_py.h index 558b59ffa..a1337777a 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.h +++ b/mooncake-integration/transfer_engine/transfer_engine_py.h @@ -68,14 +68,16 @@ class TransferEnginePy { int transferSyncWrite(const char *target_hostname, uintptr_t buffer, uintptr_t peer_buffer_address, size_t length); - batch_id_t transferSubmitWrite(const char *target_hostname, uintptr_t buffer, - uintptr_t peer_buffer_address, size_t length); + batch_id_t transferSubmitWrite(const char *target_hostname, + uintptr_t buffer, + uintptr_t peer_buffer_address, + size_t length); int transferCheckStatus(batch_id_t batch_id); int transferSyncRead(const char *target_hostname, uintptr_t buffer, uintptr_t peer_buffer_address, size_t length); - + int batchTransferSyncWrite(const char *target_hostname, std::vector buffers, std::vector peer_buffer_addresses, @@ -86,35 +88,33 @@ class TransferEnginePy { std::vector peer_buffer_addresses, std::vector lengths); - batch_id_t batchTransferAsyncWrite(const char *target_hostname, - const std::vector &buffers, - const std::vector &peer_buffer_addresses, - const std::vector &lengths); + batch_id_t batchTransferAsyncWrite( + const char *target_hostname, const std::vector &buffers, + const std::vector &peer_buffer_addresses, + const std::vector &lengths); - batch_id_t batchTransferAsyncRead(const char *target_hostname, - const std::vector &buffers, - const std::vector &peer_buffer_addresses, - const std::vector &lengths); + batch_id_t batchTransferAsyncRead( + const char *target_hostname, const std::vector &buffers, + const std::vector &peer_buffer_addresses, + const std::vector &lengths); int transferSync(const char *target_hostname, uintptr_t buffer, uintptr_t peer_buffer_address, size_t length, TransferOpcode opcode); - - // Known issue: in a few inference engines and benchmarks, accuracy - // may be affected when using the batchTransferSync API. We currently + + // Known issue: in a few inference engines and benchmarks, accuracy + // may be affected when using the batchTransferSync API. We currently // found this issue only in multi-node NVLink transfers. int batchTransferSync(const char *target_hostname, std::vector buffers, std::vector peer_buffer_addresses, - std::vector lengths, - TransferOpcode opcode); - - batch_id_t batchTransferAsync(const char *target_hostname, - const std::vector &buffers, - const std::vector &peer_buffer_addresses, - const std::vector &lengths, - TransferOpcode opcode); - + std::vector lengths, TransferOpcode opcode); + + batch_id_t batchTransferAsync( + const char *target_hostname, const std::vector &buffers, + const std::vector &peer_buffer_addresses, + const std::vector &lengths, TransferOpcode opcode); + int getBatchTransferStatus(const std::vector &batch_ids); uintptr_t getFirstBufferAddress(const std::string &segment_name); @@ -138,7 +138,8 @@ class TransferEnginePy { // must be called before TransferEnginePy::~TransferEnginePy() int unregisterMemory(uintptr_t buffer_addr); - int batchRegisterMemory(std::vector buffer_addresses, std::vector capacities); + int batchRegisterMemory(std::vector buffer_addresses, + std::vector capacities); int batchUnregisterMemory(std::vector buffer_addresses); diff --git a/mooncake-store/benchmarks/allocator_bench.cpp b/mooncake-store/benchmarks/allocator_bench.cpp index 9cf848810..e330b4c4e 100644 --- a/mooncake-store/benchmarks/allocator_bench.cpp +++ b/mooncake-store/benchmarks/allocator_bench.cpp @@ -12,7 +12,8 @@ using namespace mooncake::offset_allocator; class OffsetAllocatorBenchHelper { public: - OffsetAllocatorBenchHelper(uint64_t baseAddress, uint32_t poolSize, uint32_t maxAllocs) + OffsetAllocatorBenchHelper(uint64_t baseAddress, uint32_t poolSize, + uint32_t maxAllocs) : pool_size_(poolSize), allocated_size_(0), allocator_(OffsetAllocator::create(baseAddress, poolSize, maxAllocs)), @@ -58,7 +59,8 @@ class OffsetAllocatorBenchHelper { template void uniform_size_allocation_benchmark() { - std::cout << std::endl << "=== Uniform Size Allocation Benchmark ===" << std::endl; + std::cout << std::endl + << "=== Uniform Size Allocation Benchmark ===" << std::endl; const size_t max_pool_size = 2ull * 1024 * 1024 * 1024; std::vector allocation_sizes; for (uint32_t i = 32; i < (1 << 26); i *= 4) { @@ -103,26 +105,30 @@ void uniform_size_allocation_benchmark() { total_util_ratio += util_ratio; } auto end_time = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end_time - start_time); + auto duration = std::chrono::duration_cast( + end_time - start_time); // END double avg_util_ratio = total_util_ratio / benchmark_num; std::cout << "Alloc size: " << alloc_size << ", min util ratio: " << min_util_ratio << ", avg util ratio: " << avg_util_ratio - << ", time: " << duration.count() / benchmark_num << " ns" << std::endl; + << ", time: " << duration.count() / benchmark_num << " ns" + << std::endl; } } template void random_size_allocation_benchmark() { - std::cout << std::endl << "=== Random Size Allocation Benchmark ===" << std::endl; + std::cout << std::endl + << "=== Random Size Allocation Benchmark ===" << std::endl; const size_t pool_size = 2ull * 1024 * 1024 * 1024; const size_t max_alloc_size = 1ull << 26; const size_t min_alloc_size = 1024; std::random_device rd; std::mt19937 gen(rd()); - std::uniform_int_distribution dist(min_alloc_size, max_alloc_size); + std::uniform_int_distribution dist(min_alloc_size, + max_alloc_size); // Warmup size_t max_allocs = pool_size / min_alloc_size + 10; @@ -138,7 +144,7 @@ void random_size_allocation_benchmark() { util_ratios.reserve(benchmark_num); // Run benchmark - auto start_time = std::chrono::high_resolution_clock::now(); + auto start_time = std::chrono::high_resolution_clock::now(); for (int i = 0; i < benchmark_num; i++) { size_t alloc_size = dist(gen); bench_helper.allocate(alloc_size); diff --git a/mooncake-store/include/allocation_strategy.h b/mooncake-store/include/allocation_strategy.h index 30113b90c..2a92e401b 100644 --- a/mooncake-store/include/allocation_strategy.h +++ b/mooncake-store/include/allocation_strategy.h @@ -22,8 +22,8 @@ class AllocationStrategy { * @brief Given all mounted BufferAllocators and required object size, * the strategy can freely choose a suitable BufferAllocator. * @param allocators Container of mounted allocators - * @param allocators_by_name Container of mounted allocators, key is segment_name, - * value is the corresponding allocator + * @param allocators_by_name Container of mounted allocators, key is + * segment_name, value is the corresponding allocator * @param objectSize Size of object to be allocated * @param config Replica configuration * @return Selected allocator; returns nullptr if allocation is not possible @@ -31,7 +31,8 @@ class AllocationStrategy { */ virtual std::unique_ptr Allocate( const std::vector>& allocators, - const std::unordered_map>>& + const std::unordered_map< + std::string, std::vector>>& allocators_by_name, size_t objectSize, const ReplicateConfig& config) = 0; }; @@ -49,7 +50,8 @@ class RandomAllocationStrategy : public AllocationStrategy { std::unique_ptr Allocate( const std::vector>& allocators, - const std::unordered_map>>& + const std::unordered_map< + std::string, std::vector>>& allocators_by_name, size_t objectSize, const ReplicateConfig& config) override { // Fast path: single allocator case @@ -77,7 +79,8 @@ class RandomAllocationStrategy : public AllocationStrategy { * eligible */ std::unique_ptr TryPreferredAllocate( - const std::unordered_map>>& + const std::unordered_map< + std::string, std::vector>>& allocators, size_t objectSize, const ReplicateConfig& config) { if (config.preferred_segment.empty()) { diff --git a/mooncake-store/include/allocator.h b/mooncake-store/include/allocator.h index 68dcbb157..2d33a726b 100644 --- a/mooncake-store/include/allocator.h +++ b/mooncake-store/include/allocator.h @@ -24,7 +24,7 @@ class AllocatedBuffer; class BufferAllocatorBase { public: virtual ~BufferAllocatorBase() = default; - + virtual std::unique_ptr allocate(size_t size) = 0; virtual void deallocate(AllocatedBuffer* handle) = 0; virtual size_t capacity() const = 0; @@ -33,8 +33,8 @@ class BufferAllocatorBase { }; /** - * CachelibBufferAllocator manages memory allocation using CacheLib's slab allocation - * strategy. + * CachelibBufferAllocator manages memory allocation using CacheLib's slab + * allocation strategy. * * Important alignment requirements: * 1. Base address must be at least 8-byte aligned (CacheLib requirement) @@ -88,7 +88,8 @@ class CachelibBufferAllocator /** * OffsetBufferAllocator manages memory allocation using the OffsetAllocator - * strategy, which provides efficient memory allocation with bin-based optimization. + * strategy, which provides efficient memory allocation with bin-based + * optimization. */ class OffsetBufferAllocator : public BufferAllocatorBase, diff --git a/mooncake-store/include/eviction_strategy.h b/mooncake-store/include/eviction_strategy.h index a2eb2e5d1..2088a87ba 100644 --- a/mooncake-store/include/eviction_strategy.h +++ b/mooncake-store/include/eviction_strategy.h @@ -13,7 +13,7 @@ namespace mooncake { * @brief Abstract interface for eviction strategy, responsible for choosing * which kvcache object to be evicted before pool overflow. */ -class EvictionStrategy : public std::enable_shared_from_this{ +class EvictionStrategy : public std::enable_shared_from_this { public: virtual ~EvictionStrategy() = default; virtual ErrorCode AddKey(const std::string& key) = 0; @@ -28,23 +28,23 @@ class EvictionStrategy : public std::enable_shared_from_this{ return ErrorCode::OK; } virtual std::string EvictKey(void) = 0; - virtual size_t GetSize(void) { - return all_key_list_.size(); - } + virtual size_t GetSize(void) { return all_key_list_.size(); } void CleanUp(void) { all_key_list_.clear(); all_key_idx_map_.clear(); } + protected: std::list all_key_list_; - std::unordered_map::iterator> all_key_idx_map_; + std::unordered_map::iterator> + all_key_idx_map_; }; class LRUEvictionStrategy : public EvictionStrategy { public: virtual ErrorCode AddKey(const std::string& key) override { // Add key to the front of the list - if(all_key_idx_map_.find(key) != all_key_idx_map_.end()) { + if (all_key_idx_map_.find(key) != all_key_idx_map_.end()) { all_key_list_.erase(all_key_idx_map_[key]); all_key_idx_map_.erase(key); } @@ -93,7 +93,7 @@ class FIFOEvictionStrategy : public EvictionStrategy { std::string evicted_key = all_key_list_.back(); all_key_list_.pop_back(); return evicted_key; - } + } }; -} \ No newline at end of file +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-store/include/file_interface.h b/mooncake-store/include/file_interface.h index e3aefddf1..48f67dc2b 100644 --- a/mooncake-store/include/file_interface.h +++ b/mooncake-store/include/file_interface.h @@ -11,7 +11,7 @@ namespace mooncake { class FileLockRAII { -public: + public: enum class LockType { READ, WRITE }; FileLockRAII(int fd, LockType type) : fd_(fd), locked_(false) { @@ -28,18 +28,17 @@ class FileLockRAII { } } + FileLockRAII(const FileLockRAII &) = delete; + FileLockRAII &operator=(const FileLockRAII &) = delete; - FileLockRAII(const FileLockRAII&) = delete; - FileLockRAII& operator=(const FileLockRAII&) = delete; - - FileLockRAII(FileLockRAII&& other) noexcept + FileLockRAII(FileLockRAII &&other) noexcept : fd_(other.fd_), locked_(other.locked_) { other.locked_ = false; } bool is_locked() const { return locked_; } -private: + private: int fd_; bool locked_; }; @@ -47,15 +46,17 @@ class FileLockRAII { /** * @class LocalFile * @brief RAII wrapper for file operations with thread-safe locking support - * - * Provides thread-safe file I/O operations including read/write and vectorized I/O. - * Implements proper resource management through RAII pattern. + * + * Provides thread-safe file I/O operations including read/write and vectorized + * I/O. Implements proper resource management through RAII pattern. */ class StorageFile { -public: - + public: StorageFile(const std::string &filename, int fd) - : filename_(filename), fd_(fd), error_code_(ErrorCode::OK), is_locked_(false) {} + : filename_(filename), + fd_(fd), + error_code_(ErrorCode::OK), + is_locked_(false) {} /** * @brief Destructor * @note Automatically closes the file and releases resources @@ -66,50 +67,62 @@ class StorageFile { * @brief Writes data from buffer to file * @param buffer Input buffer containing data to write * @param length Number of bytes to write - * @return tl::expected containing number of bytes written on success, or ErrorCode on failure + * @return tl::expected containing number of bytes + * written on success, or ErrorCode on failure * @note Thread-safe operation with write locking */ - virtual tl::expected write(const std::string &buffer, size_t length) = 0; + virtual tl::expected write(const std::string &buffer, + size_t length) = 0; /** * @brief Writes data from buffer to file * @param data Input span containing data to write * @param length Number of bytes to write - * @return tl::expected containing number of bytes written on success, or ErrorCode on failure + * @return tl::expected containing number of bytes + * written on success, or ErrorCode on failure * @note Thread-safe operation with write locking */ - virtual tl::expected write(std::span data, size_t length) = 0; + virtual tl::expected write(std::span data, + size_t length) = 0; /** * @brief Reads data from file into buffer * @param buffer Output buffer for read data * @param length Maximum number of bytes to read - * @return tl::expected containing number of bytes read on success, or ErrorCode on failure + * @return tl::expected containing number of bytes read + * on success, or ErrorCode on failure * @note Thread-safe operation with read locking */ - virtual tl::expected read(std::string &buffer, size_t length) = 0; + virtual tl::expected read(std::string &buffer, + size_t length) = 0; /** * @brief Scattered write at specified file offset * @param iov Array of I/O vectors * @param iovcnt Number of elements in iov array * @param offset File offset to write at - * @return tl::expected containing total bytes written on success, or ErrorCode on failure + * @return tl::expected containing total bytes written on + * success, or ErrorCode on failure * @note Thread-safe operation with write locking */ - virtual tl::expected vector_write(const iovec *iov, int iovcnt, off_t offset) = 0; + virtual tl::expected vector_write(const iovec *iov, + int iovcnt, + off_t offset) = 0; /** * @brief Scattered read from specified file offset * @param iov Array of I/O vectors * @param iovcnt Number of elements in iov array * @param offset File offset to read from - * @return tl::expected containing total bytes read on success, or ErrorCode on failure + * @return tl::expected containing total bytes read on + * success, or ErrorCode on failure * @note Thread-safe operation with read locking */ - virtual tl::expected vector_read(const iovec *iov, int iovcnt, off_t offset) = 0; + virtual tl::expected vector_read(const iovec *iov, + int iovcnt, + off_t offset) = 0; - template + template tl::expected make_error(ErrorCode code) { error_code_ = code; return tl::make_unexpected(code); @@ -130,11 +143,9 @@ class StorageFile { * @brief Gets the current error code * @return Current error code */ - ErrorCode get_error_code(){ - return error_code_; - } + ErrorCode get_error_code() { return error_code_; } -protected: + protected: std::string filename_; int fd_; ErrorCode error_code_{ErrorCode::OK}; @@ -142,21 +153,24 @@ class StorageFile { }; class PosixFile : public StorageFile { -public: + public: PosixFile(const std::string &filename, int fd); ~PosixFile() override; - tl::expected write(const std::string &buffer, size_t length) override; - tl::expected write(std::span data, size_t length) override; - tl::expected read(std::string &buffer, size_t length) override; - tl::expected vector_write(const iovec *iov, int iovcnt, off_t offset) override; - tl::expected vector_read(const iovec *iov, int iovcnt, off_t offset) override; + tl::expected write(const std::string &buffer, + size_t length) override; + tl::expected write(std::span data, + size_t length) override; + tl::expected read(std::string &buffer, + size_t length) override; + tl::expected vector_write(const iovec *iov, int iovcnt, + off_t offset) override; + tl::expected vector_read(const iovec *iov, int iovcnt, + off_t offset) override; }; -} // namespace mooncake +} // namespace mooncake #ifdef USE_3FS #include #endif - - diff --git a/mooncake-store/include/ha_helper.h b/mooncake-store/include/ha_helper.h index d07a58e1a..504d75b65 100644 --- a/mooncake-store/include/ha_helper.h +++ b/mooncake-store/include/ha_helper.h @@ -79,9 +79,8 @@ class MasterServiceSupervisor { int rpc_port, size_t rpc_thread_num, bool enable_gc, bool enable_metric_reporting, int metrics_port, int64_t default_kv_lease_ttl, int64_t default_kv_soft_pin_ttl, - bool allow_evict_soft_pinned_objects, - double eviction_ratio, double eviction_high_watermark_ratio, - int64_t client_live_ttl_sec, + bool allow_evict_soft_pinned_objects, double eviction_ratio, + double eviction_high_watermark_ratio, int64_t client_live_ttl_sec, const std::string& etcd_endpoints = "0.0.0.0:2379", const std::string& local_hostname = "0.0.0.0:50051", const std::string& rpc_address = "0.0.0.0", diff --git a/mooncake-store/include/hf3fs/hf3fs.h b/mooncake-store/include/hf3fs/hf3fs.h index 48b6f01bf..b624d1562 100644 --- a/mooncake-store/include/hf3fs/hf3fs.h +++ b/mooncake-store/include/hf3fs/hf3fs.h @@ -14,39 +14,38 @@ class StorageFile; // Forward declaration of USRBIOResourceManager struct Hf3fsConfig { // 3FS cluster related parameters - + // USRBIO related parameters - std::string mount_root = "/"; // Mount point root directory - size_t iov_size = 32 << 20; // Shared memory size (32MB) - size_t ior_entries = 16; // Maximum number of requests in IO ring + std::string mount_root = "/"; // Mount point root directory + size_t iov_size = 32 << 20; // Shared memory size (32MB) + size_t ior_entries = 16; // Maximum number of requests in IO ring //`0` for no control with I/O depth. - // If greater than 0, then only when `io_depth` I/O requests are in queue, they will be issued to server as a batch. - // If smaller than 0, then USRBIO will wait for at most `-io_depth` I/O requests are in queue and issue them in one batch. - // If io_depth is 0, then USRBIO will issue all the prepared I/O requests to server ASAP. - size_t io_depth = 0; // IO batch processing depth - int ior_timeout = 0; // IO timeout (milliseconds) + // If greater than 0, then only when `io_depth` I/O requests are in queue, + // they will be issued to server as a batch. If smaller than 0, then USRBIO + // will wait for at most `-io_depth` I/O requests are in queue and issue + // them in one batch. If io_depth is 0, then USRBIO will issue all the + // prepared I/O requests to server ASAP. + size_t io_depth = 0; // IO batch processing depth + int ior_timeout = 0; // IO timeout (milliseconds) }; class USRBIOResourceManager { -public: - + public: USRBIOResourceManager() {} - void setDefaultParams(const Hf3fsConfig& config) { + void setDefaultParams(const Hf3fsConfig &config) { default_config_ = config; } - struct ThreadUSRBIOResource* getThreadResource( - const Hf3fsConfig &config); + struct ThreadUSRBIOResource *getThreadResource(const Hf3fsConfig &config); - struct ThreadUSRBIOResource* getThreadResource() { + struct ThreadUSRBIOResource *getThreadResource() { return getThreadResource(default_config_); } ~USRBIOResourceManager(); - -private: + private: USRBIOResourceManager(const USRBIOResourceManager &) = delete; USRBIOResourceManager &operator=(const USRBIOResourceManager &) = delete; Hf3fsConfig default_config_; @@ -84,18 +83,24 @@ struct ThreadUSRBIOResource { }; class ThreeFSFile : public StorageFile { -public: - ThreeFSFile(const std::string &filename, int fd, USRBIOResourceManager* resource_manager); + public: + ThreeFSFile(const std::string &filename, int fd, + USRBIOResourceManager *resource_manager); ~ThreeFSFile() override; - tl::expected write(const std::string &buffer, size_t length) override; - tl::expected write(std::span data, size_t length) override; - tl::expected read(std::string &buffer, size_t length) override; - tl::expected vector_write(const iovec *iov, int iovcnt, off_t offset) override; - tl::expected vector_read(const iovec *iov, int iovcnt, off_t offset) override; - -private: - USRBIOResourceManager* resource_manager_; + tl::expected write(const std::string &buffer, + size_t length) override; + tl::expected write(std::span data, + size_t length) override; + tl::expected read(std::string &buffer, + size_t length) override; + tl::expected vector_write(const iovec *iov, int iovcnt, + off_t offset) override; + tl::expected vector_read(const iovec *iov, int iovcnt, + off_t offset) override; + + private: + USRBIOResourceManager *resource_manager_; }; -} \ No newline at end of file +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-store/include/master_metric_manager.h b/mooncake-store/include/master_metric_manager.h index e35b359c5..f7dc3ce40 100644 --- a/mooncake-store/include/master_metric_manager.h +++ b/mooncake-store/include/master_metric_manager.h @@ -78,7 +78,6 @@ class MasterMetricManager { void inc_batch_put_revoke_requests(int64_t val = 1); void inc_batch_put_revoke_failures(int64_t val = 1); - // Operation Statistics Getters int64_t get_put_start_requests(); int64_t get_put_start_failures(); @@ -117,7 +116,7 @@ class MasterMetricManager { // Eviction Metrics void inc_eviction_success(int64_t key_count, int64_t size); - void inc_eviction_fail(); // not a single object is evicted + void inc_eviction_fail(); // not a single object is evicted // Eviction Metrics Getters int64_t get_eviction_success(); diff --git a/mooncake-store/include/master_service.h b/mooncake-store/include/master_service.h index 9ecaff9a4..fb9ae8317 100644 --- a/mooncake-store/include/master_service.h +++ b/mooncake-store/include/master_service.h @@ -59,20 +59,20 @@ class MasterService { }; public: - MasterService(bool enable_gc = true, - uint64_t default_kv_lease_ttl = DEFAULT_DEFAULT_KV_LEASE_TTL, - uint64_t default_kv_soft_pin_ttl = DEFAULT_KV_SOFT_PIN_TTL_MS, - bool allow_evict_soft_pinned_objects = - DEFAULT_ALLOW_EVICT_SOFT_PINNED_OBJECTS, - double eviction_ratio = DEFAULT_EVICTION_RATIO, - double eviction_high_watermark_ratio = - DEFAULT_EVICTION_HIGH_WATERMARK_RATIO, - ViewVersionId view_version = 0, - int64_t client_live_ttl_sec = DEFAULT_CLIENT_LIVE_TTL_SEC, - bool enable_ha = false, - const std::string& cluster_id = DEFAULT_CLUSTER_ID, - BufferAllocatorType memory_allocator = - BufferAllocatorType::CACHELIB); + MasterService( + bool enable_gc = true, + uint64_t default_kv_lease_ttl = DEFAULT_DEFAULT_KV_LEASE_TTL, + uint64_t default_kv_soft_pin_ttl = DEFAULT_KV_SOFT_PIN_TTL_MS, + bool allow_evict_soft_pinned_objects = + DEFAULT_ALLOW_EVICT_SOFT_PINNED_OBJECTS, + double eviction_ratio = DEFAULT_EVICTION_RATIO, + double eviction_high_watermark_ratio = + DEFAULT_EVICTION_HIGH_WATERMARK_RATIO, + ViewVersionId view_version = 0, + int64_t client_live_ttl_sec = DEFAULT_CLIENT_LIVE_TTL_SEC, + bool enable_ha = false, + const std::string& cluster_id = DEFAULT_CLUSTER_ID, + BufferAllocatorType memory_allocator = BufferAllocatorType::CACHELIB); ~MasterService(); /** @@ -235,8 +235,7 @@ class MasterService { * @return ErrorCode::OK on success, ErrorCode::INTERNAL_ERROR if the client * ping queue is full */ - auto Ping(const UUID& client_id) - -> tl::expected; + auto Ping(const UUID& client_id) -> tl::expected; /** * @brief Get the master service cluster ID to use as subdirectory name diff --git a/mooncake-store/include/rpc_service.h b/mooncake-store/include/rpc_service.h index 01d887a88..77cb9995f 100644 --- a/mooncake-store/include/rpc_service.h +++ b/mooncake-store/include/rpc_service.h @@ -80,8 +80,7 @@ class WrappedMasterService { tl::expected GetFsdir(); - tl::expected Ping( - const UUID& client_id); + tl::expected Ping(const UUID& client_id); private: MasterService master_service_; diff --git a/mooncake-store/include/segment.h b/mooncake-store/include/segment.h index 8e4ce619e..f3369f344 100644 --- a/mooncake-store/include/segment.h +++ b/mooncake-store/include/segment.h @@ -67,8 +67,8 @@ class ScopedSegmentAccess { /** * @brief Re-mount a segment. To avoid infinite remount trying, only the * errors that may be solved by subsequent remount tryings are considered as - * errors. When encounters unsolvable errors, the segment will not be mounted - * while the return value will be OK. + * errors. When encounters unsolvable errors, the segment will not be + * mounted while the return value will be OK. */ ErrorCode ReMountSegment(const std::vector& segments, const UUID& client_id); @@ -125,7 +125,7 @@ class ScopedAllocatorAccess { lock_(mutex) {} const std::unordered_map>>& + std::vector>>& getAllocatorsByName() { return allocators_by_name_; } @@ -136,7 +136,7 @@ class ScopedAllocatorAccess { private: const std::unordered_map>>& + std::vector>>& allocators_by_name_; // segment name -> allocators const std::vector>& allocators_; std::shared_lock lock_; @@ -148,7 +148,8 @@ class SegmentManager { * @brief Constructor for SegmentManager * @param memory_allocator Type of buffer allocator to use for new segments */ - explicit SegmentManager(BufferAllocatorType memory_allocator = BufferAllocatorType::CACHELIB) + explicit SegmentManager( + BufferAllocatorType memory_allocator = BufferAllocatorType::CACHELIB) : memory_allocator_(memory_allocator) {} /** @@ -171,20 +172,22 @@ class SegmentManager { private: mutable std::shared_mutex segment_mutex_; std::shared_ptr allocation_strategy_; - const BufferAllocatorType memory_allocator_; // Type of buffer allocator to use + const BufferAllocatorType + memory_allocator_; // Type of buffer allocator to use // Each allocator is put into both of allocators_by_name_ and allocators_. // These two containers only contain allocators whose segment status is OK. std::unordered_map>> allocators_by_name_; // segment name -> allocators - std::vector> allocators_; // allocators + std::vector> + allocators_; // allocators std::unordered_map> mounted_segments_; // segment_id -> mounted segment std::unordered_map, boost::hash> client_segments_; // client_id -> segment_ids friend class ScopedSegmentAccess; - friend class SegmentTest; // for unit tests + friend class SegmentTest; // for unit tests }; } // namespace mooncake \ No newline at end of file diff --git a/mooncake-store/include/storage_backend.h b/mooncake-store/include/storage_backend.h index eb01efd34..120bd6970 100644 --- a/mooncake-store/include/storage_backend.h +++ b/mooncake-store/include/storage_backend.h @@ -13,42 +13,47 @@ namespace mooncake { /** * @class StorageBackend - * @brief Implementation of StorageBackend interface using local filesystem storage. - * - * Provides thread-safe operations for storing and retrieving objects in a directory hierarchy. + * @brief Implementation of StorageBackend interface using local filesystem + * storage. + * + * Provides thread-safe operations for storing and retrieving objects in a + * directory hierarchy. */ -class StorageBackend { +class StorageBackend { public: - /** - * @brief Constructs a new StorageBackend instance - * @param root_dir Root directory path for object storage - * @param fsdir subdirectory name - * @note Directory existence is not checked in constructor - */ - #ifdef USE_3FS - explicit StorageBackend(const std::string& root_dir, const std::string& fsdir, bool is_3fs_dir) +/** + * @brief Constructs a new StorageBackend instance + * @param root_dir Root directory path for object storage + * @param fsdir subdirectory name + * @note Directory existence is not checked in constructor + */ +#ifdef USE_3FS + explicit StorageBackend(const std::string& root_dir, + const std::string& fsdir, bool is_3fs_dir) : root_dir_(root_dir), fsdir_(fsdir), is_3fs_dir_(is_3fs_dir) { - resource_manager_ = std::make_unique(); - Hf3fsConfig config; - config.mount_root = root_dir; - resource_manager_->setDefaultParams(config); + resource_manager_ = std::make_unique(); + Hf3fsConfig config; + config.mount_root = root_dir; + resource_manager_->setDefaultParams(config); } - #else - explicit StorageBackend(const std::string& root_dir, const std::string& fsdir) +#else + explicit StorageBackend(const std::string& root_dir, + const std::string& fsdir) : root_dir_(root_dir), fsdir_(fsdir) {} - #endif +#endif /** * @brief Factory method to create a StorageBackend instance * @param root_dir Root directory path for object storage * @param fsdir subdirectory name * @return shared_ptr to new instance or nullptr if directory is invalid - * + * * Performs validation of the root directory before creating the instance: * - Verifies directory exists * - Verifies path is actually a directory */ - static std::shared_ptr Create(const std::string& root_dir, const std::string& fsdir) { + static std::shared_ptr Create(const std::string& root_dir, + const std::string& fsdir) { namespace fs = std::filesystem; if (!fs::exists(root_dir)) { LOG(INFO) << "Root directory does not exist: " << root_dir; @@ -64,22 +69,24 @@ class StorageBackend { fs::path root_path(root_dir); std::string real_fsdir = "moon_" + fsdir; - #ifdef USE_3FS - bool is_3fs_dir = fs::exists(root_path / "3fs-virt") && - fs::is_directory(root_path / "3fs-virt"); - return std::make_shared(root_dir, real_fsdir, is_3fs_dir); - #else +#ifdef USE_3FS + bool is_3fs_dir = fs::exists(root_path / "3fs-virt") && + fs::is_directory(root_path / "3fs-virt"); + return std::make_shared(root_dir, real_fsdir, + is_3fs_dir); +#else return std::make_shared(root_dir, real_fsdir); - #endif - } - +#endif + } + /** * @brief Stores an object composed of multiple slices * @param key Object identifier * @param slices Vector of data slices to store * @return tl::expected indicating operation status */ - tl::expected StoreObject(const ObjectKey& key, const std::vector& slices) ; + tl::expected StoreObject(const ObjectKey& key, + const std::vector& slices); /** * @brief Stores an object from a string @@ -87,7 +94,8 @@ class StorageBackend { * @param str String containing object data * @return tl::expected indicating operation status */ - tl::expected StoreObject(const ObjectKey& key, const std::string& str) ; + tl::expected StoreObject(const ObjectKey& key, + const std::string& str); /** * @brief Stores an object from a span of data @@ -95,8 +103,9 @@ class StorageBackend { * @param data Span containing object data * @return tl::expected indicating operation status */ - tl::expected StoreObject(const ObjectKey& key, std::span data); - + tl::expected StoreObject(const ObjectKey& key, + std::span data); + /** * @brief Loads an object into slices * @param path KVCache File path to load from @@ -104,8 +113,10 @@ class StorageBackend { * @param length Expected length of data to read * @return tl::expected indicating operation status */ - tl::expected LoadObject(std::string& path, std::vector& slices, size_t length) ; - + tl::expected LoadObject(std::string& path, + std::vector& slices, + size_t length); + /** * @brief Loads an object as a string * @param path KVCache File path to load from @@ -113,20 +124,22 @@ class StorageBackend { * @param length Expected length of data to read * @return tl::expected indicating operation status */ - tl::expected LoadObject(std::string& path, std::string& str, size_t length) ; + tl::expected LoadObject(std::string& path, + std::string& str, size_t length); /** * @brief Checks if an object with the given key exists * @param key Object identifier * @return bool indicating whether the object exists */ - bool Existkey(const ObjectKey& key) ; + bool Existkey(const ObjectKey& key); /** * @brief Queries metadata for an object by key * @param key Object identifier - * @return Optional Replica::Descriptor containing object metadata, or empty if not found - * + * @return Optional Replica::Descriptor containing object metadata, or empty + * if not found + * * This method retrieves the file path and size for the given object key. */ std::optional Querykey(const ObjectKey& key); @@ -136,40 +149,39 @@ class StorageBackend { * @param keys Vector of object identifiers * @return unordered_map mapping ObjectKey to Replica::Descriptor */ - std::unordered_map BatchQueryKey(const std::vector& keys); + std::unordered_map BatchQueryKey( + const std::vector& keys); /** * @brief Deletes the physical file associated with the given object key * @param key Object identifier */ - void RemoveFile(const ObjectKey& key) ; + void RemoveFile(const ObjectKey& key); /** * @brief Deletes all objects from the storage backend - * + * * Removes all files in the cluster subdirectory. */ - void RemoveAll() ; + void RemoveAll(); - enum class FileMode { - Read, - Write - }; + enum class FileMode { Read, Write }; // Root directory path for storage and subdirectory name std::string root_dir_; std::string fsdir_; - #ifdef USE_3FS - bool is_3fs_dir_{false}; // Flag to indicate if the storage is using 3FS directory structure +#ifdef USE_3FS + bool is_3fs_dir_{false}; // Flag to indicate if the storage is using 3FS + // directory structure std::unique_ptr resource_manager_; - #endif +#endif private: /** * @brief Sanitizes object key for filesystem safety */ std::string SanitizeKey(const ObjectKey& key) const; - + /** * @brief Resolves full filesystem path for an object */ @@ -181,9 +193,8 @@ class StorageBackend { * @param mode File access mode (read/write) * @return Unique pointer to the created StorageFile, or nullptr on failure */ - std::unique_ptr create_file(const std::string& path, - FileMode mode) const; - + std::unique_ptr create_file(const std::string& path, + FileMode mode) const; }; } // namespace mooncake \ No newline at end of file diff --git a/mooncake-store/include/thread_pool.h b/mooncake-store/include/thread_pool.h index 614b00e25..f588f87ec 100644 --- a/mooncake-store/include/thread_pool.h +++ b/mooncake-store/include/thread_pool.h @@ -11,15 +11,15 @@ namespace mooncake { /** * @class ThreadPool * @brief A thread pool implementation for concurrent task execution. - * + * * Manages a fixed number of worker threads that process tasks from a queue. * Supports task enqueueing and graceful shutdown. */ class ThreadPool { -public: + public: /// Constructs a thread pool with specified number of worker threads explicit ThreadPool(size_t num_threads); - + /// Destructor (stops all threads and completes pending tasks) ~ThreadPool(); @@ -31,39 +31,39 @@ class ThreadPool { * @param args Arguments to forward to the callable * @throws std::runtime_error if enqueued after stop */ - template + template void enqueue(F&& f, Args&&... args); /// Stops the thread pool (waits for current tasks to complete) void stop(); -private: - std::vector workers; ///< Worker thread pool - std::queue> tasks; ///< Task queue - - std::mutex queue_mutex; ///< Protects task queue access - std::condition_variable condition; ///< Synchronizes task assignment - std::atomic stop_flag; ///< Termination signal + private: + std::vector workers; ///< Worker thread pool + std::queue> tasks; ///< Task queue + + std::mutex queue_mutex; ///< Protects task queue access + std::condition_variable condition; ///< Synchronizes task assignment + std::atomic stop_flag; ///< Termination signal }; /** * @brief Enqueues a task by wrapping it in a void() function object * @details Locks the queue, checks stop condition, and notifies a worker */ -template +template void ThreadPool::enqueue(F&& f, Args&&... args) { auto task = std::make_shared>( - [f = std::forward(f), ...args = std::forward(args)]() mutable { - std::invoke(f, args...); - } - ); + [f = std::forward(f), + ... args = std::forward(args)]() mutable { + std::invoke(f, args...); + }); { std::unique_lock lock(queue_mutex); - if(stop_flag) { + if (stop_flag) { throw std::runtime_error("enqueue on stopped ThreadPool"); } - tasks.emplace([task]{(*task)();}); + tasks.emplace([task] { (*task)(); }); } condition.notify_one(); ///< Wake one waiting worker } -} +} // namespace mooncake diff --git a/mooncake-store/include/transfer_task.h b/mooncake-store/include/transfer_task.h index 40b9590a1..c88c8863e 100644 --- a/mooncake-store/include/transfer_task.h +++ b/mooncake-store/include/transfer_task.h @@ -23,9 +23,9 @@ namespace mooncake { * @brief Transfer strategy enumeration */ enum class TransferStrategy { - LOCAL_MEMCPY = 0, // Local memory copy using memcpy + LOCAL_MEMCPY = 0, // Local memory copy using memcpy TRANSFER_ENGINE = 1, // Remote transfer using transfer engine - FILE_READ = 2 // File read operation + FILE_READ = 2 // File read operation }; /** @@ -126,7 +126,6 @@ class MemcpyOperationState : public OperationState { class FilereadOperationState : public OperationState { public: - bool is_completed() override { std::lock_guard lock(mutex_); return result_.has_value(); @@ -299,14 +298,13 @@ struct FilereadTask { std::vector slices; std::shared_ptr state; - FilereadTask(const std::string &path, - size_t size, - const std::vector& slices_ref, - std::shared_ptr s) - : file_path(path), - file_size(size), - slices(slices_ref), - state(std::move(s)) {} + FilereadTask(const std::string& path, size_t size, + const std::vector& slices_ref, + std::shared_ptr s) + : file_path(path), + file_size(size), + slices(slices_ref), + state(std::move(s)) {} }; /** @@ -370,8 +368,8 @@ class TransferSubmitter { * failure */ std::optional submit( - const Replica::Descriptor& replica, - std::vector& slices, Transport::TransferRequest::OpCode op_code); + const Replica::Descriptor& replica, std::vector& slices, + Transport::TransferRequest::OpCode op_code); private: TransferEngine& engine_; @@ -415,8 +413,8 @@ class TransferSubmitter { std::vector& slices, Transport::TransferRequest::OpCode op_code); std::optional submitFileReadOperation( - const Replica::Descriptor& replica, std::vector& slices, - Transport::TransferRequest::OpCode op_code); + const Replica::Descriptor& replica, std::vector& slices, + Transport::TransferRequest::OpCode op_code); }; } // namespace mooncake \ No newline at end of file diff --git a/mooncake-store/include/types.h b/mooncake-store/include/types.h index 76b4922d3..df9816cf1 100644 --- a/mooncake-store/include/types.h +++ b/mooncake-store/include/types.h @@ -89,7 +89,8 @@ enum class ErrorCode : int32_t { SEGMENT_ALREADY_EXISTS = -102, ///< Segment already exists. // Handle selection errors (Range: -200 to -299) - NO_AVAILABLE_HANDLE = -200, ///< Memory allocation failed due to insufficient space. + NO_AVAILABLE_HANDLE = + -200, ///< Memory allocation failed due to insufficient space. // Version errors (Range: -300 to -399) INVALID_VERSION = -300, ///< Invalid version. @@ -488,14 +489,15 @@ inline std::ostream& operator<<(std::ostream& os, struct PingResponse { ViewVersionId view_version_id; ClientStatus client_status; - + PingResponse() = default; PingResponse(ViewVersionId view_version, ClientStatus status) : view_version_id(view_version), client_status(status) {} - + friend std::ostream& operator<<(std::ostream& os, const PingResponse& response) noexcept { - return os << "PingResponse: { view_version_id: " << response.view_version_id + return os << "PingResponse: { view_version_id: " + << response.view_version_id << ", client_status: " << response.client_status << " }"; } }; diff --git a/mooncake-store/src/allocator.cpp b/mooncake-store/src/allocator.cpp index 694e92348..8b75ca4a3 100644 --- a/mooncake-store/src/allocator.cpp +++ b/mooncake-store/src/allocator.cpp @@ -22,8 +22,8 @@ AllocatedBuffer::~AllocatedBuffer() { } // Removed allocated_bytes parameter and member initialization -CachelibBufferAllocator::CachelibBufferAllocator(std::string segment_name, size_t base, - size_t size) +CachelibBufferAllocator::CachelibBufferAllocator(std::string segment_name, + size_t base, size_t size) : segment_name_(segment_name), base_(base), total_size_(size), @@ -60,7 +60,8 @@ CachelibBufferAllocator::CachelibBufferAllocator(std::string segment_name, size_ CachelibBufferAllocator::~CachelibBufferAllocator() = default; -std::unique_ptr CachelibBufferAllocator::allocate(size_t size) { +std::unique_ptr CachelibBufferAllocator::allocate( + size_t size) { void* buffer = nullptr; try { // Allocate memory using CacheLib. @@ -120,7 +121,8 @@ OffsetBufferAllocator::OffsetBufferAllocator(std::string segment_name, uint32_t max_allocs = size < (1ull << 32) ? size / 4096 : 1024 * 1024; // min(size / 4K, 1M) - max_allocs = std::max(max_allocs, 1024u * 64u); // at least 64K allocations + max_allocs = + std::max(max_allocs, 1024u * 64u); // at least 64K allocations // Create the offset allocator offset_allocator_ = offset_allocator::OffsetAllocator::create(base, size, max_allocs); diff --git a/mooncake-store/src/client.cpp b/mooncake-store/src/client.cpp index e5dd27db4..33a8d799c 100644 --- a/mooncake-store/src/client.cpp +++ b/mooncake-store/src/client.cpp @@ -136,7 +136,7 @@ static std::vector get_auto_discover_filters(bool auto_discover) { } tl::expected CheckRegisterMemoryParams(const void* addr, - size_t length) { + size_t length) { if (addr == nullptr) { LOG(ERROR) << "addr is nullptr"; return tl::unexpected(ErrorCode::INVALID_PARAMS); @@ -669,7 +669,8 @@ void Client::SubmitTransfers(std::vector& ops) { if (!transfer_submitter_) { LOG(ERROR) << "TransferSubmitter not initialized"; for (auto& op : ops) { - op.SetError(ErrorCode::INVALID_PARAMS, "TransferSubmitter not initialized"); + op.SetError(ErrorCode::INVALID_PARAMS, + "TransferSubmitter not initialized"); } return; } @@ -918,8 +919,7 @@ void Client::BatchPuttoLocalFile(std::vector& ops) { PutToLocalFile(op.key, op.slices); } else { LOG(ERROR) << "Skipping local file storage for key " << op.key - << " due to failure: " - << toString(op.result.error()); + << " due to failure: " << toString(op.result.error()); } } } @@ -1072,7 +1072,7 @@ tl::expected Client::unregisterLocalMemory( tl::expected Client::IsExist(const std::string& key) { auto result = master_client_.ExistKey(key); if (!result) { - if(storage_backend_) { + if (storage_backend_) { // If master query fails, check storage backend if (storage_backend_->Existkey(key)) { return true; // Key exists in storage backend @@ -1123,10 +1123,12 @@ void Client::PutToLocalFile(const std::string& key, total_size += slice.size; } - // Currently, persistence is achieved through asynchronous writes, but before asynchronous - // writing in 3FS, significant performance degradation may occur due to data copying. - // Profiling reveals that the number of page faults triggered in this scenario is nearly double the normal count. - // Future plans include introducing a reuse buffer list to address this performance degradation issue. + // Currently, persistence is achieved through asynchronous writes, but + // before asynchronous writing in 3FS, significant performance degradation + // may occur due to data copying. Profiling reveals that the number of page + // faults triggered in this scenario is nearly double the normal count. + // Future plans include introducing a reuse buffer list to address this + // performance degradation issue. std::string value; value.reserve(total_size); diff --git a/mooncake-store/src/etcd_helper.cpp b/mooncake-store/src/etcd_helper.cpp index e72320c4d..5417de5af 100644 --- a/mooncake-store/src/etcd_helper.cpp +++ b/mooncake-store/src/etcd_helper.cpp @@ -166,7 +166,8 @@ ErrorCode EtcdHelper::Get(const char* key, const size_t key_size, } ErrorCode EtcdHelper::CreateWithLease(const char* key, const size_t key_size, - const char* value, const size_t value_size, + const char* value, + const size_t value_size, EtcdLeaseId lease_id, EtcdRevisionId& revision_id) { LOG(FATAL) << "Etcd is not enabled in compilation"; diff --git a/mooncake-store/src/ha_helper.cpp b/mooncake-store/src/ha_helper.cpp index b10eb3937..12a48fd01 100644 --- a/mooncake-store/src/ha_helper.cpp +++ b/mooncake-store/src/ha_helper.cpp @@ -93,13 +93,12 @@ MasterServiceSupervisor::MasterServiceSupervisor( int rpc_port, size_t rpc_thread_num, bool enable_gc, bool enable_metric_reporting, int metrics_port, int64_t default_kv_lease_ttl, int64_t default_kv_soft_pin_ttl, - bool allow_evict_soft_pinned_objects, - double eviction_ratio, double eviction_high_watermark_ratio, - int64_t client_live_ttl_sec, const std::string& etcd_endpoints, - const std::string& local_hostname, const std::string& rpc_address, + bool allow_evict_soft_pinned_objects, double eviction_ratio, + double eviction_high_watermark_ratio, int64_t client_live_ttl_sec, + const std::string& etcd_endpoints, const std::string& local_hostname, + const std::string& rpc_address, std::chrono::steady_clock::duration rpc_conn_timeout, - bool rpc_enable_tcp_no_delay, - const std::string& cluster_id, + bool rpc_enable_tcp_no_delay, const std::string& cluster_id, BufferAllocatorType memory_allocator) : enable_gc_(enable_gc), enable_metric_reporting_(enable_metric_reporting), diff --git a/mooncake-store/src/hf3fs/hf3fs_file.cpp b/mooncake-store/src/hf3fs/hf3fs_file.cpp index 0c7be83b2..c50b094f3 100644 --- a/mooncake-store/src/hf3fs/hf3fs_file.cpp +++ b/mooncake-store/src/hf3fs/hf3fs_file.cpp @@ -7,7 +7,8 @@ namespace mooncake { -ThreeFSFile::ThreeFSFile(const std::string& filename, int fd, USRBIOResourceManager* resource_manager) +ThreeFSFile::ThreeFSFile(const std::string& filename, int fd, + USRBIOResourceManager* resource_manager) : StorageFile(filename, fd), resource_manager_(resource_manager) {} ThreeFSFile::~ThreeFSFile() { @@ -30,11 +31,13 @@ ThreeFSFile::~ThreeFSFile() { } } -tl::expected ThreeFSFile::write(const std::string& buffer, size_t length) { +tl::expected ThreeFSFile::write(const std::string& buffer, + size_t length) { return write(std::span(buffer.data(), length), length); } -tl::expected ThreeFSFile::write(std::span data, size_t length) { +tl::expected ThreeFSFile::write(std::span data, + size_t length) { // 1. Parameter validation if (length == 0) { return make_error(ErrorCode::FILE_INVALID_BUFFER); @@ -62,14 +65,16 @@ tl::expected ThreeFSFile::write(std::span data, s while (total_bytes_written < length) { // Calculate current chunk size - size_t chunk_size = std::min(length - total_bytes_written, max_chunk_size); + size_t chunk_size = + std::min(length - total_bytes_written, max_chunk_size); // Copy data to shared buffer memcpy(threefs_iov.base, data_ptr + total_bytes_written, chunk_size); // Prepare IO request - int ret = hf3fs_prep_io(&ior_write, &threefs_iov, false, - threefs_iov.base, fd_, current_offset, chunk_size, nullptr); + int ret = + hf3fs_prep_io(&ior_write, &threefs_iov, false, threefs_iov.base, + fd_, current_offset, chunk_size, nullptr); if (ret < 0) { return make_error(ErrorCode::FILE_WRITE_FAIL); } @@ -92,18 +97,19 @@ tl::expected ThreeFSFile::write(std::span data, s current_offset += bytes_written; if (bytes_written < chunk_size) { - break; // Short write, possibly disk full + break; // Short write, possibly disk full } } - if(total_bytes_written != length) { + if (total_bytes_written != length) { return make_error(ErrorCode::FILE_WRITE_FAIL); } return total_bytes_written; } -tl::expected ThreeFSFile::read(std::string& buffer, size_t length) { +tl::expected ThreeFSFile::read(std::string& buffer, + size_t length) { // 1. Parameter validation if (length == 0) { return make_error(ErrorCode::FILE_INVALID_BUFFER); @@ -132,14 +138,12 @@ tl::expected ThreeFSFile::read(std::string& buffer, size_t le // 5. Read in chunks while (total_bytes_read < length) { // Calculate current chunk size - size_t chunk_size = std::min( - length - total_bytes_read, - resource->config_.iov_size - ); + size_t chunk_size = std::min(length - total_bytes_read, + resource->config_.iov_size); // Prepare IO request - int ret = hf3fs_prep_io(&ior_read, &threefs_iov, true, - threefs_iov.base, fd_, current_offset, chunk_size, nullptr); + int ret = hf3fs_prep_io(&ior_read, &threefs_iov, true, threefs_iov.base, + fd_, current_offset, chunk_size, nullptr); if (ret < 0) { return make_error(ErrorCode::FILE_READ_FAIL); } @@ -158,7 +162,7 @@ tl::expected ThreeFSFile::read(std::string& buffer, size_t le } size_t bytes_read = cqe.result; - if (bytes_read == 0) { // EOF + if (bytes_read == 0) { // EOF break; } @@ -167,19 +171,21 @@ tl::expected ThreeFSFile::read(std::string& buffer, size_t le total_bytes_read += bytes_read; current_offset += bytes_read; - if (bytes_read < chunk_size) { // Short read + if (bytes_read < chunk_size) { // Short read break; } } - if(total_bytes_read != length) { + if (total_bytes_read != length) { return make_error(ErrorCode::FILE_READ_FAIL); } - + return total_bytes_read; } -tl::expected ThreeFSFile::vector_write(const iovec* iov, int iovcnt, off_t offset) { +tl::expected ThreeFSFile::vector_write(const iovec* iov, + int iovcnt, + off_t offset) { auto* resource = resource_manager_->getThreadResource(); if (!resource || !resource->initialized) { return make_error(ErrorCode::FILE_OPEN_FAIL); @@ -202,36 +208,34 @@ tl::expected ThreeFSFile::vector_write(const iovec* iov, int size_t total_bytes_written = 0; off_t current_offset = offset; size_t bytes_remaining = total_length; - int current_iov_index = 0; + int current_iov_index = 0; size_t current_iov_offset = 0; while (bytes_remaining > 0) { - // 2. Determine current write chunk size (not exceeding shared buffer size) - size_t current_chunk_size = std::min( - bytes_remaining, - resource->config_.iov_size - ); + // 2. Determine current write chunk size (not exceeding shared buffer + // size) + size_t current_chunk_size = + std::min(bytes_remaining, resource->config_.iov_size); // 3. Copy data from user IOV to shared buffer size_t bytes_copied = 0; char* dest_ptr = reinterpret_cast(threefs_iov.base); - - while (bytes_copied < current_chunk_size && current_iov_index < iovcnt) { + + while (bytes_copied < current_chunk_size && + current_iov_index < iovcnt) { const iovec* current_iov = &iov[current_iov_index]; - size_t copy_size = std::min( - current_chunk_size - bytes_copied, - current_iov->iov_len - current_iov_offset - ); + size_t copy_size = + std::min(current_chunk_size - bytes_copied, + current_iov->iov_len - current_iov_offset); - memcpy( - dest_ptr + bytes_copied, - reinterpret_cast(current_iov->iov_base) + current_iov_offset, - copy_size - ); + memcpy(dest_ptr + bytes_copied, + reinterpret_cast(current_iov->iov_base) + + current_iov_offset, + copy_size); bytes_copied += copy_size; current_iov_offset += copy_size; - + if (current_iov_offset >= current_iov->iov_len) { current_iov_index++; current_iov_offset = 0; @@ -239,8 +243,9 @@ tl::expected ThreeFSFile::vector_write(const iovec* iov, int } // 4. Prepare and submit IO request - int ret = hf3fs_prep_io(&ior_write, &threefs_iov, false, - threefs_iov.base, fd_, current_offset, current_chunk_size, nullptr); + int ret = + hf3fs_prep_io(&ior_write, &threefs_iov, false, threefs_iov.base, + fd_, current_offset, current_chunk_size, nullptr); if (ret < 0) { return make_error(ErrorCode::FILE_WRITE_FAIL); } @@ -263,14 +268,16 @@ tl::expected ThreeFSFile::vector_write(const iovec* iov, int current_offset += bytes_written; if (bytes_written < current_chunk_size) { - break; // Short write, possibly disk full + break; // Short write, possibly disk full } } return total_bytes_written; } -tl::expected ThreeFSFile::vector_read(const iovec* iov, int iovcnt, off_t offset) { +tl::expected ThreeFSFile::vector_read(const iovec* iov, + int iovcnt, + off_t offset) { auto* resource = resource_manager_->getThreadResource(); if (!resource || !resource->initialized) { return make_error(ErrorCode::FILE_OPEN_FAIL); @@ -296,14 +303,15 @@ tl::expected ThreeFSFile::vector_read(const iovec* iov, int i int current_iov_index = 0; size_t current_iov_offset = 0; - while(bytes_remaining > 0) { + while (bytes_remaining > 0) { // Determine current block size size_t current_chunk_size = std::min(bytes_remaining, resource->config_.iov_size); // Prepare IO request - int ret = hf3fs_prep_io(&ior_read, &threefs_iov, true, - threefs_iov.base, fd_, current_offset, current_chunk_size, nullptr); + int ret = + hf3fs_prep_io(&ior_read, &threefs_iov, true, threefs_iov.base, fd_, + current_offset, current_chunk_size, nullptr); if (ret < 0) { return make_error(ErrorCode::FILE_READ_FAIL); } @@ -325,19 +333,15 @@ tl::expected ThreeFSFile::vector_read(const iovec* iov, int i // Copy data from shared buffer to user IOV size_t bytes_to_copy = bytes_read; char* src_ptr = reinterpret_cast(threefs_iov.base); - + while (bytes_to_copy > 0 && current_iov_index < iovcnt) { const iovec* current_iov = &iov[current_iov_index]; size_t copy_size = std::min( - bytes_to_copy, - current_iov->iov_len - current_iov_offset - ); + bytes_to_copy, current_iov->iov_len - current_iov_offset); memcpy( static_cast(current_iov->iov_base) + current_iov_offset, - src_ptr, - copy_size - ); + src_ptr, copy_size); src_ptr += copy_size; bytes_to_copy -= copy_size; @@ -351,8 +355,9 @@ tl::expected ThreeFSFile::vector_read(const iovec* iov, int i current_iov_offset = 0; } } - if(bytes_read < current_chunk_size) { - // If bytes read is less than requested chunk size, we've reached EOF + if (bytes_read < current_chunk_size) { + // If bytes read is less than requested chunk size, we've reached + // EOF break; } } @@ -360,4 +365,4 @@ tl::expected ThreeFSFile::vector_read(const iovec* iov, int i return total_bytes_read; } -} // namespace mooncake \ No newline at end of file +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-store/src/hf3fs/hf3fs_resource_manager.cpp b/mooncake-store/src/hf3fs/hf3fs_resource_manager.cpp index 0af52c1fa..d59c662b9 100644 --- a/mooncake-store/src/hf3fs/hf3fs_resource_manager.cpp +++ b/mooncake-store/src/hf3fs/hf3fs_resource_manager.cpp @@ -12,16 +12,16 @@ bool ThreadUSRBIOResource::Initialize(const Hf3fsConfig &config) { this->config_ = config; // Create shared memory - int ret = - hf3fs_iovcreate(&iov_, config.mount_root.c_str(), config.iov_size, 0, -1); + int ret = hf3fs_iovcreate(&iov_, config.mount_root.c_str(), config.iov_size, + 0, -1); if (ret < 0) { return false; } // Create read I/O ring - ret = - hf3fs_iorcreate4(&ior_read_, config.mount_root.c_str(), config.ior_entries, - true, config.io_depth, config.ior_timeout, -1, 0); + ret = hf3fs_iorcreate4(&ior_read_, config.mount_root.c_str(), + config.ior_entries, true, config.io_depth, + config.ior_timeout, -1, 0); if (ret < 0) { hf3fs_iovdestroy(&iov_); return false; @@ -29,8 +29,8 @@ bool ThreadUSRBIOResource::Initialize(const Hf3fsConfig &config) { // Create write I/O ring ret = hf3fs_iorcreate4(&ior_write_, config.mount_root.c_str(), - config.ior_entries, false, config.io_depth, - config.ior_timeout, -1, 0); + config.ior_entries, false, config.io_depth, + config.ior_timeout, -1, 0); if (ret < 0) { hf3fs_iordestroy(&ior_read_); hf3fs_iovdestroy(&iov_); @@ -89,4 +89,4 @@ USRBIOResourceManager::~USRBIOResourceManager() { thread_resources.clear(); } -} \ No newline at end of file +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-store/src/master.cpp b/mooncake-store/src/master.cpp index dbc26e8ea..a9c8f080b 100644 --- a/mooncake-store/src/master.cpp +++ b/mooncake-store/src/master.cpp @@ -65,7 +65,8 @@ DEFINE_int64(client_ttl, mooncake::DEFAULT_CLIENT_LIVE_TTL_SEC, "used in HA mode"); DEFINE_string(cluster_id, mooncake::DEFAULT_CLUSTER_ID, - "Cluster ID for the master service, used for kvcache persistence in HA mode"); + "Cluster ID for the master service, used for kvcache persistence " + "in HA mode"); DEFINE_string(memory_allocator, "offset", "Memory allocator for global segments, cachelib | offset"); @@ -157,7 +158,8 @@ int main(int argc, char* argv[]) { } else if (FLAGS_memory_allocator == "offset") { allocator_type = mooncake::BufferAllocatorType::OFFSET; } else { - LOG(FATAL) << "Invalid memory allocator type: " << FLAGS_memory_allocator; + LOG(FATAL) << "Invalid memory allocator type: " + << FLAGS_memory_allocator; return 1; } diff --git a/mooncake-store/src/posix_file.cpp b/mooncake-store/src/posix_file.cpp index cceb553a2..4ebbae7c3 100644 --- a/mooncake-store/src/posix_file.cpp +++ b/mooncake-store/src/posix_file.cpp @@ -10,7 +10,8 @@ #include "file_interface.h" namespace mooncake { -PosixFile::PosixFile(const std::string& filename, int fd) : StorageFile(filename, fd) { +PosixFile::PosixFile(const std::string &filename, int fd) + : StorageFile(filename, fd) { if (fd < 0) { error_code_ = ErrorCode::FILE_INVALID_HANDLE; } @@ -29,21 +30,22 @@ PosixFile::~PosixFile() { } else { LOG(INFO) << "Deleted corrupted file: " << filename_; } - } + } } fd_ = -1; } -tl::expected PosixFile::write(const std::string &buffer, size_t length) { +tl::expected PosixFile::write(const std::string &buffer, + size_t length) { return write(std::span(buffer.data(), length), length); } -tl::expected PosixFile::write(std::span data, size_t length) { - +tl::expected PosixFile::write(std::span data, + size_t length) { if (fd_ < 0) { return make_error(ErrorCode::FILE_NOT_FOUND); } - + if (length == 0) { return make_error(ErrorCode::FILE_INVALID_BUFFER); } @@ -55,7 +57,7 @@ tl::expected PosixFile::write(std::span data, siz size_t remaining = length; size_t written_bytes = 0; - const char* ptr = data.data(); + const char *ptr = data.data(); while (remaining > 0) { ssize_t written = ::write(fd_, ptr, remaining); @@ -68,18 +70,18 @@ tl::expected PosixFile::write(std::span data, siz written_bytes += written; } - if(written_bytes != length) { + if (written_bytes != length) { return make_error(ErrorCode::FILE_WRITE_FAIL); } return written_bytes; } -tl::expected PosixFile::read(std::string &buffer, size_t length) { - +tl::expected PosixFile::read(std::string &buffer, + size_t length) { if (fd_ < 0) { return make_error(ErrorCode::FILE_NOT_FOUND); } - + if (length == 0) { return make_error(ErrorCode::FILE_INVALID_BUFFER); } @@ -91,7 +93,7 @@ tl::expected PosixFile::read(std::string &buffer, size_t leng buffer.resize(length); size_t read_bytes = 0; - char* ptr = buffer.data(); + char *ptr = buffer.data(); while (read_bytes < length) { ssize_t n = ::read(fd_, ptr, length - read_bytes); @@ -100,19 +102,21 @@ tl::expected PosixFile::read(std::string &buffer, size_t leng buffer.clear(); return make_error(ErrorCode::FILE_READ_FAIL); } - if (n == 0) break; // EOF + if (n == 0) break; // EOF read_bytes += n; ptr += n; } buffer.resize(read_bytes); - if(read_bytes != length) { + if (read_bytes != length) { return make_error(ErrorCode::FILE_READ_FAIL); } return read_bytes; } -tl::expected PosixFile::vector_write(const iovec *iov, int iovcnt, off_t offset) { +tl::expected PosixFile::vector_write(const iovec *iov, + int iovcnt, + off_t offset) { if (fd_ < 0) { return make_error(ErrorCode::FILE_NOT_FOUND); } @@ -130,7 +134,9 @@ tl::expected PosixFile::vector_write(const iovec *iov, int io return ret; } -tl::expected PosixFile::vector_read(const iovec *iov, int iovcnt, off_t offset) { +tl::expected PosixFile::vector_read(const iovec *iov, + int iovcnt, + off_t offset) { if (fd_ < 0) { return make_error(ErrorCode::FILE_NOT_FOUND); } @@ -148,4 +154,4 @@ tl::expected PosixFile::vector_read(const iovec *iov, int iov return ret; } -} // namespace mooncake \ No newline at end of file +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-store/src/rpc_service.cpp b/mooncake-store/src/rpc_service.cpp index 221cc3959..0bc85b07a 100644 --- a/mooncake-store/src/rpc_service.cpp +++ b/mooncake-store/src/rpc_service.cpp @@ -33,7 +33,8 @@ WrappedMasterService::WrappedMasterService( : master_service_(enable_gc, default_kv_lease_ttl, default_kv_soft_pin_ttl, allow_evict_soft_pinned_objects, eviction_ratio, eviction_high_watermark_ratio, view_version, - client_live_ttl_sec, enable_ha, cluster_id, memory_allocator), + client_live_ttl_sec, enable_ha, cluster_id, + memory_allocator), http_server_(4, http_port), metric_report_running_(enable_metric_reporting) { init_http_server(); @@ -450,8 +451,8 @@ tl::expected WrappedMasterService::GetFsdir() { return result; } -tl::expected -WrappedMasterService::Ping(const UUID& client_id) { +tl::expected WrappedMasterService::Ping( + const UUID& client_id) { ScopedVLogTimer timer(1, "Ping"); timer.LogRequest("client_id=", client_id); diff --git a/mooncake-store/src/segment.cpp b/mooncake-store/src/segment.cpp index ffe926cc1..ae634fde4 100644 --- a/mooncake-store/src/segment.cpp +++ b/mooncake-store/src/segment.cpp @@ -59,7 +59,8 @@ ErrorCode ScopedSegmentAccess::MountSegment(const Segment& segment, default: LOG(ERROR) << "segment_name=" << segment.name << ", error=unknown_memory_allocator=" - << static_cast(segment_manager_->memory_allocator_); + << static_cast( + segment_manager_->memory_allocator_); return ErrorCode::INVALID_PARAMS; } @@ -133,7 +134,8 @@ ErrorCode ScopedSegmentAccess::PrepareUnmountSegment( metrics_dec_capacity = segment.size; // Remove the allocator from the segment manager - std::shared_ptr allocator = mounted_segment.buf_allocator; + std::shared_ptr allocator = + mounted_segment.buf_allocator; // 1. Remove from allocators auto alloc_it = std::find(segment_manager_->allocators_.begin(), @@ -246,7 +248,7 @@ ErrorCode ScopedSegmentAccess::QuerySegments(const std::string& segment, VLOG(1) << "### DEBUG ### MasterService::QuerySegments(" << segment << ") not found!"; return ErrorCode::SEGMENT_NOT_FOUND; - } + } return ErrorCode::OK; } } // namespace mooncake \ No newline at end of file diff --git a/mooncake-store/src/storage_backend.cpp b/mooncake-store/src/storage_backend.cpp index 4ed88a0a2..e72d613c6 100644 --- a/mooncake-store/src/storage_backend.cpp +++ b/mooncake-store/src/storage_backend.cpp @@ -7,14 +7,13 @@ #include #include - namespace mooncake { - -tl::expected StorageBackend::StoreObject(const ObjectKey& key, - const std::vector& slices) { + +tl::expected StorageBackend::StoreObject( + const ObjectKey& key, const std::vector& slices) { std::string path = ResolvePath(key); - if(std::filesystem::exists(path)) { + if (std::filesystem::exists(path)) { return tl::make_unexpected(ErrorCode::FILE_OPEN_FAIL); } @@ -27,40 +26,42 @@ tl::expected StorageBackend::StoreObject(const ObjectKey& key, std::vector iovs; size_t slices_total_size = 0; for (const auto& slice : slices) { - iovec io{ slice.ptr, slice.size }; + iovec io{slice.ptr, slice.size}; iovs.push_back(io); slices_total_size += slice.size; } - auto write_result = file->vector_write(iovs.data(), static_cast(iovs.size()), 0); + auto write_result = + file->vector_write(iovs.data(), static_cast(iovs.size()), 0); if (!write_result) { - LOG(INFO) << "vector_write failed for: " << path << ", error: " << write_result.error(); + LOG(INFO) << "vector_write failed for: " << path + << ", error: " << write_result.error(); return tl::make_unexpected(write_result.error()); } if (*write_result != slices_total_size) { LOG(INFO) << "Write size mismatch for: " << path - << ", expected: " << slices_total_size - << ", got: " << *write_result; + << ", expected: " << slices_total_size + << ", got: " << *write_result; return tl::make_unexpected(ErrorCode::FILE_WRITE_FAIL); } return {}; } -tl::expected StorageBackend::StoreObject(const ObjectKey& key, - const std::string& str) { - return StoreObject(key, std::span(str.data(), str.size())); +tl::expected StorageBackend::StoreObject( + const ObjectKey& key, const std::string& str) { + return StoreObject(key, std::span(str.data(), str.size())); } -tl::expected StorageBackend::StoreObject(const ObjectKey& key, - std::span data) { +tl::expected StorageBackend::StoreObject( + const ObjectKey& key, std::span data) { std::string path = ResolvePath(key); if (std::filesystem::exists(path)) { return tl::make_unexpected(ErrorCode::FILE_OPEN_FAIL); } - + auto file = create_file(path, FileMode::Write); if (!file) { LOG(INFO) << "Failed to open file for writing: " << path; @@ -68,45 +69,47 @@ tl::expected StorageBackend::StoreObject(const ObjectKey& key, } size_t file_total_size = data.size(); - auto write_result = file->write(data, file_total_size); + auto write_result = file->write(data, file_total_size); if (!write_result) { - LOG(INFO) << "Write failed for: " << path << ", error: " << write_result.error(); + LOG(INFO) << "Write failed for: " << path + << ", error: " << write_result.error(); return tl::make_unexpected(write_result.error()); } if (*write_result != file_total_size) { LOG(INFO) << "Write size mismatch for: " << path - << ", expected: " << file_total_size - << ", got: " << *write_result; + << ", expected: " << file_total_size + << ", got: " << *write_result; return tl::make_unexpected(ErrorCode::FILE_WRITE_FAIL); } return {}; } -tl::expected StorageBackend::LoadObject(std::string& path, - std::vector& slices, size_t length) { +tl::expected StorageBackend::LoadObject( + std::string& path, std::vector& slices, size_t length) { auto file = create_file(path, FileMode::Read); if (!file) { LOG(INFO) << "Failed to open file for reading: " << path; return tl::make_unexpected(ErrorCode::FILE_OPEN_FAIL); } - std::vector iovs; + std::vector iovs; for (const auto& slice : slices) { - iovec io{ slice.ptr, slice.size }; + iovec io{slice.ptr, slice.size}; iovs.push_back(io); } - auto read_result = file->vector_read(iovs.data(), static_cast(iovs.size()), 0); + auto read_result = + file->vector_read(iovs.data(), static_cast(iovs.size()), 0); if (!read_result) { - LOG(INFO) << "vector_read failed for: " << path << ", error: " << read_result.error(); + LOG(INFO) << "vector_read failed for: " << path + << ", error: " << read_result.error(); return tl::make_unexpected(read_result.error()); } if (*read_result != length) { LOG(INFO) << "Read size mismatch for: " << path - << ", expected: " << length - << ", got: " << *read_result; + << ", expected: " << length << ", got: " << *read_result; return tl::make_unexpected(ErrorCode::FILE_READ_FAIL); } @@ -114,7 +117,8 @@ tl::expected StorageBackend::LoadObject(std::string& path, } tl::expected StorageBackend::LoadObject(std::string& path, - std::string& str, size_t length) { + std::string& str, + size_t length) { auto file = create_file(path, FileMode::Read); if (!file) { LOG(INFO) << "Failed to open file for reading: " << path; @@ -123,13 +127,13 @@ tl::expected StorageBackend::LoadObject(std::string& path, auto read_result = file->read(str, length); if (!read_result) { - LOG(INFO) << "read failed for: " << path << ", error: " << read_result.error(); + LOG(INFO) << "read failed for: " << path + << ", error: " << read_result.error(); return tl::make_unexpected(read_result.error()); } if (*read_result != length) { LOG(INFO) << "Read size mismatch for: " << path - << ", expected: " << length - << ", got: " << *read_result; + << ", expected: " << length << ", got: " << *read_result; return tl::make_unexpected(ErrorCode::FILE_READ_FAIL); } @@ -148,7 +152,8 @@ bool StorageBackend::Existkey(const ObjectKey& key) { } } -std::optional StorageBackend::Querykey(const ObjectKey& key) { +std::optional StorageBackend::Querykey( + const ObjectKey& key) { std::string path = ResolvePath(key); namespace fs = std::filesystem; @@ -163,7 +168,7 @@ std::optional StorageBackend::Querykey(const ObjectKey& key disk_desc.file_path = path; disk_desc.file_size = fs::file_size(path); desc.status = ReplicaStatus::COMPLETE; - + return desc; } @@ -177,7 +182,7 @@ StorageBackend::BatchQueryKey(const std::vector& keys) { if (!fs::exists(path)) { LOG(WARNING) << "Key not found: " << key << ", skipping..."; - return {}; + return {}; } Replica::Descriptor desc; @@ -195,18 +200,21 @@ StorageBackend::BatchQueryKey(const std::vector& keys) { void StorageBackend::RemoveFile(const ObjectKey& key) { std::string path = ResolvePath(key); namespace fs = std::filesystem; - // TODO: attention: this function is not thread-safe, need to add lock if used in multi-thread environment - // Check if the file exists before attempting to remove it - // TODO: add a sleep to ensure the write thread has time to create the corresponding file - // it will be fixed in the next version - std::this_thread::sleep_for(std::chrono::microseconds(50)); //sleep for 50 us + // TODO: attention: this function is not thread-safe, need to add lock if + // used in multi-thread environment Check if the file exists before + // attempting to remove it + // TODO: add a sleep to ensure the write thread has time to create the + // corresponding file it will be fixed in the next version + std::this_thread::sleep_for( + std::chrono::microseconds(50)); // sleep for 50 us if (fs::exists(path)) { std::error_code ec; fs::remove(path, ec); if (ec) { - LOG(ERROR) << "Failed to delete file: " << path << ", error: " << ec.message(); + LOG(ERROR) << "Failed to delete file: " << path + << ", error: " << ec.message(); } - } + } } void StorageBackend::RemoveAll() { @@ -215,13 +223,13 @@ void StorageBackend::RemoveAll() { for (const auto& entry : fs::directory_iterator(root_dir_)) { if (fs::is_regular_file(entry.status())) { std::error_code ec; - fs::remove(entry.path(),ec); + fs::remove(entry.path(), ec); if (ec) { - LOG(ERROR) << "Failed to delete file: " << entry.path() << ", error: " << ec.message(); + LOG(ERROR) << "Failed to delete file: " << entry.path() + << ", error: " << ec.message(); } } } - } std::string StorageBackend::SanitizeKey(const ObjectKey& key) const { @@ -229,40 +237,43 @@ std::string StorageBackend::SanitizeKey(const ObjectKey& key) const { constexpr std::string_view kInvalidChars = "/\\:*?\"<>|"; std::string sanitized_key; sanitized_key.reserve(key.size()); - + for (char c : key) { // Replace invalid characters with underscore sanitized_key.push_back( - kInvalidChars.find(c) != std::string_view::npos ? '_' : c - ); + kInvalidChars.find(c) != std::string_view::npos ? '_' : c); } return sanitized_key; } std::string StorageBackend::ResolvePath(const ObjectKey& key) const { - // Compute hash of the key + // Compute hash of the key size_t hash = std::hash{}(key); - + // Use low 8 bits to create 2-level directory structure (e.g. "a1/b2") - char dir1 = static_cast('a' + (hash & 0x0F)); // Lower 4 bits -> 16 dirs - char dir2 = static_cast('a' + ((hash >> 4) & 0x0F)); // Next 4 bits -> 16 subdirs - + char dir1 = + static_cast('a' + (hash & 0x0F)); // Lower 4 bits -> 16 dirs + char dir2 = static_cast( + 'a' + ((hash >> 4) & 0x0F)); // Next 4 bits -> 16 subdirs + // Safely construct path using std::filesystem namespace fs = std::filesystem; - fs::path dir_path = fs::path(root_dir_) / fsdir_ / std::string(1, dir1) / std::string(1, dir2); + fs::path dir_path = fs::path(root_dir_) / fsdir_ / std::string(1, dir1) / + std::string(1, dir2); // Create directory if not exists std::error_code ec; if (!fs::exists(dir_path)) { if (!fs::create_directories(dir_path, ec) && ec) { - LOG(INFO) << "Failed to create directory: " << dir_path << ", error: " << ec.message(); - return ""; // Empty string indicates failure + LOG(INFO) << "Failed to create directory: " << dir_path + << ", error: " << ec.message(); + return ""; // Empty string indicates failure } } // Combine directory path with sanitized filename fs::path full_path = dir_path / SanitizeKey(key); - + return full_path.lexically_normal().string(); } @@ -278,10 +289,10 @@ std::unique_ptr StorageBackend::create_file( access_mode = O_WRONLY | O_CREAT | O_TRUNC; break; } - + int fd = open(path.c_str(), flags | access_mode, 0644); if (fd < 0) { - return nullptr; + return nullptr; } #ifdef USE_3FS @@ -290,8 +301,9 @@ std::unique_ptr StorageBackend::create_file( close(fd); return nullptr; } - return resource_manager_ ? - std::make_unique(path, fd, resource_manager_.get()) : nullptr; + return resource_manager_ ? std::make_unique( + path, fd, resource_manager_.get()) + : nullptr; } #endif diff --git a/mooncake-store/src/thread_pool.cpp b/mooncake-store/src/thread_pool.cpp index 1f7f70a7b..e69a5f4c8 100644 --- a/mooncake-store/src/thread_pool.cpp +++ b/mooncake-store/src/thread_pool.cpp @@ -3,19 +3,18 @@ #include namespace mooncake { ThreadPool::ThreadPool(size_t num_threads) : stop_flag(false) { - for(size_t i = 0; i < num_threads; ++i) { + for (size_t i = 0; i < num_threads; ++i) { workers.emplace_back([this] { - while(true) { + while (true) { std::function task; { std::unique_lock lock(this->queue_mutex); this->condition.wait(lock, [this] { return this->stop_flag || !this->tasks.empty(); }); - - if(this->stop_flag && this->tasks.empty()) - return; - + + if (this->stop_flag && this->tasks.empty()) return; + task = std::move(this->tasks.front()); this->tasks.pop(); } @@ -31,13 +30,10 @@ void ThreadPool::stop() { stop_flag = true; } condition.notify_all(); - for(std::thread &worker : workers) { - if(worker.joinable()) - worker.join(); + for (std::thread &worker : workers) { + if (worker.joinable()) worker.join(); } } -ThreadPool::~ThreadPool() { - stop(); -} -} \ No newline at end of file +ThreadPool::~ThreadPool() { stop(); } +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-store/src/transfer_task.cpp b/mooncake-store/src/transfer_task.cpp index 820817abf..4ce4dbbee 100644 --- a/mooncake-store/src/transfer_task.cpp +++ b/mooncake-store/src/transfer_task.cpp @@ -12,10 +12,12 @@ namespace mooncake { // ============================================================================ // FilereadWorkerPool Implementation // ============================================================================ -//to fully utilize the available ssd bandwidth, we use a default of 10 worker threads. +// to fully utilize the available ssd bandwidth, we use a default of 10 worker +// threads. constexpr int kDefaultFilereadWorkers = 10; -FilereadWorkerPool::FilereadWorkerPool(std::shared_ptr& backend) : shutdown_(false) { +FilereadWorkerPool::FilereadWorkerPool(std::shared_ptr& backend) + : shutdown_(false) { VLOG(1) << "Creating FilereadWorkerPool with " << kDefaultFilereadWorkers << " workers"; @@ -86,20 +88,22 @@ void FilereadWorkerPool::workerThread() { if (task.state) { try { if (!backend_) { - LOG(ERROR) << "Backend is not initialized, cannot load object"; + LOG(ERROR) + << "Backend is not initialized, cannot load object"; task.state->set_completed(ErrorCode::TRANSFER_FAIL); continue; } - auto load_result = backend_->LoadObject(task.file_path, task.slices, task.file_size); + auto load_result = backend_->LoadObject( + task.file_path, task.slices, task.file_size); if (load_result) { - VLOG(2) << "Fileread task completed successfully with " - << task.file_path ; + VLOG(2) << "Fileread task completed successfully with " + << task.file_path; task.state->set_completed(ErrorCode::OK); } else { - LOG(ERROR) << "Fileread task failed for file: " - << task.file_path - << " with error: " << toString(load_result.error()); + LOG(ERROR) + << "Fileread task failed for file: " << task.file_path + << " with error: " << toString(load_result.error()); task.state->set_completed(ErrorCode::TRANSFER_FAIL); } } catch (const std::exception& e) { @@ -389,10 +393,9 @@ TransferSubmitter::TransferSubmitter(TransferEngine& engine, } std::optional TransferSubmitter::submit( - const Replica::Descriptor& replica, - std::vector& slices, Transport::TransferRequest::OpCode op_code) { - - if(replica.is_memory_replica()) { + const Replica::Descriptor& replica, std::vector& slices, + Transport::TransferRequest::OpCode op_code) { + if (replica.is_memory_replica()) { std::vector handles; auto& mem_desc = replica.get_memory_descriptor(); handles = mem_desc.buffer_descriptors; @@ -412,7 +415,7 @@ std::optional TransferSubmitter::submit( LOG(ERROR) << "Unknown transfer strategy: " << strategy; return std::nullopt; } - }else{ + } else { return submitFileReadOperation(replica, slices, op_code); } } @@ -515,7 +518,7 @@ std::optional TransferSubmitter::submitTransferEngineOperation( } std::optional TransferSubmitter::submitFileReadOperation( - const Replica::Descriptor& replica, std::vector& slices, + const Replica::Descriptor& replica, std::vector& slices, Transport::TransferRequest::OpCode op_code) { auto state = std::make_shared(); auto disk_replica = replica.get_disk_descriptor(); @@ -526,8 +529,7 @@ std::optional TransferSubmitter::submitFileReadOperation( FilereadTask task(file_path, file_length, slices, state); fileread_pool_->submitTask(std::move(task)); - VLOG(1) << "Fileread transfer submitted to worker pool with " - << file_path ; + VLOG(1) << "Fileread transfer submitted to worker pool with " << file_path; return TransferFuture(state); } diff --git a/mooncake-store/src/types.cpp b/mooncake-store/src/types.cpp index 99392b264..f8d6b0fb5 100644 --- a/mooncake-store/src/types.cpp +++ b/mooncake-store/src/types.cpp @@ -32,7 +32,8 @@ const std::string& toString(ErrorCode errorCode) noexcept { {ErrorCode::ETCD_KEY_NOT_EXIST, "ETCD_KEY_NOT_EXIST"}, {ErrorCode::ETCD_TRANSACTION_FAIL, "ETCD_TRANSACTION_FAIL"}, {ErrorCode::ETCD_CTX_CANCELLED, "ETCD_CTX_CANCELLED"}, - {ErrorCode::UNAVAILABLE_IN_CURRENT_STATUS, "UNAVAILABLE_IN_CURRENT_STATUS"}, + {ErrorCode::UNAVAILABLE_IN_CURRENT_STATUS, + "UNAVAILABLE_IN_CURRENT_STATUS"}, {ErrorCode::UNAVAILABLE_IN_CURRENT_MODE, "UNAVAILABLE_IN_CURRENT_MODE"}, {ErrorCode::FILE_NOT_FOUND, "FILE_NOT_FOUND"}, {ErrorCode::FILE_OPEN_FAIL, "FILE_OPEN_FAIL"}, @@ -40,8 +41,7 @@ const std::string& toString(ErrorCode errorCode) noexcept { {ErrorCode::FILE_WRITE_FAIL, "FILE_WRITE_FAIL"}, {ErrorCode::FILE_INVALID_BUFFER, "FILE_INVALID_BUFFER"}, {ErrorCode::FILE_LOCK_FAIL, "FILE_LOCK_FAIL"}, - {ErrorCode::FILE_INVALID_HANDLE, "FILE_INVALID_HANDLE"} - }; + {ErrorCode::FILE_INVALID_HANDLE, "FILE_INVALID_HANDLE"}}; auto it = errorCodeMap.find(errorCode); static const std::string unknownError = "UNKNOWN_ERROR"; @@ -61,7 +61,8 @@ UUID generate_uuid() { boost::uuids::random_generator gen; boost::uuids::uuid uuid = gen(); std::memcpy(&pair_uuid.first, uuid.data, sizeof(uint64_t)); - std::memcpy(&pair_uuid.second, uuid.data + sizeof(uint64_t), sizeof(uint64_t)); + std::memcpy(&pair_uuid.second, uuid.data + sizeof(uint64_t), + sizeof(uint64_t)); return pair_uuid; } diff --git a/mooncake-store/src/utils.cpp b/mooncake-store/src/utils.cpp index 8cf85032b..d9fbb7c60 100644 --- a/mooncake-store/src/utils.cpp +++ b/mooncake-store/src/utils.cpp @@ -4,7 +4,7 @@ #include namespace mooncake { -void* allocate_buffer_allocator_memory(size_t total_size) { +void *allocate_buffer_allocator_memory(size_t total_size) { const size_t alignment = facebook::cachelib::Slab::kSize; // Ensure total_size is a multiple of alignment if (total_size < alignment) { diff --git a/mooncake-store/tests/buffer_allocator_test.cpp b/mooncake-store/tests/buffer_allocator_test.cpp index 92501b763..bbd1bc6a3 100644 --- a/mooncake-store/tests/buffer_allocator_test.cpp +++ b/mooncake-store/tests/buffer_allocator_test.cpp @@ -63,7 +63,8 @@ TEST_F(BufferAllocatorTest, AllocateAndDeallocate) { for (const auto& allocator_type : allocator_types_) { std::string segment_name = "1"; size_t size = 1024 * 1024 * 16; // 16MB (multiple of 4MB) - auto allocator = CreateTestAllocator(segment_name, 0, size, allocator_type); + auto allocator = + CreateTestAllocator(segment_name, 0, size, allocator_type); // Allocate memory block size_t alloc_size = 1024; @@ -83,14 +84,15 @@ TEST_F(BufferAllocatorTest, AllocateMultiple) { for (const auto& allocator_type : allocator_types_) { std::string segment_name = "1"; size_t size = 1024 * 1024 * 16; // 16MB (must be multiple of 4MB) - auto allocator = CreateTestAllocator(segment_name, 0, size, allocator_type); + auto allocator = + CreateTestAllocator(segment_name, 0, size, allocator_type); // Allocate multiple memory blocks size_t alloc_size = 1024 * 1024; // 1MB per block std::vector> handles; - // Attempt to allocate 8 blocks (should succeed as total size is less than - // buffer size) + // Attempt to allocate 8 blocks (should succeed as total size is less + // than buffer size) for (int i = 0; i < 8; ++i) { auto bufHandle = allocator->allocate(alloc_size); ASSERT_NE(bufHandle, nullptr); @@ -110,7 +112,8 @@ TEST_F(BufferAllocatorTest, AllocateTooLarge) { std::string segment_name = "3"; size_t size = 1024 * 1024 * 16; // 16MB (must be multiple of 4MB) - auto allocator = CreateTestAllocator(segment_name, 0x20000000ULL, size, allocator_type); + auto allocator = CreateTestAllocator(segment_name, 0x20000000ULL, size, + allocator_type); // Attempt to allocate more than total buffer size size_t alloc_size = size + 1; @@ -143,22 +146,26 @@ TEST_F(BufferAllocatorTest, ParallelAllocation) { for (const auto& allocator_type : allocator_types_) { std::string segment_name = "test"; size_t size = 1024 * 1024 * 16; // 16MB (must be multiple of 4MB) - auto allocator = CreateTestAllocator(segment_name, 0x20000000ULL, size, allocator_type); + auto allocator = CreateTestAllocator(segment_name, 0x20000000ULL, size, + allocator_type); const int num_threads = 4; const auto test_duration = std::chrono::seconds(1); std::vector threads; - // Create 4 threads, each performing repeated allocation and deallocation for 1 second + // Create 4 threads, each performing repeated allocation and + // deallocation for 1 second for (int thread_id = 0; thread_id < num_threads; ++thread_id) { - threads.emplace_back([this, &allocator, test_duration, segment_name]() { + threads.emplace_back([this, &allocator, test_duration, + segment_name]() { auto start_time = std::chrono::steady_clock::now(); - - while (std::chrono::steady_clock::now() - start_time < test_duration) { + + while (std::chrono::steady_clock::now() - start_time < + test_duration) { // Allocate memory of varying sizes size_t alloc_size = 477; auto bufHandle = allocator->allocate(alloc_size); - + ASSERT_NE(bufHandle, nullptr); VerifyAllocatedBuffer(*bufHandle, alloc_size, segment_name); } @@ -171,7 +178,9 @@ TEST_F(BufferAllocatorTest, ParallelAllocation) { } LOG(INFO) << "Completed parallel allocation/deallocation test for " - << (allocator_type == BufferAllocatorType::CACHELIB ? "CACHELIB" : "OFFSET"); + << (allocator_type == BufferAllocatorType::CACHELIB + ? "CACHELIB" + : "OFFSET"); } } diff --git a/mooncake-store/tests/e2e/chaosctl.cpp b/mooncake-store/tests/e2e/chaosctl.cpp index 4c47c622b..eaf22faeb 100644 --- a/mooncake-store/tests/e2e/chaosctl.cpp +++ b/mooncake-store/tests/e2e/chaosctl.cpp @@ -185,18 +185,18 @@ int main(int argc, char* argv[]) { std::vector> clients; for (int i = 0; i < client_num; ++i) { - mooncake::testing::ClientRunnerConfig client_config{ - .put_prob = std::nullopt, - .get_prob = std::nullopt, - .mount_prob = std::nullopt, - .unmount_prob = std::nullopt, - .port = 17812 + i, - .master_server_entry = "etcd://" + FLAGS_etcd_endpoints, - .engine_meta_url = FLAGS_engine_meta_url, - .protocol = FLAGS_protocol, - .device_name = FLAGS_device_name, - }; - clients.emplace_back( + mooncake::testing::ClientRunnerConfig client_config{ + .put_prob = std::nullopt, + .get_prob = std::nullopt, + .mount_prob = std::nullopt, + .unmount_prob = std::nullopt, + .port = 17812 + i, + .master_server_entry = "etcd://" + FLAGS_etcd_endpoints, + .engine_meta_url = FLAGS_engine_meta_url, + .protocol = FLAGS_protocol, + .device_name = FLAGS_device_name, + }; + clients.emplace_back( std::make_unique( FLAGS_client_path, i, FLAGS_out_dir, client_config)); clients.back()->start(); diff --git a/mooncake-store/tests/e2e/client_wrapper.cpp b/mooncake-store/tests/e2e/client_wrapper.cpp index 7f93007d3..7c4146f8b 100644 --- a/mooncake-store/tests/e2e/client_wrapper.cpp +++ b/mooncake-store/tests/e2e/client_wrapper.cpp @@ -44,7 +44,8 @@ ClientTestWrapper::CreateClientWrapper(const std::string& hostname, auto register_result = client_opt.value()->RegisterLocalMemory( allocator->getBase(), local_buffer_size, "cpu:0", false, false); - ErrorCode error_code = register_result.has_value() ? ErrorCode::OK : register_result.error(); + ErrorCode error_code = + register_result.has_value() ? ErrorCode::OK : register_result.error(); if (error_code != ErrorCode::OK) { LOG(ERROR) << "register_local_memory_failed base=" << allocator->getBase() << " size=" << local_buffer_size @@ -62,7 +63,8 @@ ErrorCode ClientTestWrapper::Mount(const size_t size, void*& buffer) { } auto mount_result = client_->MountSegment(buffer, size); - ErrorCode error_code = mount_result.has_value() ? ErrorCode::OK : mount_result.error(); + ErrorCode error_code = + mount_result.has_value() ? ErrorCode::OK : mount_result.error(); if (error_code != ErrorCode::OK) { free(buffer); return error_code; @@ -80,7 +82,8 @@ ErrorCode ClientTestWrapper::Unmount(const void* buffer) { } SegmentInfo& segment = it->second; auto unmount_result = client_->UnmountSegment(segment.base, segment.size); - ErrorCode error_code = unmount_result.has_value() ? ErrorCode::OK : unmount_result.error(); + ErrorCode error_code = + unmount_result.has_value() ? ErrorCode::OK : unmount_result.error(); if (error_code != ErrorCode::OK) { return error_code; } else { @@ -97,7 +100,7 @@ ErrorCode ClientTestWrapper::Get(const std::string& key, std::string& value) { if (!query_result.has_value()) { return query_result.error(); } - + auto replica_list = query_result.value(); if (replica_list.empty()) { return ErrorCode::OBJECT_NOT_FOUND; @@ -110,7 +113,8 @@ ErrorCode ClientTestWrapper::Get(const std::string& key, std::string& value) { // Perform get operation auto get_result = client_->Get(key, replica_list, slice_guard.slices_); - ErrorCode error_code = get_result.has_value() ? ErrorCode::OK : get_result.error(); + ErrorCode error_code = + get_result.has_value() ? ErrorCode::OK : get_result.error(); if (error_code != ErrorCode::OK) { return error_code; } diff --git a/mooncake-store/tests/e2e/clientctl.cpp b/mooncake-store/tests/e2e/clientctl.cpp index 11043535b..ec1b7d69b 100644 --- a/mooncake-store/tests/e2e/clientctl.cpp +++ b/mooncake-store/tests/e2e/clientctl.cpp @@ -12,8 +12,8 @@ #include "e2e_utils.h" // Command line flags -USE_engine_flags -DEFINE_string(master_server_entry, "localhost:50051", "Master server address"); +USE_engine_flags DEFINE_string(master_server_entry, "localhost:50051", + "Master server address"); namespace mooncake { namespace testing { @@ -72,8 +72,8 @@ class ClientCtl { std::string hostname = "localhost:" + port; auto client_opt = ClientTestWrapper::CreateClientWrapper( - hostname, FLAGS_engine_meta_url, FLAGS_protocol, - FLAGS_device_name, FLAGS_master_server_entry); + hostname, FLAGS_engine_meta_url, FLAGS_protocol, FLAGS_device_name, + FLAGS_master_server_entry); if (!client_opt.has_value()) { std::cout << "Failed to create client: " << name << std::endl; @@ -194,14 +194,14 @@ class ClientCtl { iss >> seconds; if (seconds <= 0) { - std::cout << "Invalid sleep command format. Expected: sleep [seconds]" - << std::endl; + std::cout + << "Invalid sleep command format. Expected: sleep [seconds]" + << std::endl; return; } std::this_thread::sleep_for(std::chrono::seconds(seconds)); - std::cout << "Slept for " << seconds << " seconds" - << std::endl; + std::cout << "Slept for " << seconds << " seconds" << std::endl; } std::unordered_map clients_; diff --git a/mooncake-store/tests/e2e/e2e_utils.h b/mooncake-store/tests/e2e/e2e_utils.h index c0be7400a..5ce5e17f6 100644 --- a/mooncake-store/tests/e2e/e2e_utils.h +++ b/mooncake-store/tests/e2e/e2e_utils.h @@ -24,7 +24,7 @@ namespace testing { #define FLAG_master_path \ DEFINE_string(master_path, "./mooncake-store/src/mooncake_master", \ "Path to the master executable"); -#define FLAG_client_path \ +#define FLAG_client_path \ DEFINE_string(client_path, "./mooncake-store/tests/e2e/client_runner", \ "Path to the client executable"); #define FLAG_out_dir \ diff --git a/mooncake-store/tests/eviction_strategy_test.cpp b/mooncake-store/tests/eviction_strategy_test.cpp index 61d5d0a85..f99a378c7 100644 --- a/mooncake-store/tests/eviction_strategy_test.cpp +++ b/mooncake-store/tests/eviction_strategy_test.cpp @@ -23,7 +23,6 @@ class EvictionStrategyTest : public ::testing::Test { } }; - // Test LRUEvictionStrategy AddKey and RemoveKey functionality TEST_F(EvictionStrategyTest, AddAndRemoveKey) { LRUEvictionStrategy eviction_strategy; @@ -87,7 +86,9 @@ TEST_F(EvictionStrategyTest, FIFOAddAndRemoveKey) { EXPECT_EQ(eviction_strategy.GetSize(), 2); // Remove a key - EXPECT_EQ(eviction_strategy.RemoveKey("key1"), ErrorCode::OK); // FIFO not support remove a randomly accessed key + EXPECT_EQ( + eviction_strategy.RemoveKey("key1"), + ErrorCode::OK); // FIFO not support remove a randomly accessed key EXPECT_EQ(eviction_strategy.GetSize(), 2); // Clean up diff --git a/mooncake-store/tests/posix_file_test.cpp b/mooncake-store/tests/posix_file_test.cpp index f15762d98..543c36aa3 100644 --- a/mooncake-store/tests/posix_file_test.cpp +++ b/mooncake-store/tests/posix_file_test.cpp @@ -8,11 +8,11 @@ namespace mooncake { class PosixFileTest : public ::testing::Test { -protected: + protected: void SetUp() override { google::InitGoogleLogging("PosixFileTest"); FLAGS_logtostderr = 1; - + // Create and open a test file test_filename = "test_file.txt"; test_fd = open(test_filename.c_str(), O_CREAT | O_RDWR, 0644); @@ -41,11 +41,12 @@ TEST_F(PosixFileTest, FileLifecycle) { // Test basic write operation TEST_F(PosixFileTest, BasicWrite) { PosixFile posix_file(test_filename, test_fd); - + std::string test_data = "Test write data"; auto result = posix_file.write(test_data, test_data.size()); - - ASSERT_TRUE(result) << "Write failed with error: " << toString(result.error()); + + ASSERT_TRUE(result) << "Write failed with error: " + << toString(result.error()); EXPECT_EQ(*result, test_data.size()); EXPECT_EQ(posix_file.get_error_code(), ErrorCode::OK); } @@ -59,15 +60,18 @@ TEST_F(PosixFileTest, BasicRead) { // Write test data const char* test_data = "Test read data"; ssize_t written = write(test_fd, test_data, strlen(test_data)); - ASSERT_EQ(written, static_cast(strlen(test_data))) << "Write failed"; + ASSERT_EQ(written, static_cast(strlen(test_data))) + << "Write failed"; ASSERT_NE(lseek(test_fd, 0, SEEK_SET), -1) << "Seek failed"; - + PosixFile posix_file(test_filename, test_fd); - + std::string buffer; - auto result = posix_file.read(buffer, strlen(test_data)); // Read up to test_data bytes + auto result = posix_file.read( + buffer, strlen(test_data)); // Read up to test_data bytes - ASSERT_TRUE(result) << "Read failed with error: " << toString(result.error()); + ASSERT_TRUE(result) << "Read failed with error: " + << toString(result.error()); EXPECT_EQ(*result, strlen(test_data)); EXPECT_EQ(buffer, test_data); EXPECT_EQ(posix_file.get_error_code(), ErrorCode::OK); @@ -76,19 +80,20 @@ TEST_F(PosixFileTest, BasicRead) { // Test vectorized write operation TEST_F(PosixFileTest, VectorizedWrite) { PosixFile posix_file(test_filename, test_fd); - + std::string data1 = "First part "; std::string data2 = "Second part"; - + iovec iov[2]; iov[0].iov_base = const_cast(data1.data()); iov[0].iov_len = data1.size(); iov[1].iov_base = const_cast(data2.data()); iov[1].iov_len = data2.size(); - + auto result = posix_file.vector_write(iov, 2, 0); - - ASSERT_TRUE(result) << "Vector write failed with error: " << toString(result.error()); + + ASSERT_TRUE(result) << "Vector write failed with error: " + << toString(result.error()); EXPECT_EQ(*result, data1.size() + data2.size()); EXPECT_EQ(posix_file.get_error_code(), ErrorCode::OK); } @@ -102,23 +107,25 @@ TEST_F(PosixFileTest, VectorizedRead) { // Write test data const char* test_data = "Vectorized read test data"; ssize_t written = write(test_fd, test_data, strlen(test_data)); - ASSERT_EQ(written, static_cast(strlen(test_data))) << "Write failed"; + ASSERT_EQ(written, static_cast(strlen(test_data))) + << "Write failed"; ASSERT_NE(lseek(test_fd, 0, SEEK_SET), -1) << "Seek failed"; - + PosixFile posix_file(test_filename, test_fd); - + char buf1[11] = {0}; // "Vectorized" + null char buf2[16] = {0}; // " read test data" + null - + iovec iov[2]; iov[0].iov_base = buf1; iov[0].iov_len = 10; // Exact length of "Vectorized" iov[1].iov_base = buf2; iov[1].iov_len = 15; // Exact length of " read test data" - + auto result = posix_file.vector_read(iov, 2, 0); - - ASSERT_TRUE(result) << "Vector read failed with error: " << toString(result.error()); + + ASSERT_TRUE(result) << "Vector read failed with error: " + << toString(result.error()); EXPECT_EQ(*result, strlen(test_data)); EXPECT_STREQ(buf1, "Vectorized"); EXPECT_STREQ(buf2, " read test data"); @@ -130,13 +137,13 @@ TEST_F(PosixFileTest, ErrorCases) { // Test invalid file descriptor PosixFile posix_file("invalid.txt", -1); EXPECT_EQ(posix_file.get_error_code(), ErrorCode::FILE_INVALID_HANDLE); - + // Test write to invalid file std::string test_data = "test"; auto write_result = posix_file.write(test_data, test_data.size()); EXPECT_FALSE(write_result); EXPECT_EQ(write_result.error(), ErrorCode::FILE_NOT_FOUND); - + // Test read from invalid file std::string buffer; auto read_result = posix_file.read(buffer, test_data.size()); @@ -147,18 +154,18 @@ TEST_F(PosixFileTest, ErrorCases) { // Test file locking TEST_F(PosixFileTest, FileLocking) { PosixFile posix_file(test_filename, test_fd); - + { // Acquire write lock auto lock = posix_file.acquire_write_lock(); EXPECT_TRUE(lock.is_locked()); - + // Try to read while locked std::string buffer; auto result = posix_file.read(buffer, 10); EXPECT_FALSE(result); } - + { // Acquire read lock auto lock = posix_file.acquire_read_lock(); diff --git a/mooncake-store/tests/stress_workload_test.cpp b/mooncake-store/tests/stress_workload_test.cpp index 6e1955c75..bcd7d6508 100644 --- a/mooncake-store/tests/stress_workload_test.cpp +++ b/mooncake-store/tests/stress_workload_test.cpp @@ -87,7 +87,8 @@ void cleanup_segment() { auto result = g_client->UnmountSegment(g_segment_ptr, g_ram_buffer_size); if (!result.has_value()) { - LOG(ERROR) << "Failed to unmount segment: " << toString(result.error()); + LOG(ERROR) << "Failed to unmount segment: " + << toString(result.error()); } } } diff --git a/mooncake-store/tests/thread_pool_test.cpp b/mooncake-store/tests/thread_pool_test.cpp index 6e4b976b8..0787530d6 100644 --- a/mooncake-store/tests/thread_pool_test.cpp +++ b/mooncake-store/tests/thread_pool_test.cpp @@ -9,15 +9,13 @@ namespace mooncake { class ThreadPoolTest : public ::testing::Test { -protected: + protected: void SetUp() override { google::InitGoogleLogging("ThreadPoolTest"); FLAGS_logtostderr = 1; } - void TearDown() override { - google::ShutdownGoogleLogging(); - } + void TearDown() override { google::ShutdownGoogleLogging(); } }; // Test basic task execution @@ -59,8 +57,9 @@ TEST_F(ThreadPoolTest, ParallelExecution) { pool.enqueue([&]() { int current = ++running_threads; int old_max = max_concurrent_threads.load(); - while (old_max < current && - !max_concurrent_threads.compare_exchange_weak(old_max, current)) { + while (old_max < current && + !max_concurrent_threads.compare_exchange_weak(old_max, + current)) { // Keep trying to update max } @@ -107,9 +106,7 @@ TEST_F(ThreadPoolTest, ProperStop) { pool.stop(); EXPECT_EQ(counter.load(), total_tasks); - EXPECT_THROW({ - pool.enqueue([](){}); - }, std::runtime_error); + EXPECT_THROW({ pool.enqueue([]() {}); }, std::runtime_error); } // Test stress with many tasks @@ -136,8 +133,7 @@ TEST_F(ThreadPoolTest, StressTest) { EXPECT_EQ(counter.load(), num_tasks); } - -} // namespace mooncake +} // namespace mooncake int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); diff --git a/mooncake-transfer-engine/example/transfer_engine_ascend_one_sided.cpp b/mooncake-transfer-engine/example/transfer_engine_ascend_one_sided.cpp index c9ae097da..babf671b3 100644 --- a/mooncake-transfer-engine/example/transfer_engine_ascend_one_sided.cpp +++ b/mooncake-transfer-engine/example/transfer_engine_ascend_one_sided.cpp @@ -45,7 +45,8 @@ DEFINE_bool(auto_discovery, false, "Enable auto discovery"); DEFINE_uint64(device_id, 65536, "The device logic and phy ID of this machine"); DEFINE_uint64(device_logicid, 0, "The device logic ID of this machine"); DEFINE_uint64(device_phyid, 0, "The device phy ID of this machine"); -DEFINE_string(segment_id_1, "NA", "A segment ID that a sender wants to another receiver"); +DEFINE_string(segment_id_1, "NA", + "A segment ID that a sender wants to another receiver"); DEFINE_uint64(recv_num, 1, "Num of coonections received by the receiver"); DEFINE_uint64(send_index, 0, "which one is sent to the same receiver"); DEFINE_string(report_unit, "GB", "Report unit: GB|GiB|Gb|MB|MiB|Mb|KB|KiB|Kb"); @@ -68,7 +69,8 @@ const static std::unordered_map RATE_UNIT_MP = { {"KiB", 1ull << 10}, {"Kb", 1000ull / 8}}; -static inline std::string calculateRate(uint64_t data_bytes, uint64_t duration) { +static inline std::string calculateRate(uint64_t data_bytes, + uint64_t duration) { if (!RATE_UNIT_MP.count(FLAGS_report_unit)) { LOG(WARNING) << "Invalid flag: report_unit only support " "GB|GiB|Gb|MB|MiB|Mb|KB|KiB|Kb, not support " @@ -78,12 +80,13 @@ static inline std::string calculateRate(uint64_t data_bytes, uint64_t duration) } std::ostringstream oss; oss << std::fixed << std::setprecision(FLAGS_report_precision) - << 1.0 * data_bytes * 1000000 / duration / RATE_UNIT_MP.at(FLAGS_report_unit) + << 1.0 * data_bytes * 1000000 / duration / + RATE_UNIT_MP.at(FLAGS_report_unit) << " " << FLAGS_report_unit << "/s"; return oss.str(); } -int allocateDevMem(void* &devAddr, size_t size) { +int allocateDevMem(void *&devAddr, size_t size) { // malloc device mem aclError ret = aclrtMalloc(&devAddr, size, ACL_MEM_MALLOC_NORMAL_ONLY); if (ret != ACL_ERROR_NONE) { @@ -92,7 +95,7 @@ int allocateDevMem(void* &devAddr, size_t size) { } // malloc host mem - void* host_addr = nullptr; + void *host_addr = nullptr; ret = aclrtMallocHost(&host_addr, size); if (ret != ACL_ERROR_NONE || host_addr == nullptr) { LOG(ERROR) << "Failed to allocate device memory, ret:" << ret; @@ -100,25 +103,26 @@ int allocateDevMem(void* &devAddr, size_t size) { } for (size_t i = 0; i < size; i += sizeof(uint32_t)) { - *(uint32_t*)((char *)host_addr + i) = 0x12345678; + *(uint32_t *)((char *)host_addr + i) = 0x12345678; } // copy data from host mem to device mem - ret = aclrtMemcpy(devAddr, size, host_addr, size, ACL_MEMCPY_HOST_TO_DEVICE); + ret = + aclrtMemcpy(devAddr, size, host_addr, size, ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_ERROR_NONE) { LOG(ERROR) << "Failed to copy data from host to device, ret: " << ret; aclrtFreeHost(host_addr); aclrtFree(devAddr); return ret; } - + // release resource ret = aclrtFreeHost(host_addr); if (ret != ACL_ERROR_NONE) { LOG(ERROR) << "Failed to aclrtFreeHost, ret: " << ret; return ret; } - + return 0; } @@ -133,10 +137,12 @@ int initiator() { auto engine = std::make_unique(FLAGS_auto_discovery); auto hostname_port = parseHostNameWithPort(FLAGS_local_server_name); - std::string FLAGS_local_server_name_npu = hostname_port.first + ":" + std::to_string(hostname_port.second) + ":npu_" + std::to_string(g_devicePhyId); + std::string FLAGS_local_server_name_npu = + hostname_port.first + ":" + std::to_string(hostname_port.second) + + ":npu_" + std::to_string(g_devicePhyId); engine->init(FLAGS_metadata_server, FLAGS_local_server_name_npu.c_str(), hostname_port.first.c_str(), hostname_port.second); - + void *devAddr = nullptr; ret = allocateDevMem(devAddr, FLAGS_block_size * FLAGS_batch_size); if (ret) { @@ -147,7 +153,7 @@ int initiator() { LOG(INFO) << "devAddr_initiator: " << devAddr; ret = engine->registerLocalMemory(devAddr, g_TotalSize, - "npu:" + std::to_string(g_devicePhyId)); + "npu:" + std::to_string(g_devicePhyId)); if (ret) { LOG(ERROR) << "Failed to registerLocalMemory, ret: " << ret; return ret; @@ -163,7 +169,7 @@ int initiator() { LOG(INFO) << "devAddr_initiator2: " << devAddr2; ret = engine->registerLocalMemory(devAddr2, g_TotalSize, - "npu:" + std::to_string(g_devicePhyId)); + "npu:" + std::to_string(g_devicePhyId)); if (ret) { LOG(ERROR) << "Failed to registerLocalMemory, ret: " << ret; return ret; @@ -189,8 +195,7 @@ int initiator() { LOG(ERROR) << "Unable to get target segment ID, please recheck"; return -1; } - uint64_t remote_base = - (uint64_t)segment_desc->buffers[0].addr; + uint64_t remote_base = (uint64_t)segment_desc->buffers[0].addr; auto batch_id = engine->allocateBatchID(FLAGS_batch_size); Status s; @@ -201,7 +206,8 @@ int initiator() { entry.length = FLAGS_block_size; entry.source = (uint8_t *)(devAddr) + FLAGS_block_size * i; entry.target_id = segment_id; - entry.target_offset = remote_base + FLAGS_block_size * i + g_TotalSize * FLAGS_send_index; + entry.target_offset = + remote_base + FLAGS_block_size * i + g_TotalSize * FLAGS_send_index; requests.emplace_back(entry); } @@ -228,8 +234,7 @@ int initiator() { LOG(INFO) << "The First Time Send OK"; - uint64_t remote_base2 = - (uint64_t)segment_desc->buffers[1].addr; + uint64_t remote_base2 = (uint64_t)segment_desc->buffers[1].addr; auto batch_id_2 = engine->allocateBatchID(FLAGS_batch_size); std::vector requests2; @@ -239,7 +244,8 @@ int initiator() { entry.length = FLAGS_block_size; entry.source = (uint8_t *)(devAddr2) + FLAGS_block_size * i; entry.target_id = segment_id; - entry.target_offset = remote_base2 + FLAGS_block_size * i + g_TotalSize * FLAGS_send_index; + entry.target_offset = remote_base2 + FLAGS_block_size * i + + g_TotalSize * FLAGS_send_index; requests2.emplace_back(entry); } completed = false; @@ -266,16 +272,16 @@ int initiator() { gettimeofday(&stop_tv, nullptr); uint64_t duration = (stop_tv.tv_sec - start_tv.tv_sec) * 1000000.0 + - (stop_tv.tv_usec - start_tv.tv_usec); + (stop_tv.tv_usec - start_tv.tv_usec); LOG(INFO) << "Test completed: duration " << duration << "us, batch count " << FLAGS_batch_size * FLAGS_block_size << ", throughput " - << calculateRate( - FLAGS_batch_size * FLAGS_block_size, - duration); - - // When testing 1-to-2 transmission (1 initiator to 2 targets), fill in the segment_id of the second receiver. - // If not filled, it defaults to "NA" and 1-to-2 transmission is not enabled, only 1-to-1 transmission is performed. + << calculateRate(FLAGS_batch_size * FLAGS_block_size, duration); + + // When testing 1-to-2 transmission (1 initiator to 2 targets), fill in the + // segment_id of the second receiver. If not filled, it defaults to "NA" and + // 1-to-2 transmission is not enabled, only 1-to-1 transmission is + // performed. if (FLAGS_segment_id_1 != "NA") { sleep(10); auto segment_id_1 = engine->openSegment(FLAGS_segment_id_1.c_str()); @@ -290,13 +296,13 @@ int initiator() { return -1; } - auto segment_desc_1 = engine->getMetadata()->getSegmentDescByID(segment_id_1); + auto segment_desc_1 = + engine->getMetadata()->getSegmentDescByID(segment_id_1); if (!segment_desc_1) { LOG(ERROR) << "Unable to get target segment ID, please recheck"; return -1; } - uint64_t remote_base_1 = - (uint64_t)segment_desc_1->buffers[0].addr; + uint64_t remote_base_1 = (uint64_t)segment_desc_1->buffers[0].addr; auto batch_id = engine->allocateBatchID(FLAGS_batch_size); std::vector requests; @@ -306,7 +312,8 @@ int initiator() { entry.length = FLAGS_block_size; entry.source = (uint8_t *)(devAddr) + FLAGS_block_size * i; entry.target_id = segment_id_1; - entry.target_offset = remote_base_1 + FLAGS_block_size * i + g_TotalSize * FLAGS_send_index; + entry.target_offset = remote_base_1 + FLAGS_block_size * i + + g_TotalSize * FLAGS_send_index; requests.emplace_back(entry); } @@ -346,17 +353,18 @@ int target() { aclrtContext context = nullptr; aclError ret = aclrtCreateContext(&context, g_deviceLogicId); if (ret != ACL_ERROR_NONE) { - LOG(ERROR) <<"Failed to create context, ret: " << ret; + LOG(ERROR) << "Failed to create context, ret: " << ret; return -1; } auto engine = std::make_unique(FLAGS_auto_discovery); auto hostname_port = parseHostNameWithPort(FLAGS_local_server_name); - std::string FLAGS_local_server_name_npu = hostname_port.first + ":" + std::to_string(hostname_port.second) + ":npu_" + std::to_string(g_devicePhyId); + std::string FLAGS_local_server_name_npu = + hostname_port.first + ":" + std::to_string(hostname_port.second) + + ":npu_" + std::to_string(g_devicePhyId); engine->init(FLAGS_metadata_server, FLAGS_local_server_name_npu.c_str(), hostname_port.first.c_str(), hostname_port.second); - void *devAddr = nullptr; ret = allocateDevMem(devAddr, FLAGS_block_size * FLAGS_batch_size); @@ -368,7 +376,7 @@ int target() { LOG(INFO) << "devAddr_target: " << devAddr; ret = engine->registerLocalMemory(devAddr, g_TotalSize * FLAGS_recv_num, - "npu:" + std::to_string(g_devicePhyId)); + "npu:" + std::to_string(g_devicePhyId)); if (ret) { LOG(ERROR) << "Failed to registerLocalMemory, ret: " << ret; return ret; @@ -384,7 +392,7 @@ int target() { LOG(INFO) << "devAddr_target_2: " << devAddr2; ret = engine->registerLocalMemory(devAddr2, g_TotalSize * FLAGS_recv_num, - "npu:" + std::to_string(g_devicePhyId)); + "npu:" + std::to_string(g_devicePhyId)); if (ret) { LOG(ERROR) << "Failed to registerLocalMemory, ret: " << ret; return ret; @@ -395,14 +403,14 @@ int target() { // release resource aclrtFree(devAddr); aclrtFree(devAddr2); - + return 0; } int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, false); g_TotalSize = (uint64_t)(FLAGS_batch_size * FLAGS_block_size); - + if (FLAGS_device_id != 65536) { g_deviceLogicId = FLAGS_device_id; g_devicePhyId = FLAGS_device_id; diff --git a/mooncake-transfer-engine/example/transfer_engine_ascend_perf.cpp b/mooncake-transfer-engine/example/transfer_engine_ascend_perf.cpp index 024f33efb..b35473f1b 100644 --- a/mooncake-transfer-engine/example/transfer_engine_ascend_perf.cpp +++ b/mooncake-transfer-engine/example/transfer_engine_ascend_perf.cpp @@ -66,7 +66,8 @@ const static std::unordered_map RATE_UNIT_MP = { {"KiB", 1ull << 10}, {"Kb", 1000ull / 8}}; -static inline std::string calculateRate(uint64_t data_bytes, uint64_t duration) { +static inline std::string calculateRate(uint64_t data_bytes, + uint64_t duration) { if (!RATE_UNIT_MP.count(FLAGS_report_unit)) { LOG(WARNING) << "Invalid flag: report_unit only support " "GB|GiB|Gb|MB|MiB|Mb|KB|KiB|Kb, not support " @@ -76,12 +77,13 @@ static inline std::string calculateRate(uint64_t data_bytes, uint64_t duration) } std::ostringstream oss; oss << std::fixed << std::setprecision(FLAGS_report_precision) - << 1.0 * data_bytes * 1000000 / duration / RATE_UNIT_MP.at(FLAGS_report_unit) + << 1.0 * data_bytes * 1000000 / duration / + RATE_UNIT_MP.at(FLAGS_report_unit) << " " << FLAGS_report_unit << "/s"; return oss.str(); } -int allocateDevMem(void* &devAddr, size_t size) { +int allocateDevMem(void *&devAddr, size_t size) { // malloc device mem aclError ret = aclrtMalloc(&devAddr, size, ACL_MEM_MALLOC_NORMAL_ONLY); if (ret != ACL_ERROR_NONE) { @@ -90,7 +92,7 @@ int allocateDevMem(void* &devAddr, size_t size) { } // malloc host mem - void* host_addr = nullptr; + void *host_addr = nullptr; ret = aclrtMallocHost(&host_addr, size); if (ret != ACL_ERROR_NONE || host_addr == nullptr) { LOG(ERROR) << "Failed to allocate device memory, ret:" << ret; @@ -98,25 +100,26 @@ int allocateDevMem(void* &devAddr, size_t size) { } for (size_t i = 0; i < size; i += sizeof(uint32_t)) { - *(uint32_t*)((char *)host_addr + i) = 0x12345678; + *(uint32_t *)((char *)host_addr + i) = 0x12345678; } // copy data from host mem to device mem - ret = aclrtMemcpy(devAddr, size, host_addr, size, ACL_MEMCPY_HOST_TO_DEVICE); + ret = + aclrtMemcpy(devAddr, size, host_addr, size, ACL_MEMCPY_HOST_TO_DEVICE); if (ret != ACL_ERROR_NONE) { LOG(ERROR) << "Failed to copy data from host to device, ret: " << ret; aclrtFreeHost(host_addr); aclrtFree(devAddr); return ret; } - - //release resource + + // release resource ret = aclrtFreeHost(host_addr); if (ret != ACL_ERROR_NONE) { LOG(ERROR) << "Failed to aclrtFreeHost, ret: " << ret; return ret; } - + return 0; } @@ -131,10 +134,12 @@ int initiator() { auto engine = std::make_unique(FLAGS_auto_discovery); auto hostname_port = parseHostNameWithPort(FLAGS_local_server_name); - std::string FLAGS_local_server_name_new = hostname_port.first + ":" + std::to_string(hostname_port.second) + ":npu_" + std::to_string(g_devicePhyId); + std::string FLAGS_local_server_name_new = + hostname_port.first + ":" + std::to_string(hostname_port.second) + + ":npu_" + std::to_string(g_devicePhyId); engine->init(FLAGS_metadata_server, FLAGS_local_server_name_new.c_str(), hostname_port.first.c_str(), hostname_port.second); - + // Warm-up transmission void *tmp_devAddr = NULL; ret = allocateDevMem(tmp_devAddr, FLAGS_block_size); @@ -143,22 +148,25 @@ int initiator() { return ret; } - LOG(INFO) << "tmp_devAddr_target: " << tmp_devAddr << ", len: " << FLAGS_block_size; + LOG(INFO) << "tmp_devAddr_target: " << tmp_devAddr + << ", len: " << FLAGS_block_size; ret = engine->registerLocalMemory(tmp_devAddr, FLAGS_block_size, - "npu:" + std::to_string(g_devicePhyId)); + "npu:" + std::to_string(g_devicePhyId)); void *devAddr = NULL; std::vector g_addr; - for (uint32_t i = 0; i < FLAGS_block_iteration; i++ ) { + for (uint32_t i = 0; i < FLAGS_block_iteration; i++) { uint64_t block_size = FLAGS_block_size * (1 << i); ret = allocateDevMem(devAddr, FLAGS_batch_size * block_size * 2); if (ret) { LOG(ERROR) << "Failed to allocateDevMem, ret: " << ret; return -1; } - LOG(INFO) << "dev_addr_initiator: " << devAddr << " len:" << FLAGS_batch_size * block_size * 2; - ret = engine->registerLocalMemory(devAddr, FLAGS_batch_size * block_size * 2, - "npu:" + std::to_string(g_devicePhyId)); + LOG(INFO) << "dev_addr_initiator: " << devAddr + << " len:" << FLAGS_batch_size * block_size * 2; + ret = engine->registerLocalMemory( + devAddr, FLAGS_batch_size * block_size * 2, + "npu:" + std::to_string(g_devicePhyId)); if (ret) { LOG(ERROR) << "Failed to registerLocalMemory, ret: " << ret; return ret; @@ -195,7 +203,7 @@ int initiator() { entry.length = FLAGS_block_size; entry.source = (uint8_t *)tmp_devAddr; entry.target_id = segment_id; - entry.target_offset = remote_base; + entry.target_offset = remote_base; tmp_requests.emplace_back(entry); s = engine->submitTransfer(tmp_batch_id, tmp_requests); @@ -226,14 +234,15 @@ int initiator() { remote_base = (uint64_t)segment_desc->buffers[i + 1].addr; auto batch_id = engine->allocateBatchID(FLAGS_batch_size); std::vector requests; - // Send every other block to ensure that all the sent memory is non-contiguous + // Send every other block to ensure that all the sent memory is + // non-contiguous for (int j = 0; j < FLAGS_batch_size; ++j) { TransferRequest entry; entry.opcode = opcode; entry.length = block_size; entry.source = (uint8_t *)(g_addr[i]) + block_size * 2 * j; entry.target_id = segment_id; - entry.target_offset = remote_base + block_size * 2 * j; + entry.target_offset = remote_base + block_size * 2 * j; requests.emplace_back(entry); } s = engine->submitTransfer(batch_id, requests); @@ -255,14 +264,12 @@ int initiator() { } gettimeofday(&stop_tv, nullptr); uint64_t duration = (stop_tv.tv_sec - start_tv.tv_sec) * 1000000.0 + - (stop_tv.tv_usec - start_tv.tv_usec); - - LOG(INFO) << "Test completed: duration " << duration << "us, block size " - << block_size / 1024 << "KB, total size " - << FLAGS_batch_size * block_size / 1024 << "KB , throughput " - << calculateRate( - FLAGS_batch_size * block_size, - duration); + (stop_tv.tv_usec - start_tv.tv_usec); + + LOG(INFO) << "Test completed: duration " << duration + << "us, block size " << block_size / 1024 << "KB, total size " + << FLAGS_batch_size * block_size / 1024 << "KB , throughput " + << calculateRate(FLAGS_batch_size * block_size, duration); s = engine->freeBatchID(batch_id); LOG_ASSERT(s.ok()); } @@ -280,14 +287,16 @@ int target() { aclrtContext context = nullptr; aclError ret = aclrtCreateContext(&context, g_deviceLogicId); if (ret != ACL_ERROR_NONE) { - LOG(ERROR) <<"Failed to create context, ret: " << ret; + LOG(ERROR) << "Failed to create context, ret: " << ret; return -1; } auto engine = std::make_unique(FLAGS_auto_discovery); auto hostname_port = parseHostNameWithPort(FLAGS_local_server_name); - std::string FLAGS_local_server_name_new = hostname_port.first + ":" + std::to_string(hostname_port.second) + ":npu_" + std::to_string(g_devicePhyId); + std::string FLAGS_local_server_name_new = + hostname_port.first + ":" + std::to_string(hostname_port.second) + + ":npu_" + std::to_string(g_devicePhyId); engine->init(FLAGS_metadata_server, FLAGS_local_server_name_new.c_str(), hostname_port.first.c_str(), hostname_port.second); @@ -299,9 +308,10 @@ int target() { return ret; } - LOG(INFO) << "tmp_devAddr_target: " << tmp_devAddr << ", len: " << FLAGS_block_size; + LOG(INFO) << "tmp_devAddr_target: " << tmp_devAddr + << ", len: " << FLAGS_block_size; ret = engine->registerLocalMemory(tmp_devAddr, FLAGS_block_size, - "npu:" + std::to_string(g_devicePhyId)); + "npu:" + std::to_string(g_devicePhyId)); void *devAddr = NULL; std::vector g_addr; @@ -313,9 +323,11 @@ int target() { return ret; } - LOG(INFO) << "devAddr_target: " << devAddr << ", len: " << FLAGS_batch_size * block_size * 2; - ret = engine->registerLocalMemory(devAddr, FLAGS_batch_size * block_size * 2, - "npu:" + std::to_string(g_devicePhyId)); + LOG(INFO) << "devAddr_target: " << devAddr + << ", len: " << FLAGS_batch_size * block_size * 2; + ret = engine->registerLocalMemory( + devAddr, FLAGS_batch_size * block_size * 2, + "npu:" + std::to_string(g_devicePhyId)); if (ret) { LOG(ERROR) << "Failed to registerLocalMemory, ret: " << ret; return ret; @@ -331,7 +343,7 @@ int target() { for (uint32_t i = 0; i < FLAGS_block_iteration; i++) { aclrtFree(g_addr[i]); } - + return 0; } diff --git a/mooncake-transfer-engine/example/transfer_engine_bench.cpp b/mooncake-transfer-engine/example/transfer_engine_bench.cpp index 76d4a64ae..ac4ecd4d0 100644 --- a/mooncake-transfer-engine/example/transfer_engine_bench.cpp +++ b/mooncake-transfer-engine/example/transfer_engine_bench.cpp @@ -315,10 +315,12 @@ int initiator() { int gpu_num; LOG(INFO) << "VRAM is used"; if (FLAGS_gpu_id == -1 && cudaGetDeviceCount(&gpu_num) == cudaSuccess) { - LOG(INFO) << "GPU ID is not specified, found " << gpu_num << " GPUs to use"; + LOG(INFO) << "GPU ID is not specified, found " << gpu_num + << " GPUs to use"; buffer_num = gpu_num; } else { - LOG(INFO) << "GPU ID is specified or failed to get GPU count, use " << FLAGS_gpu_id << " GPU"; + LOG(INFO) << "GPU ID is specified or failed to get GPU count, use " + << FLAGS_gpu_id << " GPU"; buffer_num = 1; } } else { @@ -377,7 +379,6 @@ int initiator() { (stop_tv.tv_usec - start_tv.tv_usec) / 1000000.0; auto batch_count = total_batch_count.load(); - LOG(INFO) << "Test completed: duration " << std::fixed << std::setprecision(2) << duration << ", batch count " << batch_count << ", throughput " diff --git a/mooncake-transfer-engine/example/transfer_engine_bench_with_retry.cpp b/mooncake-transfer-engine/example/transfer_engine_bench_with_retry.cpp index 8f052ef1e..e2f2dfa28 100644 --- a/mooncake-transfer-engine/example/transfer_engine_bench_with_retry.cpp +++ b/mooncake-transfer-engine/example/transfer_engine_bench_with_retry.cpp @@ -152,8 +152,8 @@ static inline std::string calculateRate(uint64_t data_bytes, double duration) { volatile bool running = true; std::atomic total_batch_count(0); -Status initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_id, - void *addr) { +Status initiatorWorker(TransferEngine *engine, SegmentID segment_id, + int thread_id, void *addr) { bindToSocket(thread_id % NR_SOCKETS); TransferRequest::OpCode opcode; if (FLAGS_operation == "read") @@ -192,14 +192,15 @@ Status initiatorWorker(TransferEngine *engine, SegmentID segment_id, int thread_ } s = engine->submitTransfer(batch_id, requests); - if (!s.ok()) + if (!s.ok()) LOG(INFO) << "Found Failed Requests"; else { for (int task_id = 0; task_id < FLAGS_batch_size; ++task_id) { bool completed = false; TransferStatus status; while (!completed) { - Status s = engine->getTransferStatus(batch_id, task_id, status); + Status s = + engine->getTransferStatus(batch_id, task_id, status); LOG_ASSERT(s.ok()); if (status.s == TransferStatusEnum::COMPLETED) completed = true; @@ -319,8 +320,8 @@ int initiator() { auto segment_id = engine->openSegment(segment_name.c_str()); running = true; for (int i = 0; i < FLAGS_threads; ++i) - workers[i] = std::thread(initiatorWorker, engine.get(), segment_id, i, - addr[i % buffer_num]); + workers[i] = std::thread(initiatorWorker, engine.get(), segment_id, + i, addr[i % buffer_num]); sleep(FLAGS_duration); running = false; @@ -354,7 +355,7 @@ volatile bool target_running = true; void signalHandler(int signum) { LOG(INFO) << "Received signal " << signum << ", stopping target server..."; - target_running = false; + target_running = false; } int target() { diff --git a/mooncake-transfer-engine/include/common.h b/mooncake-transfer-engine/include/common.h index a18b19ad0..3cfcf8fec 100644 --- a/mooncake-transfer-engine/include/common.h +++ b/mooncake-transfer-engine/include/common.h @@ -123,48 +123,54 @@ static inline std::string getCurrentDateTime() { uint16_t getDefaultHandshakePort(); -template +template std::optional parseFromString(std::string_view str) { T result = T(); - auto [ptr, ec] = std::from_chars(str.data(), str.data() + str.size(), result); + auto [ptr, ec] = + std::from_chars(str.data(), str.data() + str.size(), result); if (ec != std::errc() || ptr != str.data() + str.size()) { return {}; } return {std::move(result)}; } -static inline uint16_t getPortFromString(std::string_view port_string, uint16_t default_port) { +static inline uint16_t getPortFromString(std::string_view port_string, + uint16_t default_port) { std::optional port = parseFromString(port_string); if (port.has_value()) { return *port; } - LOG(WARNING) << "Illegal port number in " << port_string << ". Use default port " << default_port << " instead"; + LOG(WARNING) << "Illegal port number in " << port_string + << ". Use default port " << default_port << " instead"; return default_port; } -static inline bool isValidIpV6(const std::string& address) { +static inline bool isValidIpV6(const std::string &address) { sockaddr_in6 addr; std::memset(&addr, 0, sizeof(addr)); return inet_pton(AF_INET6, address.c_str(), &addr.sin6_addr) == 1; } -static inline std::string maybeWrapIpV6(const std::string& address) { +static inline std::string maybeWrapIpV6(const std::string &address) { if (isValidIpV6(address)) { return "[" + address + "]"; } return address; } -static inline std::pair parseHostNameWithPort(const std::string &server_name) { +static inline std::pair parseHostNameWithPort( + const std::string &server_name) { uint16_t port = getDefaultHandshakePort(); if (server_name.starts_with("[")) { // [ipv6] or [ipv6]:port const size_t closing_bracket_pos = server_name.find(']'); const size_t colon_pos = server_name.find(':', closing_bracket_pos); - std::string potentialHost = server_name.substr(1, closing_bracket_pos - 1); + std::string potentialHost = + server_name.substr(1, closing_bracket_pos - 1); if (isValidIpV6(potentialHost)) { - return {std::move(potentialHost), getPortFromString(server_name.substr(colon_pos + 1), port)}; + return {std::move(potentialHost), + getPortFromString(server_name.substr(colon_pos + 1), port)}; } // Not valid ipv6, fallback to ipv4/host/etc mode } else if (isValidIpV6(server_name)) { @@ -177,10 +183,13 @@ static inline std::pair parseHostNameWithPort(const std:: if (colon_pos == server_name.npos) { return {server_name, port}; } - return {server_name.substr(0, colon_pos), getPortFromString(server_name.substr(colon_pos + 1), port)}; + return {server_name.substr(0, colon_pos), + getPortFromString(server_name.substr(colon_pos + 1), port)}; } -static inline uint16_t parsePortAndDevice(std::string_view suffix, uint16_t default_port, int *device_id) { +static inline uint16_t parsePortAndDevice(std::string_view suffix, + uint16_t default_port, + int *device_id) { auto colon_pos = suffix.find(':'); if (colon_pos == suffix.npos) { return getPortFromString(suffix, default_port); @@ -189,8 +198,10 @@ static inline uint16_t parsePortAndDevice(std::string_view suffix, uint16_t defa auto npu_str = suffix.substr(colon_pos + 1); auto npu_ops = npu_str.find('_'); - if (npu_ops != npu_str.npos && npu_ops != 0 && npu_ops != npu_str.size() - 1) { - *device_id = parseFromString(npu_str.substr(npu_ops + 1)).value_or(0); + if (npu_ops != npu_str.npos && npu_ops != 0 && + npu_ops != npu_str.size() - 1) { + *device_id = + parseFromString(npu_str.substr(npu_ops + 1)).value_or(0); } return getPortFromString(port_str, default_port); } @@ -203,12 +214,12 @@ static inline std::pair parseHostNameWithPortAscend( // [ipv6] or [ipv6]:port const size_t closing_bracket_pos = server_name.find(']'); const size_t colon_pos = server_name.find(':', closing_bracket_pos); - std::string potentialHost = server_name.substr(1, closing_bracket_pos - 1); + std::string potentialHost = + server_name.substr(1, closing_bracket_pos - 1); if (isValidIpV6(potentialHost)) { - return { - std::move(potentialHost), - parsePortAndDevice(server_name.substr(colon_pos + 1), port, device_id) - }; + return {std::move(potentialHost), + parsePortAndDevice(server_name.substr(colon_pos + 1), port, + device_id)}; } // Not valid ipv6, fallback to ipv4/host/etc mode } else if (isValidIpV6(server_name)) { @@ -222,8 +233,7 @@ static inline std::pair parseHostNameWithPortAscend( return { server_name.substr(0, colon_pos), - parsePortAndDevice(server_name.substr(colon_pos + 1), port, device_id) - }; + parsePortAndDevice(server_name.substr(colon_pos + 1), port, device_id)}; } static inline ssize_t writeFully(int fd, const void *buf, size_t len) { diff --git a/mooncake-transfer-engine/include/common/base/status.h b/mooncake-transfer-engine/include/common/base/status.h index b3b30d853..4015a9cad 100644 --- a/mooncake-transfer-engine/include/common/base/status.h +++ b/mooncake-transfer-engine/include/common/base/status.h @@ -28,257 +28,238 @@ namespace mooncake { class Status final { - public: - // The code of the status. - enum class Code : uint16_t { - kOk = 0, - kInvalidArgument = 1, - kTooManyRequests = 2, - kAddressNotRegistered = 3, - kBatchBusy = 4, - kDeviceNotFound = 6, - kAddressOverlapped = 7, - kNotSupportedTransport = 8, - kDns = 101, - kSocket = 102, - kMalformedJson = 103, - kRejectHandshake = 104, - kMetadata = 200, - kEndpoint = 201, - kContext = 202, - kNuma = 300, - kClock = 301, - kMemory = 302, - kNotImplemented = 999, - kMaxCode - }; - - // Builds an OK Status. - Status() = default; - - ~Status() { delete[] message_; } - - // Constructs a Status object containing a status code and message. - // If 'code == Code::kOk', 'msg' is ignored and an object identical to an OK - // status is constructed. - Status(Code code, std::string_view message); - - Status(const Status& s); - Status& operator=(const Status& s); - Status(Status&& s); - Status& operator=(Status&& s); - - // Returns the stored status code. - Code code() const { return code_; } - - // Return the error message (if any). - std::string_view message() const { - if (message_) { - return message_; - } else { - return std::string_view(); + public: + // The code of the status. + enum class Code : uint16_t { + kOk = 0, + kInvalidArgument = 1, + kTooManyRequests = 2, + kAddressNotRegistered = 3, + kBatchBusy = 4, + kDeviceNotFound = 6, + kAddressOverlapped = 7, + kNotSupportedTransport = 8, + kDns = 101, + kSocket = 102, + kMalformedJson = 103, + kRejectHandshake = 104, + kMetadata = 200, + kEndpoint = 201, + kContext = 202, + kNuma = 300, + kClock = 301, + kMemory = 302, + kNotImplemented = 999, + kMaxCode + }; + + // Builds an OK Status. + Status() = default; + + ~Status() { delete[] message_; } + + // Constructs a Status object containing a status code and message. + // If 'code == Code::kOk', 'msg' is ignored and an object identical to an OK + // status is constructed. + Status(Code code, std::string_view message); + + Status(const Status& s); + Status& operator=(const Status& s); + Status(Status&& s); + Status& operator=(Status&& s); + + // Returns the stored status code. + Code code() const { return code_; } + + // Return the error message (if any). + std::string_view message() const { + if (message_) { + return message_; + } else { + return std::string_view(); + } } - } - - // Returns true if the Status is OK. - [[nodiscard]] bool ok() const { return Code::kOk == code_; } - - // Returns true iff the status indicates an InvalidArgument error. - [[nodiscard]] bool IsInvalidArgument() const { - return Code::kInvalidArgument == code_; - } - - // Returns true iff the status indicates a TooManyRequests error. - [[nodiscard]] bool IsTooManyRequests() const { - return Code::kTooManyRequests == code_; - } - - // Returns true iff the status indicates an AddressNotRegistered error. - [[nodiscard]] bool IsAddressNotRegistered() const { - return Code::kAddressNotRegistered == code_; - } - - // Returns true iff the status indicates a BatchBusy error. - [[nodiscard]] bool IsBatchBusy() const { - return Code::kBatchBusy == code_; - } - - // Returns true iff the status indicates an DeviceNotFound error. - [[nodiscard]] bool IsDeviceNotFound() const { - return Code::kDeviceNotFound == code_; - } - - // Returns true iff the status indicates an AddressOverlapped error. - [[nodiscard]] bool IsAddressOverlapped() const { - return Code::kAddressOverlapped == code_; - } - - // Returns true iff the status indicates a dns error. - [[nodiscard]] bool IsDns() const { - return Code::kDns == code_; - } - - // Returns true iff the status indicates an Socket error. - [[nodiscard]] bool IsSocket() const { - return Code::kSocket == code_; - } - - // Returns true iff the status indicates a MalformedJson error. - [[nodiscard]] bool IsMalformedJson() const { - return Code::kMalformedJson == code_; - } - - // Returns true iff the status indicates a RejectHandshake error. - [[nodiscard]] bool IsRejectHandshake() const { - return Code::kRejectHandshake == code_; - } - - // Returns true iff the status indicates a Metadata error. - [[nodiscard]] bool IsMetadata() const { - return Code::kMetadata == code_; - } - - // Returns true iff the status indicates an Endpoint error. - [[nodiscard]] bool IsEndpoint() const { - return Code::kEndpoint == code_; - } - - // Returns true iff the status indicates a Context error. - [[nodiscard]] bool IsContext() const { - return Code::kContext == code_; - } - - // Returns true iff the status indicates a Numa error. - [[nodiscard]] bool IsNuma() const { - return Code::kNuma == code_; - } - - // Returns true iff the status indicates a Clock error. - [[nodiscard]] bool IsClock() const { - return Code::kClock == code_; - } - - // Returns true iff the status indicates a Memory error. - [[nodiscard]] bool IsMemory() const { - return Code::kMemory == code_; - } - - // Returns true iff the status indicates a NotImplemented error. - [[nodiscard]] bool IsNotImplemented() const { - return Code::kNotImplemented == code_; - } - - // Returns true iff the status indicates a NotImplemented error. - [[nodiscard]] bool IsNotSupportedTransport() const { - return Code::kNotSupportedTransport == code_; - } - - // Return a combination of the error code name and message. - std::string ToString() const; - - bool operator==(const Status& s) const; - bool operator!=(const Status& s) const; - - // Return a status of an appropriate type. - static Status OK() { return Status(); } - static Status InvalidArgument(std::string_view msg) { - return Status(Code::kInvalidArgument, msg); - } - static Status TooManyRequests(std::string_view msg) { - return Status(Code::kTooManyRequests, msg); - } - static Status AddressNotRegistered(std::string_view msg) { - return Status(Code::kAddressNotRegistered, msg); - } - static Status BatchBusy(std::string_view msg) { - return Status(Code::kBatchBusy, msg); - } - static Status DeviceNotFound(std::string_view msg) { - return Status(Code::kDeviceNotFound, msg); - } - static Status AddressOverlapped(std::string_view msg) { - return Status(Code::kAddressOverlapped, msg); - } - static Status Dns(std::string_view msg) { - return Status(Code::kDns, msg); - } - static Status Socket(std::string_view msg) { - return Status(Code::kSocket, msg); - } - static Status MalformedJson(std::string_view msg) { - return Status(Code::kMalformedJson, msg); - } - static Status RejectHandshake(std::string_view msg) { - return Status(Code::kRejectHandshake, msg); - } - static Status Metadata(std::string_view msg) { - return Status(Code::kMetadata, msg); - } - static Status Endpoint(std::string_view msg) { - return Status(Code::kEndpoint, msg); - } - static Status Context(std::string_view msg) { - return Status(Code::kContext, msg); - } - static Status Numa(std::string_view msg) { - return Status(Code::kNuma, msg); - } - static Status Clock(std::string_view msg) { - return Status(Code::kClock, msg); - } - static Status Memory(std::string_view msg) { - return Status(Code::kMemory, msg); - } - static Status NotImplemented(std::string_view msg) { - return Status(Code::kNotImplemented, msg); - } - static Status NotSupportedTransport(std::string_view msg) { - return Status(Code::kNotSupportedTransport, msg); - } - - // Return a human-readable name of the 'code'. - static std::string_view CodeToString(Code code); - - private: - // Return a copy of the message 'msg'. - static const char* CopyMessage(const char* msg); - - // The code of the status. - Code code_ = Code::kOk; - // The error message of the status. Refer to the Status definition in RocksDB, - // we don't use 'std::string' type message but 'const char*' type one for the - // performance considerations. A memory allocation in the std::string - // construction could be avoid for the most cases that the Status is OK. And - // the total size of 'message_' is only 8 bytes on a x86-64 platform, while - // the size of a uninitialized strings with SSO (Small String Optimization) - // will be 24 to 32 bytes big, excluding the dynamically allocated memory. - const char* message_ = nullptr; + + // Returns true if the Status is OK. + [[nodiscard]] bool ok() const { return Code::kOk == code_; } + + // Returns true iff the status indicates an InvalidArgument error. + [[nodiscard]] bool IsInvalidArgument() const { + return Code::kInvalidArgument == code_; + } + + // Returns true iff the status indicates a TooManyRequests error. + [[nodiscard]] bool IsTooManyRequests() const { + return Code::kTooManyRequests == code_; + } + + // Returns true iff the status indicates an AddressNotRegistered error. + [[nodiscard]] bool IsAddressNotRegistered() const { + return Code::kAddressNotRegistered == code_; + } + + // Returns true iff the status indicates a BatchBusy error. + [[nodiscard]] bool IsBatchBusy() const { return Code::kBatchBusy == code_; } + + // Returns true iff the status indicates an DeviceNotFound error. + [[nodiscard]] bool IsDeviceNotFound() const { + return Code::kDeviceNotFound == code_; + } + + // Returns true iff the status indicates an AddressOverlapped error. + [[nodiscard]] bool IsAddressOverlapped() const { + return Code::kAddressOverlapped == code_; + } + + // Returns true iff the status indicates a dns error. + [[nodiscard]] bool IsDns() const { return Code::kDns == code_; } + + // Returns true iff the status indicates an Socket error. + [[nodiscard]] bool IsSocket() const { return Code::kSocket == code_; } + + // Returns true iff the status indicates a MalformedJson error. + [[nodiscard]] bool IsMalformedJson() const { + return Code::kMalformedJson == code_; + } + + // Returns true iff the status indicates a RejectHandshake error. + [[nodiscard]] bool IsRejectHandshake() const { + return Code::kRejectHandshake == code_; + } + + // Returns true iff the status indicates a Metadata error. + [[nodiscard]] bool IsMetadata() const { return Code::kMetadata == code_; } + + // Returns true iff the status indicates an Endpoint error. + [[nodiscard]] bool IsEndpoint() const { return Code::kEndpoint == code_; } + + // Returns true iff the status indicates a Context error. + [[nodiscard]] bool IsContext() const { return Code::kContext == code_; } + + // Returns true iff the status indicates a Numa error. + [[nodiscard]] bool IsNuma() const { return Code::kNuma == code_; } + + // Returns true iff the status indicates a Clock error. + [[nodiscard]] bool IsClock() const { return Code::kClock == code_; } + + // Returns true iff the status indicates a Memory error. + [[nodiscard]] bool IsMemory() const { return Code::kMemory == code_; } + + // Returns true iff the status indicates a NotImplemented error. + [[nodiscard]] bool IsNotImplemented() const { + return Code::kNotImplemented == code_; + } + + // Returns true iff the status indicates a NotImplemented error. + [[nodiscard]] bool IsNotSupportedTransport() const { + return Code::kNotSupportedTransport == code_; + } + + // Return a combination of the error code name and message. + std::string ToString() const; + + bool operator==(const Status& s) const; + bool operator!=(const Status& s) const; + + // Return a status of an appropriate type. + static Status OK() { return Status(); } + static Status InvalidArgument(std::string_view msg) { + return Status(Code::kInvalidArgument, msg); + } + static Status TooManyRequests(std::string_view msg) { + return Status(Code::kTooManyRequests, msg); + } + static Status AddressNotRegistered(std::string_view msg) { + return Status(Code::kAddressNotRegistered, msg); + } + static Status BatchBusy(std::string_view msg) { + return Status(Code::kBatchBusy, msg); + } + static Status DeviceNotFound(std::string_view msg) { + return Status(Code::kDeviceNotFound, msg); + } + static Status AddressOverlapped(std::string_view msg) { + return Status(Code::kAddressOverlapped, msg); + } + static Status Dns(std::string_view msg) { return Status(Code::kDns, msg); } + static Status Socket(std::string_view msg) { + return Status(Code::kSocket, msg); + } + static Status MalformedJson(std::string_view msg) { + return Status(Code::kMalformedJson, msg); + } + static Status RejectHandshake(std::string_view msg) { + return Status(Code::kRejectHandshake, msg); + } + static Status Metadata(std::string_view msg) { + return Status(Code::kMetadata, msg); + } + static Status Endpoint(std::string_view msg) { + return Status(Code::kEndpoint, msg); + } + static Status Context(std::string_view msg) { + return Status(Code::kContext, msg); + } + static Status Numa(std::string_view msg) { + return Status(Code::kNuma, msg); + } + static Status Clock(std::string_view msg) { + return Status(Code::kClock, msg); + } + static Status Memory(std::string_view msg) { + return Status(Code::kMemory, msg); + } + static Status NotImplemented(std::string_view msg) { + return Status(Code::kNotImplemented, msg); + } + static Status NotSupportedTransport(std::string_view msg) { + return Status(Code::kNotSupportedTransport, msg); + } + + // Return a human-readable name of the 'code'. + static std::string_view CodeToString(Code code); + + private: + // Return a copy of the message 'msg'. + static const char* CopyMessage(const char* msg); + + // The code of the status. + Code code_ = Code::kOk; + // The error message of the status. Refer to the Status definition in + // RocksDB, we don't use 'std::string' type message but 'const char*' type + // one for the performance considerations. A memory allocation in the + // std::string construction could be avoid for the most cases that the + // Status is OK. And the total size of 'message_' is only 8 bytes on a + // x86-64 platform, while the size of a uninitialized strings with SSO + // (Small String Optimization) will be 24 to 32 bytes big, excluding the + // dynamically allocated memory. + const char* message_ = nullptr; }; inline Status::Status(const Status& s) : code_(s.code_) { - message_ = (s.message_ == nullptr) ? nullptr : CopyMessage(s.message_); + message_ = (s.message_ == nullptr) ? nullptr : CopyMessage(s.message_); } inline Status& Status::operator=(const Status& s) { - if (this != &s) { - code_ = s.code_; - delete[] message_; - message_ = (s.message_ == nullptr) ? nullptr : CopyMessage(s.message_); - } - return *this; + if (this != &s) { + code_ = s.code_; + delete[] message_; + message_ = (s.message_ == nullptr) ? nullptr : CopyMessage(s.message_); + } + return *this; } inline Status::Status(Status&& s) : Status() { *this = std::move(s); } inline Status& Status::operator=(Status&& s) { - if (this != &s) { - code_ = std::move(s.code_); - s.code_ = Code::kOk; - delete[] message_; - message_ = nullptr; - std::swap(message_, s.message_); - } - return *this; + if (this != &s) { + code_ = std::move(s.code_); + s.code_ = Code::kOk; + delete[] message_; + message_ = nullptr; + std::swap(message_, s.message_); + } + return *this; } // Prints a human-readable representation name of the 'code' to 'os'. diff --git a/mooncake-transfer-engine/include/transfer_engine.h b/mooncake-transfer-engine/include/transfer_engine.h index 5105f2490..5cea5c4af 100644 --- a/mooncake-transfer-engine/include/transfer_engine.h +++ b/mooncake-transfer-engine/include/transfer_engine.h @@ -222,7 +222,7 @@ class TransferEngine { std::shared_mutex mutex_; std::vector local_memory_regions_; std::shared_ptr local_topology_; - + RWSpinlock send_notifies_lock_; std::unordered_map> diff --git a/mooncake-transfer-engine/include/transfer_metadata.h b/mooncake-transfer-engine/include/transfer_metadata.h index cc583f1a8..70f15c8d4 100644 --- a/mooncake-transfer-engine/include/transfer_metadata.h +++ b/mooncake-transfer-engine/include/transfer_metadata.h @@ -65,12 +65,12 @@ class TransferMetadata { }; struct RankInfoDesc { - uint64_t rankId = 0xFFFFFFFF; // rank id, user rank + uint64_t rankId = 0xFFFFFFFF; // rank id, user rank std::string hostIp; uint64_t hostPort; uint64_t deviceLogicId; uint64_t devicePhyId; - uint64_t deviceType = 5; // default + uint64_t deviceType = 5; // default std::string deviceIp; uint64_t devicePort; uint64_t pid; diff --git a/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport.h b/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport.h index 5cc3e0960..94a09787f 100644 --- a/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport.h +++ b/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport.h @@ -32,7 +32,7 @@ #include "hccl_transport_mem_c.h" #define THREAD_NUM 1 -#define ASCEND_DEFAULT_HOST_PORT 10000 +#define ASCEND_DEFAULT_HOST_PORT 10000 #define ASCEND_DEFAULT_DEVICE_PORT 16666 namespace mooncake { @@ -58,14 +58,17 @@ class HcclTransport : public Transport { TransferStatus &status) override; int install(std::string &local_server_name, - std::shared_ptr meta, std::shared_ptr topo) override; + std::shared_ptr meta, + std::shared_ptr topo) override; const char *getName() const override { return "hccl"; } - + int registerLocalMemory(void *addr, size_t length, - const std::string &location, bool remote_accessible, bool update_metadata) override; + const std::string &location, bool remote_accessible, + bool update_metadata) override; - int unregisterLocalMemory(void *addr, bool update_metadata = false) override; + int unregisterLocalMemory(void *addr, + bool update_metadata = false) override; int registerLocalMemoryBatch( const std::vector &buffer_list, @@ -83,7 +86,9 @@ class HcclTransport : public Transport { void acceptLoop(int deviceLogicId); - int getDevIdAndIpPortFromServerName(std::string& local_server_name, std::string& ip, int &ip_port, int& devicePhyId); + int getDevIdAndIpPortFromServerName(std::string &local_server_name, + std::string &ip, int &ip_port, + int &devicePhyId); int rankInfoParse(int devicePhyId, std::string hostIp); diff --git a/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport_mem_c.h b/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport_mem_c.h index 15a1a78e7..832ef910b 100644 --- a/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport_mem_c.h +++ b/mooncake-transfer-engine/include/transport/ascend_transport/hccl_transport/hccl_transport_mem_c.h @@ -43,11 +43,11 @@ #ifdef __cplusplus extern "C" { -#endif // __cplusplus +#endif // __cplusplus struct RankInfo { uint64_t rankId = 0xFFFFFFFF; - uint64_t serverIdx; + uint64_t serverIdx; struct in_addr hostIp; uint64_t hostPort; uint64_t deviceLogicId; @@ -69,7 +69,7 @@ struct RankControlInfo { struct MergeMem { void *addr = nullptr; uint64_t len = 0; - MergeMem(void* addr_, size_t len_) : addr(addr_), len(len_) {} + MergeMem(void *addr_, size_t len_) : addr(addr_), len(len_) {} }; struct ConnectionInfo { @@ -80,25 +80,27 @@ struct ConnectionInfo { }; // Retry mechanism for initialization function failure -#define RETRY_CALL(funcCall, errorMsg) \ - do { \ - int retryCount = 0; \ - int __ret = funcCall; \ - while (__ret && retryCount < 3) { \ - LOG(ERROR) << errorMsg << ", retrying... (" << ++retryCount << "/3)"; \ - __ret = funcCall; \ - } \ - if (__ret) { \ - LOG(ERROR) << errorMsg << " failed after 3 retries."; \ - return __ret; \ - } \ +#define RETRY_CALL(funcCall, errorMsg) \ + do { \ + int retryCount = 0; \ + int __ret = funcCall; \ + while (__ret && retryCount < 3) { \ + LOG(ERROR) << errorMsg << ", retrying... (" << ++retryCount \ + << "/3)"; \ + __ret = funcCall; \ + } \ + if (__ret) { \ + LOG(ERROR) << errorMsg << " failed after 3 retries."; \ + return __ret; \ + } \ } while (0) extern int initTransportMem(RankInfo *local_rank_info); -extern int transportMemTask(RankInfo *local_rank_info, - RankInfo *remote_rank_info, int op_code, uint64_t offset, - uint64_t req_len, void *local_mem, aclrtStream stream); +extern int transportMemTask(RankInfo *local_rank_info, + RankInfo *remote_rank_info, int op_code, + uint64_t offset, uint64_t req_len, void *local_mem, + aclrtStream stream); extern int transportMemAccept(RankInfo *local_rank_info); @@ -106,10 +108,11 @@ extern int regLocalRmaMem(void *addr, uint64_t length); extern bool printEnabled(); -extern int transportMemAddOpFence(RankInfo *remote_rank_info, aclrtStream stream); +extern int transportMemAddOpFence(RankInfo *remote_rank_info, + aclrtStream stream); #ifdef __cplusplus } -#endif // __cplusplus +#endif // __cplusplus -#endif // HCCL_TRANSPORT_MEM_C_H +#endif // HCCL_TRANSPORT_MEM_C_H diff --git a/mooncake-transfer-engine/include/transport/cxl_transport/cxl_transport.h b/mooncake-transfer-engine/include/transport/cxl_transport/cxl_transport.h index 3d6b2d8fb..db06d64c1 100644 --- a/mooncake-transfer-engine/include/transport/cxl_transport/cxl_transport.h +++ b/mooncake-transfer-engine/include/transport/cxl_transport/cxl_transport.h @@ -15,7 +15,6 @@ #ifndef CXL_TRANSPORT_H_ #define CXL_TRANSPORT_H_ - #include #include #include @@ -51,7 +50,7 @@ class CxlTransport : public Transport { Status getTransferStatus(BatchID batch_id, size_t task_id, TransferStatus &status) override; - void* getCxlBaseAddr() { return cxl_base_addr; } + void *getCxlBaseAddr() { return cxl_base_addr; } private: int install(std::string &local_server_name, @@ -86,9 +85,9 @@ class CxlTransport : public Transport { bool validateMemoryBounds(void *dest, void *src, size_t size); private: - void* cxl_base_addr; + void *cxl_base_addr; size_t cxl_dev_size; - char* cxl_dev_path; + char *cxl_dev_path; }; } // namespace mooncake diff --git a/mooncake-transfer-engine/include/transport/nvlink_transport/nvlink_transport.h b/mooncake-transfer-engine/include/transport/nvlink_transport/nvlink_transport.h index 7b6db2cc1..eddac0ac2 100644 --- a/mooncake-transfer-engine/include/transport/nvlink_transport/nvlink_transport.h +++ b/mooncake-transfer-engine/include/transport/nvlink_transport/nvlink_transport.h @@ -46,9 +46,9 @@ class NvlinkTransport : public Transport { Status getTransferStatus(BatchID batch_id, size_t task_id, TransferStatus& status) override; - static void *allocatePinnedLocalMemory(size_t length); + static void* allocatePinnedLocalMemory(size_t length); - static void freePinnedLocalMemory(void *addr); + static void freePinnedLocalMemory(void* addr); protected: int install(std::string& local_server_name, @@ -80,7 +80,8 @@ class NvlinkTransport : public Transport { uint64_t length; }; - std::unordered_map, OpenedShmEntry, PairHash> remap_entries_; + std::unordered_map, OpenedShmEntry, PairHash> + remap_entries_; RWSpinlock remap_lock_; bool use_fabric_mem_; diff --git a/mooncake-transfer-engine/include/transport/nvmeof_transport/cufile_desc_pool.h b/mooncake-transfer-engine/include/transport/nvmeof_transport/cufile_desc_pool.h index f8d81ba8b..bda91920d 100644 --- a/mooncake-transfer-engine/include/transport/nvmeof_transport/cufile_desc_pool.h +++ b/mooncake-transfer-engine/include/transport/nvmeof_transport/cufile_desc_pool.h @@ -37,7 +37,7 @@ class CUFileDescPool { CUFileDescPool(CUFileDescPool &&) = delete; int allocCUfileDesc(size_t batch_size); // ret: (desc_idx, start_idx) - + int pushParams(int idx, CUfileIOParams_t &io_params); int submitBatch(int idx); diff --git a/mooncake-transfer-engine/include/transport/nvmeof_transport/nvmeof_transport.h b/mooncake-transfer-engine/include/transport/nvmeof_transport/nvmeof_transport.h index a45763ff8..120e25db6 100644 --- a/mooncake-transfer-engine/include/transport/nvmeof_transport/nvmeof_transport.h +++ b/mooncake-transfer-engine/include/transport/nvmeof_transport/nvmeof_transport.h @@ -30,8 +30,7 @@ namespace mooncake { -struct NVMeoFBatchDesc -{ +struct NVMeoFBatchDesc { int desc_idx_; std::vector transfer_status; std::vector> task_to_slices; @@ -57,11 +56,9 @@ class NVMeoFTransport : public Transport { Status freeBatchID(BatchID batch_id) override; void addSliceToTask(void *source_addr, uint64_t slice_len, - uint64_t target_start, - TransferRequest::OpCode op, - TransferTask &task, - const char *file_path); - + uint64_t target_start, TransferRequest::OpCode op, + TransferTask &task, const char *file_path); + private: void startTransfer(Slice *slice); diff --git a/mooncake-transfer-engine/include/transport/rdma_transport/rdma_endpoint.h b/mooncake-transfer-engine/include/transport/rdma_transport/rdma_endpoint.h index 9c583d152..b9052cc9b 100644 --- a/mooncake-transfer-engine/include/transport/rdma_transport/rdma_endpoint.h +++ b/mooncake-transfer-engine/include/transport/rdma_transport/rdma_endpoint.h @@ -73,9 +73,9 @@ class RdmaEndPoint { bool active() const { return active_; } - void set_active(bool flag) { + void set_active(bool flag) { RWSpinlock::WriteGuard guard(lock_); - active_ = flag; + active_ = flag; if (!flag) inactive_time_ = getCurrentTimeInNano(); } diff --git a/mooncake-transfer-engine/src/common/base/status.cpp b/mooncake-transfer-engine/src/common/base/status.cpp index 9508d10b2..cf75c2e91 100644 --- a/mooncake-transfer-engine/src/common/base/status.cpp +++ b/mooncake-transfer-engine/src/common/base/status.cpp @@ -24,104 +24,104 @@ namespace mooncake { -Status::Status(Status::Code code, std::string_view message) - : code_(code) { - if (code != Code::kOk) { - // Only store the message when it is not empty. - if (!message.empty()) { - const size_t len = message.size(); - // +1 for null terminator - char* const result = new char[len + 1]; - memcpy(result, message.data(), len); - result[len] = '\0'; - message_ = result; +Status::Status(Status::Code code, std::string_view message) : code_(code) { + if (code != Code::kOk) { + // Only store the message when it is not empty. + if (!message.empty()) { + const size_t len = message.size(); + // +1 for null terminator + char* const result = new char[len + 1]; + memcpy(result, message.data(), len); + result[len] = '\0'; + message_ = result; + } } - } } std::string Status::ToString() const { - if (ok()) { - return "OK"; - } else { - return std::string(CodeToString(code())) + ": " + std::string(message()); - } + if (ok()) { + return "OK"; + } else { + return std::string(CodeToString(code())) + ": " + + std::string(message()); + } } std::string_view Status::CodeToString(Status::Code code) { - switch (code) { - case Code::kOk: - return "OK"; - case Code::kInvalidArgument: - return "InvalidArgument"; - case Code::kTooManyRequests: - return "TooManyRequests"; - case Code::kAddressNotRegistered: - return "AddressNotRegistered"; - case Code::kBatchBusy: - return "BatchBusy"; - case Code::kDeviceNotFound: - return "DeviceNotFound"; - case Code::kAddressOverlapped: - return "AddressOverlapped"; - case Code::kDns: - return "Dns"; - case Code::kSocket: - return "Socket"; - case Code::kMalformedJson: - return "MalformedJson"; - case Code::kRejectHandshake: - return "RejectHandshake"; - case Code::kMetadata: - return "Metadata"; - case Code::kEndpoint: - return "Endpoint"; - case Code::kContext: - return "Context"; - case Code::kNuma: - return "Numa"; - case Code::kClock: - return "Clock"; - case Code::kMemory: - return "Memory"; - case Code::kNotImplemented: - return "NotImplemented"; - case Code::kNotSupportedTransport: - return "NotSupportedTransport"; - default: - LOG(ERROR) << "Unknown code: " << static_cast(code); - return "UnknownCode"; - } + switch (code) { + case Code::kOk: + return "OK"; + case Code::kInvalidArgument: + return "InvalidArgument"; + case Code::kTooManyRequests: + return "TooManyRequests"; + case Code::kAddressNotRegistered: + return "AddressNotRegistered"; + case Code::kBatchBusy: + return "BatchBusy"; + case Code::kDeviceNotFound: + return "DeviceNotFound"; + case Code::kAddressOverlapped: + return "AddressOverlapped"; + case Code::kDns: + return "Dns"; + case Code::kSocket: + return "Socket"; + case Code::kMalformedJson: + return "MalformedJson"; + case Code::kRejectHandshake: + return "RejectHandshake"; + case Code::kMetadata: + return "Metadata"; + case Code::kEndpoint: + return "Endpoint"; + case Code::kContext: + return "Context"; + case Code::kNuma: + return "Numa"; + case Code::kClock: + return "Clock"; + case Code::kMemory: + return "Memory"; + case Code::kNotImplemented: + return "NotImplemented"; + case Code::kNotSupportedTransport: + return "NotSupportedTransport"; + default: + LOG(ERROR) << "Unknown code: " << static_cast(code); + return "UnknownCode"; + } } const char* Status::CopyMessage(const char* msg) { - // +1 for the null terminator - const size_t len = std::strlen(msg) + 1; - return std::strncpy(new char[len], msg, len); + // +1 for the null terminator + const size_t len = std::strlen(msg) + 1; + return std::strncpy(new char[len], msg, len); } bool Status::operator==(const Status& s) const { - // Compare the code. - if (code_ != s.code_) { + // Compare the code. + if (code_ != s.code_) { + return false; + } + // Compare the message content. + if (message_ == nullptr && s.message_ == nullptr) { + return true; + } + if (message_ != nullptr && s.message_ != nullptr) { + return strcmp(message_, s.message_) == 0; + } return false; - } - // Compare the message content. - if (message_ == nullptr && s.message_ == nullptr) { - return true; - } - if (message_ != nullptr && s.message_ != nullptr) { - return strcmp(message_, s.message_) == 0; - } - return false; } bool Status::operator!=(const Status& s) const { return !(*this == s); } std::ostream& operator<<(std::ostream& os, Status::Code code) { - return os << Status::CodeToString(code); + return os << Status::CodeToString(code); } std::ostream& operator<<(std::ostream& os, const Status& s) { - return os << s.ToString(); + return os << s.ToString(); } } // namespace mooncake diff --git a/mooncake-transfer-engine/src/config.cpp b/mooncake-transfer-engine/src/config.cpp index 1fef662fb..61fcc835e 100644 --- a/mooncake-transfer-engine/src/config.cpp +++ b/mooncake-transfer-engine/src/config.cpp @@ -249,8 +249,8 @@ void loadGlobalConfig(GlobalConfig &config) { if (val > 0 && val < config.slice_size) config.fragment_limit = config.slice_size / val; else { - LOG(WARNING) - << "Ignore value from environment variable MC_FRAGMENT_RATIO and set it to 4 as default"; + LOG(WARNING) << "Ignore value from environment variable " + "MC_FRAGMENT_RATIO and set it to 4 as default"; config.fragment_limit = config.slice_size / 4; } } diff --git a/mooncake-transfer-engine/src/memory_location.cpp b/mooncake-transfer-engine/src/memory_location.cpp index 615950f7d..6ade7c559 100644 --- a/mooncake-transfer-engine/src/memory_location.cpp +++ b/mooncake-transfer-engine/src/memory_location.cpp @@ -56,7 +56,8 @@ const std::vector getMemoryLocation(void *start, // start and end address may not be page aligned. uintptr_t aligned_start = alignPage((uintptr_t)start); - long long n = (uintptr_t(start) - aligned_start + len + pagesize - 1) / pagesize; + long long n = + (uintptr_t(start) - aligned_start + len + pagesize - 1) / pagesize; void **pages = (void **)malloc(sizeof(void *) * n); int *status = (int *)malloc(sizeof(int) * n); diff --git a/mooncake-transfer-engine/src/multi_transport.cpp b/mooncake-transfer-engine/src/multi_transport.cpp index b6f32efcb..35b18ab2c 100644 --- a/mooncake-transfer-engine/src/multi_transport.cpp +++ b/mooncake-transfer-engine/src/multi_transport.cpp @@ -150,26 +150,27 @@ Status MultiTransport::getTransferStatus(BatchID batch_id, size_t task_id, return Status::OK(); } -Status MultiTransport::getBatchTransferStatus(BatchID batch_id, TransferStatus &status) { +Status MultiTransport::getBatchTransferStatus(BatchID batch_id, + TransferStatus &status) { auto &batch_desc = *((BatchDesc *)(batch_id)); const size_t task_count = batch_desc.task_list.size(); status.transferred_bytes = 0; - + if (task_count == 0) { status.s = Transport::TransferStatusEnum::COMPLETED; return Status::OK(); } - + size_t success_count = 0; for (size_t task_id = 0; task_id < task_count; task_id++) { TransferStatus task_status; auto ret = getTransferStatus(batch_id, task_id, task_status); - + if (!ret.ok()) { status.s = Transport::TransferStatusEnum::FAILED; return Status::OK(); } - + if (task_status.s == Transport::TransferStatusEnum::COMPLETED) { status.transferred_bytes += task_status.transferred_bytes; success_count++; @@ -178,10 +179,10 @@ Status MultiTransport::getBatchTransferStatus(BatchID batch_id, TransferStatus & return Status::OK(); } } - - status.s = (success_count == task_count) ? - Transport::TransferStatusEnum::COMPLETED : - Transport::TransferStatusEnum::WAITING; + + status.s = (success_count == task_count) + ? Transport::TransferStatusEnum::COMPLETED + : Transport::TransferStatusEnum::WAITING; return Status::OK(); } diff --git a/mooncake-transfer-engine/src/topology.cpp b/mooncake-transfer-engine/src/topology.cpp index 2b69bc754..f9e87ef2b 100644 --- a/mooncake-transfer-engine/src/topology.cpp +++ b/mooncake-transfer-engine/src/topology.cpp @@ -177,11 +177,11 @@ static std::vector discoverCudaTopology( std::vector preferred_hca; std::vector avail_hca; - + // Find HCAs with minimum distance in one pass int min_distance = INT_MAX; std::vector min_distance_hcas; - + for (const auto &hca : all_hca) { int distance = getPciDistance(hca.pci_bus_id.c_str(), pci_bus_id); if (distance >= 0) { @@ -194,10 +194,11 @@ static std::vector discoverCudaTopology( } } } - + // Add HCAs with minimum distance to preferred_hca, others to avail_hca for (const auto &hca : all_hca) { - if (std::find(min_distance_hcas.begin(), min_distance_hcas.end(), hca.name) != min_distance_hcas.end()) { + if (std::find(min_distance_hcas.begin(), min_distance_hcas.end(), + hca.name) != min_distance_hcas.end()) { preferred_hca.push_back(hca.name); } else { avail_hca.push_back(hca.name); @@ -218,7 +219,7 @@ Topology::Topology() {} Topology::~Topology() {} bool Topology::empty() const { - for (const auto& entry : resolved_matrix_) { + for (const auto &entry : resolved_matrix_) { if (!entry.second.preferred_hca.empty() || !entry.second.avail_hca.empty()) { return false; diff --git a/mooncake-transfer-engine/src/transfer_engine.cpp b/mooncake-transfer-engine/src/transfer_engine.cpp index 6d8e2e29c..7c192f9ad 100644 --- a/mooncake-transfer-engine/src/transfer_engine.cpp +++ b/mooncake-transfer-engine/src/transfer_engine.cpp @@ -50,17 +50,22 @@ int TransferEngine::init(const std::string &metadata_conn_string, std::string rpc_binding_method; #ifdef USE_ASCEND - // The only difference in initializing the Ascend Transport is that the `local_server_name` must include the physical NPU card ID. - // The format changes from `ip:port` to `ip:port:npu_x`, e.g., `"0.0.0.0:12345:npu_2"`. - // While the desc_name stored in the metadata remains in the format of ip:port. + // The only difference in initializing the Ascend Transport is that the + // `local_server_name` must include the physical NPU card ID. The format + // changes from `ip:port` to `ip:port:npu_x`, e.g., `"0.0.0.0:12345:npu_2"`. + // While the desc_name stored in the metadata remains in the format of + // ip:port. int devicePhyId = -1; - auto[host_name, port] = parseHostNameWithPortAscend(local_server_name, &devicePhyId); - LOG(INFO) << "Transfer Engine parseHostNameWithPortAscend. server_name: " << host_name - << " port: " << port << " devicePhyId: " << devicePhyId; + auto [host_name, port] = + parseHostNameWithPortAscend(local_server_name, &devicePhyId); + LOG(INFO) << "Transfer Engine parseHostNameWithPortAscend. server_name: " + << host_name << " port: " << port + << " devicePhyId: " << devicePhyId; local_server_name_ = host_name + ":" + std::to_string(port); #else - auto[host_name, port] = parseHostNameWithPort(local_server_name); - LOG(INFO) << "Transfer Engine parseHostNameWithPort. server_name: " << host_name << " port: " << port; + auto [host_name, port] = parseHostNameWithPort(local_server_name); + LOG(INFO) << "Transfer Engine parseHostNameWithPort. server_name: " + << host_name << " port: " << port; local_server_name_ = local_server_name; #endif @@ -79,10 +84,13 @@ int TransferEngine::init(const std::string &metadata_conn_string, return -1; } #ifdef USE_ASCEND - // The current version of Ascend Transport does not support IPv6, but it will be added in a future release. - local_server_name_ = desc.ip_or_host_name + ":" + std::to_string(desc.rpc_port); + // The current version of Ascend Transport does not support IPv6, + // but it will be added in a future release. + local_server_name_ = + desc.ip_or_host_name + ":" + std::to_string(desc.rpc_port); #else - local_server_name_ = maybeWrapIpV6(desc.ip_or_host_name) + ":" + std::to_string(desc.rpc_port); + local_server_name_ = maybeWrapIpV6(desc.ip_or_host_name) + ":" + + std::to_string(desc.rpc_port); #endif } } else { @@ -117,7 +125,8 @@ int TransferEngine::init(const std::string &metadata_conn_string, metadata_ = std::make_shared(metadata_conn_string); #ifdef USE_ASCEND - std::string mutable_server_name = local_server_name_ + ":npu_" + std::to_string(devicePhyId); + std::string mutable_server_name = + local_server_name_ + ":npu_" + std::to_string(devicePhyId); multi_transports_ = std::make_shared(metadata_, mutable_server_name); #else @@ -128,7 +137,8 @@ int TransferEngine::init(const std::string &metadata_conn_string, if (ret) return ret; #ifdef USE_ASCEND - Transport* ascend_transport = multi_transports_->installTransport("ascend", local_topology_); + Transport *ascend_transport = + multi_transports_->installTransport("ascend", local_topology_); if (!ascend_transport) { LOG(ERROR) << "Failed to install Ascend transport"; return -1; @@ -136,8 +146,10 @@ int TransferEngine::init(const std::string &metadata_conn_string, #else #if defined(USE_CXL) && !defined(USE_ASCEND) - if (std::getenv("MC_CXL_DEV_PATH") != nullptr && std::getenv("MC_CXL_DEV_SIZE") != nullptr) { - Transport* cxl_transport = multi_transports_->installTransport("cxl", local_topology_); + if (std::getenv("MC_CXL_DEV_PATH") != nullptr && + std::getenv("MC_CXL_DEV_SIZE") != nullptr) { + Transport *cxl_transport = + multi_transports_->installTransport("cxl", local_topology_); if (!cxl_transport) { LOG(ERROR) << "Failed to install CXL transport"; return -1; @@ -167,13 +179,15 @@ int TransferEngine::init(const std::string &metadata_conn_string, #ifdef USE_MNNVL if (local_topology_->getHcaList().size() > 0 && !getenv("MC_FORCE_MNNVL")) { - Transport* rdma_transport = multi_transports_->installTransport("rdma", local_topology_); + Transport *rdma_transport = + multi_transports_->installTransport("rdma", local_topology_); if (!rdma_transport) { LOG(ERROR) << "Failed to install RDMA transport"; return -1; } } else { - Transport* nvlink_transport = multi_transports_->installTransport("nvlink", nullptr); + Transport *nvlink_transport = + multi_transports_->installTransport("nvlink", nullptr); if (!nvlink_transport) { LOG(ERROR) << "Failed to install NVLink transport"; return -1; @@ -182,13 +196,15 @@ int TransferEngine::init(const std::string &metadata_conn_string, #else if (local_topology_->getHcaList().size() > 0) { // only install RDMA transport when there is at least one HCA - Transport* rdma_transport = multi_transports_->installTransport("rdma", local_topology_); + Transport *rdma_transport = + multi_transports_->installTransport("rdma", local_topology_); if (!rdma_transport) { LOG(ERROR) << "Failed to install RDMA transport"; return -1; } } else { - Transport* tcp_transport = multi_transports_->installTransport("tcp", nullptr); + Transport *tcp_transport = + multi_transports_->installTransport("tcp", nullptr); if (!tcp_transport) { LOG(ERROR) << "Failed to install TCP transport"; return -1; @@ -315,7 +331,8 @@ int TransferEngine::registerLocalMemory(void *addr, size_t length, return ERR_ADDRESS_OVERLAPPED; } if (length == 0) { - LOG(ERROR) << "Transfer Engine does not support zero length memory region"; + LOG(ERROR) + << "Transfer Engine does not support zero length memory region"; return ERR_INVALID_ARGUMENT; } for (auto transport : multi_transports_->listTransports()) { diff --git a/mooncake-transfer-engine/src/transfer_metadata.cpp b/mooncake-transfer-engine/src/transfer_metadata.cpp index a6b8e541e..52f5a228c 100644 --- a/mooncake-transfer-engine/src/transfer_metadata.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata.cpp @@ -215,7 +215,8 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, segmentJSON["buffers"] = buffersJSON; } else if (segmentJSON["protocol"] == "cxl") { segmentJSON["cxl_name"] = desc.cxl_name; - segmentJSON["cxl_base_addr"] = static_cast(desc.cxl_base_addr); + segmentJSON["cxl_base_addr"] = + static_cast(desc.cxl_base_addr); Json::Value buffersJSON(Json::arrayValue); for (const auto &buffer : desc.buffers) { Json::Value bufferJSON; @@ -584,10 +585,11 @@ int TransferMetadata::removeLocalMemoryBuffer(void *addr, *new_segment_desc = *segment_desc; segment_desc = new_segment_desc; for (auto iter = segment_desc->buffers.begin(); - iter != segment_desc->buffers.end(); ++iter) { + iter != segment_desc->buffers.end(); ++iter) { if (iter->addr == (uint64_t)addr #ifdef USE_CXL - || (iter->offset + segment_desc->cxl_base_addr) == (uint64_t)addr + || + (iter->offset + segment_desc->cxl_base_addr) == (uint64_t)addr #endif ) { segment_desc->buffers.erase(iter); diff --git a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp index 846422b71..b5995aa77 100644 --- a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp @@ -688,9 +688,11 @@ struct SocketHandShakePlugin : public HandShakePlugin { // old protocol equals Connection type if (type == HandShakeRequestType::Connection || type == HandShakeRequestType::OldProtocol) { - if (on_connection_callback_) on_connection_callback_(peer, local); + if (on_connection_callback_) + on_connection_callback_(peer, local); } else if (type == HandShakeRequestType::Metadata) { - if (on_metadata_callback_) on_metadata_callback_(peer, local); + if (on_metadata_callback_) + on_metadata_callback_(peer, local); } else if (type == HandShakeRequestType::Notify) { if (on_notify_callback_) on_notify_callback_(peer, local); } else { diff --git a/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_transport_mem_c.cpp b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_transport_mem_c.cpp index 3fe281d24..74788a480 100644 --- a/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_transport_mem_c.cpp +++ b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/hccl_transport_mem_c.cpp @@ -32,7 +32,7 @@ #ifdef __cplusplus extern "C" { -#endif // __cplusplus +#endif // __cplusplus #define READ 0 #define WRITE 1 @@ -55,7 +55,7 @@ struct epoll_event g_ev; struct epoll_event g_events[MAX_EVENTS]; bool printEnabled() { - char* env = getenv("ASCEND_TRANSPORT_PRINT"); + char *env = getenv("ASCEND_TRANSPORT_PRINT"); return env != nullptr && std::string(env) == "1"; } @@ -75,7 +75,8 @@ uint16_t findAvailableTcpPort(int &sockfd, bool use_ipv6) { bind_address.sin6_family = AF_INET6; bind_address.sin6_port = htons(port); bind_address.sin6_addr = IN6ADDR_ANY_INIT; - if (bind(sockfd, (sockaddr *)&bind_address, sizeof(sockaddr_in6)) < 0) { + if (bind(sockfd, (sockaddr *)&bind_address, sizeof(sockaddr_in6)) < + 0) { continue; } } else { @@ -84,7 +85,8 @@ uint16_t findAvailableTcpPort(int &sockfd, bool use_ipv6) { bind_address.sin_family = AF_INET; bind_address.sin_port = htons(port); bind_address.sin_addr.s_addr = INADDR_ANY; - if (bind(sockfd, (sockaddr *)&bind_address, sizeof(sockaddr_in)) < 0) { + if (bind(sockfd, (sockaddr *)&bind_address, sizeof(sockaddr_in)) < + 0) { continue; } } @@ -96,14 +98,19 @@ uint16_t findAvailableTcpPort(int &sockfd, bool use_ipv6) { static int initServerNetSocket(RankInfo *local_rank_info) { RETRY_CALL(HcclNetInit(NICDeployment::NIC_DEPLOYMENT_DEVICE, - local_rank_info->devicePhyId, local_rank_info->deviceLogicId, false), "HcclNetInit failed"); + local_rank_info->devicePhyId, + local_rank_info->deviceLogicId, false), + "HcclNetInit failed"); // Use the physical network card of the device across HCCS hccl::HcclIpAddress localIp(local_rank_info->deviceIp); - RETRY_CALL(HcclNetOpenDev(&nicNetDevCtx_, NicType::DEVICE_NIC_TYPE, local_rank_info->devicePhyId, - local_rank_info->deviceLogicId, localIp), "HcclNetOpenDev DEVICE_NIC_TYPE failed"); + RETRY_CALL(HcclNetOpenDev(&nicNetDevCtx_, NicType::DEVICE_NIC_TYPE, + local_rank_info->devicePhyId, + local_rank_info->deviceLogicId, localIp), + "HcclNetOpenDev DEVICE_NIC_TYPE failed"); - nicServerSocket_ = std::make_shared(nicNetDevCtx_, local_rank_info->devicePort); + nicServerSocket_ = std::make_shared( + nicNetDevCtx_, local_rank_info->devicePort); if (nicServerSocket_ == NULL) { LOG(ERROR) << "make nicNetDevCtx_ failed"; return -1; @@ -114,18 +121,20 @@ static int initServerNetSocket(RankInfo *local_rank_info) { // Use virtual network card within HCCS hccl::HcclIpAddress localVnicIp(local_rank_info->devicePhyId); - RETRY_CALL(hrtRaGetSingleSocketVnicIpInfo( + RETRY_CALL( + hrtRaGetSingleSocketVnicIpInfo( local_rank_info->devicePhyId, DeviceIdType::DEVICE_ID_TYPE_PHY_ID, local_rank_info->devicePhyId, localVnicIp), - "hrtRaGetSingleSocketVnicIpInfo failed"); + "hrtRaGetSingleSocketVnicIpInfo failed"); RETRY_CALL(HcclNetOpenDev(&vnicNetDevCtx_, NicType::VNIC_TYPE, - local_rank_info->devicePhyId, - local_rank_info->deviceLogicId, localVnicIp), - "HcclNetOpenDev vnicNetDevCtx_ failed"); + local_rank_info->devicePhyId, + local_rank_info->deviceLogicId, localVnicIp), + "HcclNetOpenDev vnicNetDevCtx_ failed"); // control plane connection, creat serversocket, listening client - vnicServerSocket_ = std::make_shared(vnicNetDevCtx_, local_rank_info->devicePort); + vnicServerSocket_ = std::make_shared( + vnicNetDevCtx_, local_rank_info->devicePort); if (vnicServerSocket_ == NULL) { LOG(ERROR) << "vnicServerSocket_ make failed"; return -1; @@ -134,21 +143,24 @@ static int initServerNetSocket(RankInfo *local_rank_info) { RETRY_CALL(vnicServerSocket_->Init(), "vnicServerSocket_ Init failed"); RETRY_CALL(vnicServerSocket_->Listen(), "vnicServerSocket_ Listen failed"); - RETRY_CALL(HcclDispatcherInit(DispatcherType::DISPATCHER_NORMAL, local_rank_info->devicePhyId, &dispatcher_), - "client HcclDispatcherInit failed"); - + RETRY_CALL(HcclDispatcherInit(DispatcherType::DISPATCHER_NORMAL, + local_rank_info->devicePhyId, &dispatcher_), + "client HcclDispatcherInit failed"); + notifyPool_.reset(new (std::nothrow) hccl::NotifyPool()); if (notifyPool_ == nullptr) { LOG(ERROR) << "reset notifyPool error"; return -1; } - RETRY_CALL(notifyPool_->Init(local_rank_info->devicePhyId), "Init notifyPool error"); + RETRY_CALL(notifyPool_->Init(local_rank_info->devicePhyId), + "Init notifyPool error"); return 0; } -// The out-of-band socket on the host side that ascend_transport depends on, used to convey control information such as deviceId and deviceIp +// The out-of-band socket on the host side that ascend_transport depends on, +// used to convey control information such as deviceId and deviceIp static int initControlSocket(RankInfo *local_rank_info) { int ret = 0; g_server_socket_ = socket(AF_INET, SOCK_STREAM, 0); @@ -158,7 +170,8 @@ static int initControlSocket(RankInfo *local_rank_info) { } int optval = 1; - ret = setsockopt(g_server_socket_, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)); + ret = setsockopt(g_server_socket_, SOL_SOCKET, SO_REUSEADDR, &optval, + sizeof(optval)); if (ret < 0) { LOG(ERROR) << "set sock opt failed, ret: " << ret; close(g_server_socket_); @@ -171,9 +184,11 @@ static int initControlSocket(RankInfo *local_rank_info) { bind_address.sin_addr.s_addr = INADDR_ANY; bind_address.sin_port = htons(local_rank_info->hostPort); - ret = bind(g_server_socket_, (struct sockaddr*)&bind_address, sizeof(bind_address)); + ret = bind(g_server_socket_, (struct sockaddr *)&bind_address, + sizeof(bind_address)); if (ret < 0) { - LOG(INFO) << "bind failed on the default port, default port: " << local_rank_info->hostPort << ", will find available port"; + LOG(INFO) << "bind failed on the default port, default port: " + << local_rank_info->hostPort << ", will find available port"; uint16_t port = findAvailableTcpPort(g_server_socket_, false); if (port == 0) { LOG(ERROR) << "findAvailableTcpPort failed"; @@ -186,7 +201,8 @@ static int initControlSocket(RankInfo *local_rank_info) { struct timeval timeout; timeout.tv_sec = 120; timeout.tv_usec = 0; - ret = setsockopt(g_server_socket_, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof(timeout)); + ret = setsockopt(g_server_socket_, SOL_SOCKET, SO_RCVTIMEO, + (const char *)&timeout, sizeof(timeout)); if (ret < 0) { LOG(ERROR) << "Set recv timeout failed, ret: " << ret; close(g_server_socket_); @@ -199,7 +215,9 @@ static int initControlSocket(RankInfo *local_rank_info) { close(g_server_socket_); return ret; } - LOG(INFO) << "initControlSocket successful, Server listening on host port: " << local_rank_info->hostPort << "..." << " g_server_socket_" << g_server_socket_; + LOG(INFO) << "initControlSocket successful, Server listening on host port: " + << local_rank_info->hostPort << "..." << " g_server_socket_" + << g_server_socket_; g_epoll_fd = epoll_create1(0); if (g_epoll_fd == -1) { LOG(ERROR) << "epoll create Failed, ret: " << g_epoll_fd; @@ -225,7 +243,7 @@ int initTransportMem(RankInfo *local_rank_info) { } uint32_t devPid; - ret = SalGetBareTgid(reinterpret_cast(&devPid)); + ret = SalGetBareTgid(reinterpret_cast(&devPid)); if (ret) { LOG(ERROR) << "SalGetBareTgid failed: " << ret; return ret; @@ -244,13 +262,14 @@ int initTransportMem(RankInfo *local_rank_info) { << ", hostPort: " << local_rank_info->hostPort << ", device pid: " << local_rank_info->pid; - // Initialize the virtual network card and socket for the data channel, exchange RmaMem, and create the QP connection + // Initialize the virtual network card and socket for the data channel, + // exchange RmaMem, and create the QP connection ret = initServerNetSocket(local_rank_info); if (ret) { LOG(ERROR) << "initServerNetSocket failed, ret: " << ret; return ret; } - + ret = initControlSocket(local_rank_info); if (ret) { LOG(ERROR) << "initControlSocket failed, ret: " << ret; @@ -273,7 +292,8 @@ static int connectToTarget(std::string target_ip, int target_port) { } int optval = 1; - int ret = setsockopt(client_socket, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)); + int ret = setsockopt(client_socket, SOL_SOCKET, SO_REUSEADDR, &optval, + sizeof(optval)); if (ret < 0) { LOG(ERROR) << "set sock opt failed, ret: " << ret; close(client_socket); @@ -290,27 +310,31 @@ static int connectToTarget(std::string target_ip, int target_port) { close(client_socket); return -1; } - + int connected = 0; - const char* tcp_timeout_str = std::getenv("Ascend_TCP_TIMEOUT"); + const char *tcp_timeout_str = std::getenv("Ascend_TCP_TIMEOUT"); int ascend_tcp_timeout = tcp_timeout_str ? std::atoi(tcp_timeout_str) : 30; int connect_retry_times = ascend_tcp_timeout * 100; for (int i = 0; i < connect_retry_times; ++i) { - if (connect(client_socket, (struct sockaddr*)&server_addr, sizeof(server_addr)) == 0) { - LOG(INFO) << "Connect to host server " << target_ip << ":" << ntohs(server_addr.sin_port) << " successful"; + if (connect(client_socket, (struct sockaddr *)&server_addr, + sizeof(server_addr)) == 0) { + LOG(INFO) << "Connect to host server " << target_ip << ":" + << ntohs(server_addr.sin_port) << " successful"; connected = 1; break; } - LOG(INFO) << "Connect attempt " << i << " failed: " << strerror(errno) << ", retry once"; + LOG(INFO) << "Connect attempt " << i << " failed: " << strerror(errno) + << ", retry once"; std::this_thread::sleep_for(std::chrono::milliseconds(10)); } if (!connected) { - LOG(ERROR) << "Failed to connect to server after " << connect_retry_times << " retries"; + LOG(ERROR) << "Failed to connect to server after " + << connect_retry_times << " retries"; close(client_socket); return HCCL_E_TIMEOUT; } @@ -320,7 +344,8 @@ static int connectToTarget(std::string target_ip, int target_port) { int controlInfoSend(RankInfo *local_rank_info, RankInfo *remote_rank_info) { int ret = 0; - std::string key_str = inet_ntoa(remote_rank_info->hostIp) + std::to_string(remote_rank_info->devicePhyId); + std::string key_str = inet_ntoa(remote_rank_info->hostIp) + + std::to_string(remote_rank_info->devicePhyId); LOG(INFO) << "transportMemTask local_rank_info rankId: " << local_rank_info->rankId << ", serverIdx: " << local_rank_info->serverIdx @@ -331,17 +356,17 @@ int controlInfoSend(RankInfo *local_rank_info, RankInfo *remote_rank_info) { << ", hostIp: " << inet_ntoa(local_rank_info->hostIp) << ", hostPort: " << local_rank_info->hostPort << ", device pid: " << local_rank_info->pid; - + LOG(INFO) << "transportMemTask remote_rank_info rankId: " - << remote_rank_info->rankId - << ", serverIdx: " << remote_rank_info->serverIdx - << ", deviceLogicId: " << remote_rank_info->deviceLogicId - << ", devicePhyId: " << remote_rank_info->devicePhyId - << ", deviceIp: " << inet_ntoa(remote_rank_info->deviceIp) - << ", devicePort: " << remote_rank_info->devicePort - << ", hostIp: " << inet_ntoa(remote_rank_info->hostIp) - << ", hostPort: " << remote_rank_info->hostPort - << ", device pid: " << remote_rank_info->pid; + << remote_rank_info->rankId + << ", serverIdx: " << remote_rank_info->serverIdx + << ", deviceLogicId: " << remote_rank_info->deviceLogicId + << ", devicePhyId: " << remote_rank_info->devicePhyId + << ", deviceIp: " << inet_ntoa(remote_rank_info->deviceIp) + << ", devicePort: " << remote_rank_info->devicePort + << ", hostIp: " << inet_ntoa(remote_rank_info->hostIp) + << ", hostPort: " << remote_rank_info->hostPort + << ", device pid: " << remote_rank_info->pid; // Encapsulate control information RankControlInfo control_info; @@ -351,7 +376,7 @@ int controlInfoSend(RankInfo *local_rank_info, RankInfo *remote_rank_info) { control_info.deviceIp = local_rank_info->deviceIp; control_info.pid = local_rank_info->pid; // Self-built out-of-band, host socket for sending control plane - int client_socket = connectToTarget(inet_ntoa(remote_rank_info->hostIp), + int client_socket = connectToTarget(inet_ntoa(remote_rank_info->hostIp), remote_rank_info->hostPort); if (client_socket < 0) { LOG(ERROR) << "client connect failed"; @@ -367,11 +392,16 @@ int controlInfoSend(RankInfo *local_rank_info, RankInfo *remote_rank_info) { return 0; } -int createClientSocket(std::shared_ptr &hccl_socket, RankInfo *local_rank_info, RankInfo *remote_rank_info, bool is_cross_hccs, std::string tag) { +int createClientSocket(std::shared_ptr &hccl_socket, + RankInfo *local_rank_info, RankInfo *remote_rank_info, + bool is_cross_hccs, std::string tag) { int ret = 0; hccl::HcclIpAddress rempoteDevIp; - std::string key_str = inet_ntoa(remote_rank_info->hostIp) + std::to_string(remote_rank_info->devicePhyId); - std::string baseTag_ = inet_ntoa(local_rank_info->hostIp) + std::to_string(local_rank_info->devicePhyId) + key_str + tag; + std::string key_str = inet_ntoa(remote_rank_info->hostIp) + + std::to_string(remote_rank_info->devicePhyId); + std::string baseTag_ = inet_ntoa(local_rank_info->hostIp) + + std::to_string(local_rank_info->devicePhyId) + + key_str + tag; if (!is_cross_hccs) { std::vector remoteDevPhyId; remoteDevPhyId.push_back(remote_rank_info->devicePhyId); @@ -386,92 +416,115 @@ int createClientSocket(std::shared_ptr &hccl_socket, RankInfo return ret; } rempoteDevIp = hccl::HcclIpAddress(remote_rank_info->devicePhyId); - ret = hrtRaGetSingleSocketVnicIpInfo(local_rank_info->devicePhyId, - DeviceIdType::DEVICE_ID_TYPE_PHY_ID, - remote_rank_info->devicePhyId, rempoteDevIp); + ret = hrtRaGetSingleSocketVnicIpInfo( + local_rank_info->devicePhyId, DeviceIdType::DEVICE_ID_TYPE_PHY_ID, + remote_rank_info->devicePhyId, rempoteDevIp); if (ret) { LOG(ERROR) << "hrtRaGetSingleSocketVnicIpInfo, ret: " << ret; return ret; - } + } hccl_socket = std::make_shared( - baseTag_, vnicNetDevCtx_, rempoteDevIp, remote_rank_info->devicePort, - hccl::HcclSocketRole::SOCKET_ROLE_CLIENT); + baseTag_, vnicNetDevCtx_, rempoteDevIp, + remote_rank_info->devicePort, + hccl::HcclSocketRole::SOCKET_ROLE_CLIENT); } else { rempoteDevIp = hccl::HcclIpAddress(remote_rank_info->deviceIp); - hccl_socket = - std::make_shared( - baseTag_, nicNetDevCtx_, rempoteDevIp, remote_rank_info->devicePort, - hccl::HcclSocketRole::SOCKET_ROLE_CLIENT); + hccl_socket = std::make_shared( + baseTag_, nicNetDevCtx_, rempoteDevIp, remote_rank_info->devicePort, + hccl::HcclSocketRole::SOCKET_ROLE_CLIENT); } - + ret = hccl_socket->Init(); if (ret) { char deviceIp[64]; inet_ntop(AF_INET, &rempoteDevIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "client hccl_socket init failed, target devicePhyId: " - << remote_rank_info->devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", remote port: " << remote_rank_info->devicePort - << ", ret: " << ret; + LOG(ERROR) << "client hccl_socket init failed, target devicePhyId: " + << remote_rank_info->devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", rempoteDevIp: " << deviceIp + << ", remote port: " << remote_rank_info->devicePort + << ", ret: " << ret; return ret; } ret = hccl_socket->Connect(); if (ret) { char deviceIp[64]; inet_ntop(AF_INET, &rempoteDevIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "client hccl_socket Connect failed, target devicePhyId: " - << remote_rank_info->devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", remote port: " << remote_rank_info->devicePort - << ", ret: " << ret; + LOG(ERROR) << "client hccl_socket Connect failed, target devicePhyId: " + << remote_rank_info->devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", rempoteDevIp: " << deviceIp + << ", remote port: " << remote_rank_info->devicePort + << ", ret: " << ret; return ret; } - LOG(INFO) << "hccl_socket begin to connect, local devicePhyId: " << local_rank_info->devicePhyId << ", target devicePhyId: " << remote_rank_info->devicePhyId; + LOG(INFO) << "hccl_socket begin to connect, local devicePhyId: " + << local_rank_info->devicePhyId + << ", target devicePhyId: " << remote_rank_info->devicePhyId; hccl::HcclSocketStatus status; struct timespec start, end; - const char* hccl_socket_timeout_str = std::getenv("Ascend_HCCL_SOCKET_TIMEOUT"); - int hccl_socket_timeout = hccl_socket_timeout_str ? std::atoi(hccl_socket_timeout_str) : 30; - long long hccl_socket_timeout_ns = static_cast(hccl_socket_timeout) * 1000000000LL; + const char *hccl_socket_timeout_str = + std::getenv("Ascend_HCCL_SOCKET_TIMEOUT"); + int hccl_socket_timeout = + hccl_socket_timeout_str ? std::atoi(hccl_socket_timeout_str) : 30; + long long hccl_socket_timeout_ns = + static_cast(hccl_socket_timeout) * 1000000000LL; clock_gettime(CLOCK_MONOTONIC, &start); do { status = hccl_socket->GetStatus(); clock_gettime(CLOCK_MONOTONIC, &end); - long long elapsed_time = (end.tv_sec - start.tv_sec) * 1000000000LL + (end.tv_nsec - start.tv_nsec); - if (elapsed_time > hccl_socket_timeout_ns) { // Exceeds 20 seconds,TimeOut - LOG(ERROR) << "hccl_socket connect timeout, local devicePhyId: " << local_rank_info->devicePhyId << ", target devicePhyId: " << remote_rank_info->devicePhyId; + long long elapsed_time = (end.tv_sec - start.tv_sec) * 1000000000LL + + (end.tv_nsec - start.tv_nsec); + if (elapsed_time > + hccl_socket_timeout_ns) { // Exceeds 20 seconds,TimeOut + LOG(ERROR) << "hccl_socket connect timeout, local devicePhyId: " + << local_rank_info->devicePhyId + << ", target devicePhyId: " + << remote_rank_info->devicePhyId; return HCCL_E_TIMEOUT; } } while (status != hccl::HcclSocketStatus::SOCKET_OK); - LOG(INFO) << "hccl_socket connect success, local devicePhyId: " << local_rank_info->devicePhyId << ", target devicePhyId: " << remote_rank_info->devicePhyId; + LOG(INFO) << "hccl_socket connect success, local devicePhyId: " + << local_rank_info->devicePhyId + << ", target devicePhyId: " << remote_rank_info->devicePhyId; return 0; } -int createTransportMem(RankInfo *local_rank_info, RankInfo *remote_rank_info, std::shared_ptr& transport_mem) { +int createTransportMem(RankInfo *local_rank_info, RankInfo *remote_rank_info, + std::shared_ptr &transport_mem) { int ret = 0; - bool same_host = local_rank_info->hostIp.s_addr == remote_rank_info->hostIp.s_addr; - // For A2 series, internal communication among 8 cards does not cross HCCS, such as communication among cards 0-7 - bool same_group = (local_rank_info->devicePhyId / 8) == (remote_rank_info->devicePhyId / 8); + bool same_host = + local_rank_info->hostIp.s_addr == remote_rank_info->hostIp.s_addr; + // For A2 series, internal communication among 8 cards does not cross HCCS, + // such as communication among cards 0-7 + bool same_group = (local_rank_info->devicePhyId / 8) == + (remote_rank_info->devicePhyId / 8); bool is_cross_hccs = !(same_host && same_group); - std::string key_str = inet_ntoa(remote_rank_info->hostIp) + std::to_string(remote_rank_info->devicePhyId); + std::string key_str = inet_ntoa(remote_rank_info->hostIp) + + std::to_string(remote_rank_info->devicePhyId); if (printEnabled()) { - LOG(INFO) << "hccl transport is cross_hccs: " << (is_cross_hccs ? "true (cross-hccs)" : "false (same-hccs)"); + LOG(INFO) << "hccl transport is cross_hccs: " + << (is_cross_hccs ? "true (cross-hccs)" + : "false (same-hccs)"); } std::shared_ptr hccl_ctrl_socket; std::shared_ptr hccl_data_socket; - ret = createClientSocket(hccl_ctrl_socket, local_rank_info, remote_rank_info, is_cross_hccs, "ctrl"); + ret = createClientSocket(hccl_ctrl_socket, local_rank_info, + remote_rank_info, is_cross_hccs, "ctrl"); if (ret) { - LOG(ERROR) << "createClientSocket hccl_ctrl_socket failed, ret: " << ret; + LOG(ERROR) << "createClientSocket hccl_ctrl_socket failed, ret: " + << ret; return ret; } target_key_to_connection_map_[key_str].hccl_ctrl_socket = hccl_ctrl_socket; - ret = createClientSocket(hccl_data_socket, local_rank_info, remote_rank_info, is_cross_hccs, "data"); + ret = createClientSocket(hccl_data_socket, local_rank_info, + remote_rank_info, is_cross_hccs, "data"); if (ret) { - LOG(ERROR) << "createClientSocket hccl_data_socket failed, ret: " << ret; + LOG(ERROR) << "createClientSocket hccl_data_socket failed, ret: " + << ret; return ret; } target_key_to_connection_map_[key_str].hccl_data_socket = hccl_data_socket; @@ -482,49 +535,54 @@ int createTransportMem(RankInfo *local_rank_info, RankInfo *remote_rank_info, st attrInfo.serverId = local_rank_info->serverIdx; if (is_cross_hccs) { transport_mem = hccl::TransportMem::Create( - hccl::TransportMem::TpType::ROCE, notifyPool_, nicNetDevCtx_, + hccl::TransportMem::TpType::ROCE, notifyPool_, nicNetDevCtx_, dispatcher_, attrInfo); } else { transport_mem = hccl::TransportMem::Create( - hccl::TransportMem::TpType::IPC, notifyPool_, vnicNetDevCtx_, + hccl::TransportMem::TpType::IPC, notifyPool_, vnicNetDevCtx_, dispatcher_, attrInfo); } ret = transport_mem->SetDataSocket(hccl_data_socket); if (ret) { char deviceIp[64]; - inet_ntop(AF_INET, &remote_rank_info->deviceIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "client SetDataSocket failed, target devicePhyId: " - << remote_rank_info->devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", remote port: " << remote_rank_info->devicePort - << ", ret: " << ret; + inet_ntop(AF_INET, &remote_rank_info->deviceIp, deviceIp, + sizeof(deviceIp)); + LOG(ERROR) << "client SetDataSocket failed, target devicePhyId: " + << remote_rank_info->devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", rempoteDevIp: " << deviceIp + << ", remote port: " << remote_rank_info->devicePort + << ", ret: " << ret; return ret; } ret = transport_mem->SetSocket(hccl_ctrl_socket); if (ret) { char deviceIp[64]; - inet_ntop(AF_INET, &remote_rank_info->deviceIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "client SetSocket failed, target devicePhyId: " - << remote_rank_info->devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", remote port: " << remote_rank_info->devicePort - << ", ret: " << ret; + inet_ntop(AF_INET, &remote_rank_info->deviceIp, deviceIp, + sizeof(deviceIp)); + LOG(ERROR) << "client SetSocket failed, target devicePhyId: " + << remote_rank_info->devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", rempoteDevIp: " << deviceIp + << ", remote port: " << remote_rank_info->devicePort + << ", ret: " << ret; return ret; } - const char* transport_mem_timeout_str = std::getenv("Ascend_TRANSPORT_MEM_TIMEOUT"); - int transport_mem_timeout = transport_mem_timeout_str ? std::atoi(transport_mem_timeout_str) : 120; + const char *transport_mem_timeout_str = + std::getenv("Ascend_TRANSPORT_MEM_TIMEOUT"); + int transport_mem_timeout = + transport_mem_timeout_str ? std::atoi(transport_mem_timeout_str) : 120; ret = transport_mem->Connect(transport_mem_timeout); if (ret) { char deviceIp[64]; - inet_ntop(AF_INET, &remote_rank_info->deviceIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "client Connect failed, target devicePhyId: " - << remote_rank_info->devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", remote port: " << remote_rank_info->devicePort - << ", ret: " << ret; + inet_ntop(AF_INET, &remote_rank_info->deviceIp, deviceIp, + sizeof(deviceIp)); + LOG(ERROR) << "client Connect failed, target devicePhyId: " + << remote_rank_info->devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", rempoteDevIp: " << deviceIp + << ", remote port: " << remote_rank_info->devicePort + << ", ret: " << ret; return ret; } LOG(INFO) << "transport_mem connect success"; @@ -544,33 +602,44 @@ int createTransportMem(RankInfo *local_rank_info, RankInfo *remote_rank_info, st ret = HcclMemReg(nicNetDevCtx_, &mem, &buf); } if (ret != 0 && ret != 20) { - LOG(ERROR) << "HcclMemReg failed, ret: " << ret << " addr: " << g_localMergeMem[i].addr << " len: " << g_localMergeMem[i].len; + LOG(ERROR) << "HcclMemReg failed, ret: " << ret + << " addr: " << g_localMergeMem[i].addr + << " len: " << g_localMergeMem[i].len; return ret; } char *desc = nullptr; uint64_t desc_len = 0; ret = HcclMemExport(&buf, &desc, &desc_len); if (ret) { - LOG(ERROR) << "HcclMemExport failed, ret: " << ret << ", addr: " << g_localMergeMem[i].addr << ", len: " << g_localMergeMem[i].len; + LOG(ERROR) << "HcclMemExport failed, ret: " << ret + << ", addr: " << g_localMergeMem[i].addr + << ", len: " << g_localMergeMem[i].len; return ret; } - + rmaMemDescs[i].localRankId = local_rank_info->deviceLogicId; rmaMemDescs[i].remoteRankId = remote_rank_info->deviceLogicId; - memset_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, 0, hccl::TRANSPORT_EMD_ESC_SIZE); - if (memcpy_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, desc, desc_len + 1) != EOK) { - LOG(ERROR) << "memcpy_s failed, ret: " << ret << ", addr: " << g_localMergeMem[i].addr << ", len: " << g_localMergeMem[i].len; + memset_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, 0, + hccl::TRANSPORT_EMD_ESC_SIZE); + if (memcpy_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, desc, + desc_len + 1) != EOK) { + LOG(ERROR) << "memcpy_s failed, ret: " << ret + << ", addr: " << g_localMergeMem[i].addr + << ", len: " << g_localMergeMem[i].len; return -1; } - // In the scenario within HCCS, it is necessary to call HcclMemGrant to authorize peer memory + // In the scenario within HCCS, it is necessary to call HcclMemGrant to + // authorize peer memory if (!is_cross_hccs) { HcclMemGrantInfo grant_info; grant_info.remotePid = (int32_t)remote_rank_info->pid; grant_info.remoteSdid = 0xFFFFFFFF; ret = HcclMemGrant(&buf, &grant_info); if (ret) { - LOG(ERROR) << "HcclMemGrant failed, ret: " << ret << ", addr: " << g_localMergeMem[i].addr << ", len: " << g_localMergeMem[i].len; + LOG(ERROR) << "HcclMemGrant failed, ret: " << ret + << ", addr: " << g_localMergeMem[i].addr + << ", len: " << g_localMergeMem[i].len; return ret; } } @@ -583,41 +652,54 @@ int createTransportMem(RankInfo *local_rank_info, RankInfo *remote_rank_info, st hccl::TransportMem::RmaMemDescs remoteRmaMemDescs; remoteRmaMemDescs.array = remoteRmaMemDescArray.data(); remoteRmaMemDescs.arrayLength = m_num; - ret = transport_mem->ExchangeMemDesc(localRmaMemDescs, remoteRmaMemDescs, actualNumOfRemote); + ret = transport_mem->ExchangeMemDesc(localRmaMemDescs, remoteRmaMemDescs, + actualNumOfRemote); if (ret) { - LOG(ERROR) << "transport_mem->ExchangeMemDesc failed, ret: " << ret << ", local_rank: " << local_rank_info->devicePhyId << ", remote_rank: " << remote_rank_info->devicePhyId; + LOG(ERROR) << "transport_mem->ExchangeMemDesc failed, ret: " << ret + << ", local_rank: " << local_rank_info->devicePhyId + << ", remote_rank: " << remote_rank_info->devicePhyId; return ret; } std::vector remoteRmaMemArray(m_num); for (uint32_t i = 0; i < m_num; ++i) { - ret = transport_mem->EnableMemAccess(remoteRmaMemDescArray[i], remoteRmaMemArray[i]); + ret = transport_mem->EnableMemAccess(remoteRmaMemDescArray[i], + remoteRmaMemArray[i]); if (ret) { - LOG(ERROR) << "transport_mem->EnableMemAccess failed, ret: " << ret << ", i: " << i << ", local_rank: " << local_rank_info->devicePhyId << ", remote_rank: " << remote_rank_info->devicePhyId; + LOG(ERROR) << "transport_mem->EnableMemAccess failed, ret: " << ret + << ", i: " << i + << ", local_rank: " << local_rank_info->devicePhyId + << ", remote_rank: " << remote_rank_info->devicePhyId; return ret; } } - LOG(INFO) << "ExchangeMem and EnableMemAccess Success, local devicePhyId: " << local_rank_info->devicePhyId << ", target devicePhyId: " << remote_rank_info->devicePhyId; + LOG(INFO) << "ExchangeMem and EnableMemAccess Success, local devicePhyId: " + << local_rank_info->devicePhyId + << ", target devicePhyId: " << remote_rank_info->devicePhyId; return 0; } int transportMemAddOpFence(RankInfo *remote_rank_info, aclrtStream stream) { - std::string key_str = inet_ntoa(remote_rank_info->hostIp) + std::to_string(remote_rank_info->devicePhyId); - int ret = target_key_to_connection_map_[key_str].transport_mem->AddOpFence(stream); + std::string key_str = inet_ntoa(remote_rank_info->hostIp) + + std::to_string(remote_rank_info->devicePhyId); + int ret = target_key_to_connection_map_[key_str].transport_mem->AddOpFence( + stream); if (ret) { LOG(ERROR) << "transport_mem AddOpFence failed, ret: " << ret; - return ret; + return ret; } return 0; } -int transportMemTask(RankInfo *local_rank_info, RankInfo *remote_rank_info, - int op_code, uint64_t offset, - uint64_t req_len, void *local_mem, aclrtStream stream) { +int transportMemTask(RankInfo *local_rank_info, RankInfo *remote_rank_info, + int op_code, uint64_t offset, uint64_t req_len, + void *local_mem, aclrtStream stream) { int ret = 0; std::shared_ptr transport_mem{}; - // Check if a connection has been established with the peer, and send local information to the peer - std::string key_str = inet_ntoa(remote_rank_info->hostIp) + std::to_string(remote_rank_info->devicePhyId); + // Check if a connection has been established with the peer, and send local + // information to the peer + std::string key_str = inet_ntoa(remote_rank_info->hostIp) + + std::to_string(remote_rank_info->devicePhyId); auto iter = target_key_to_connection_map_.find(key_str); if (iter == target_key_to_connection_map_.end()) { ret = controlInfoSend(local_rank_info, remote_rank_info); @@ -625,7 +707,8 @@ int transportMemTask(RankInfo *local_rank_info, RankInfo *remote_rank_info, LOG(ERROR) << "controlInfoSend failed, ret: " << ret; return ret; } - ret = createTransportMem(local_rank_info, remote_rank_info, transport_mem); + ret = createTransportMem(local_rank_info, remote_rank_info, + transport_mem); if (ret) { LOG(ERROR) << "createTransportMem failed, ret: " << ret; return ret; @@ -639,24 +722,22 @@ int transportMemTask(RankInfo *local_rank_info, RankInfo *remote_rank_info, hccl::TransportMem::RmaOpMem remoteMem; remoteMem.addr = (void *)offset; remoteMem.size = req_len; - if (op_code == WRITE){ + if (op_code == WRITE) { ret = transport_mem->Write(remoteMem, localMem, stream); if (ret) { - LOG(ERROR) << "transport_mem Write failed, localMem.addr: " - << local_mem << "local_mem.size: " << req_len - << ", remoteMem.addr: " << remoteMem.addr - << ", remoteMem.size: " << req_len - << ", ret: " << ret; + LOG(ERROR) << "transport_mem Write failed, localMem.addr: " + << local_mem << "local_mem.size: " << req_len + << ", remoteMem.addr: " << remoteMem.addr + << ", remoteMem.size: " << req_len << ", ret: " << ret; return ret; } } else { ret = transport_mem->Read(localMem, remoteMem, stream); if (ret) { - LOG(ERROR) << "transport_mem Read failed, localMem.addr: " - << local_mem << "local_mem.size: " << req_len - << ", remoteMem.addr: " << remoteMem.addr - << ", remoteMem.size: " << req_len - << ", ret: " << ret; + LOG(ERROR) << "transport_mem Read failed, localMem.addr: " + << local_mem << "local_mem.size: " << req_len + << ", remoteMem.addr: " << remoteMem.addr + << ", remoteMem.size: " << req_len << ", ret: " << ret; return ret; } } @@ -668,17 +749,23 @@ static int acceptFromTarget() { int client_socket; struct sockaddr_in client_addr; socklen_t client_len = sizeof(client_addr); - client_socket = accept(g_server_socket_, (struct sockaddr*)&client_addr, &client_len); + client_socket = + accept(g_server_socket_, (struct sockaddr *)&client_addr, &client_len); if (client_socket < 0) { LOG(ERROR) << "Accept failed"; return client_socket; } - LOG(INFO) << "host client connected from " << inet_ntoa(client_addr.sin_addr) << ":" << ntohs(client_addr.sin_port); + LOG(INFO) << "host client connected from " + << inet_ntoa(client_addr.sin_addr) << ":" + << ntohs(client_addr.sin_port); return client_socket; } -int acceptSocket(std::shared_ptr &hccl_socket, RankInfo *local_rank_info, RankControlInfo remote_control_info, std::string baseTag_, hccl::HcclIpAddress rempoteDevIp, bool is_cross_hccs) { +int acceptSocket(std::shared_ptr &hccl_socket, + RankInfo *local_rank_info, RankControlInfo remote_control_info, + std::string baseTag_, hccl::HcclIpAddress rempoteDevIp, + bool is_cross_hccs) { int ret = 0; std::vector wlistInfoVec; SocketWlistInfo wlistInfo = {}; @@ -693,12 +780,14 @@ int acceptSocket(std::shared_ptr &hccl_socket, RankInfo *local LOG(ERROR) << "serverSocket AddWhiteList failed, ret: " << ret; return ret; } - // Before using the device-side network card for communication, it is necessary to add the client device address to the whitelist. + // Before using the device-side network card for communication, it is + // necessary to add the client device address to the whitelist. LOG(INFO) << "Add the client's Device IP address to the whitelist success."; ret = serverSocket->Accept(baseTag_, hccl_socket); if (ret) { - LOG(ERROR) << "serverSocket transportMemAccept ctrl socket failed ret: " << ret; + LOG(ERROR) << "serverSocket transportMemAccept ctrl socket failed ret: " + << ret; return ret; } return 0; @@ -732,34 +821,42 @@ int transportMemAccept(RankInfo *local_rank_info) { << ", devicePhyId: " << remote_control_info.devicePhyId << ", hostIp: " << inet_ntoa(remote_control_info.hostIp) << ", deviceIp: " << inet_ntoa(remote_control_info.deviceIp) - << ", device pid: " << remote_control_info.pid; + << ", device pid: " << remote_control_info.pid; // Check if TransportMem has been established with the peer - std::string key_str = inet_ntoa(remote_control_info.hostIp) + std::to_string(remote_control_info.devicePhyId); + std::string key_str = inet_ntoa(remote_control_info.hostIp) + + std::to_string(remote_control_info.devicePhyId); auto iter = target_key_to_connection_map_.find(key_str); if (iter != target_key_to_connection_map_.end()) { - LOG(WARNING) << "A duplicate connection request from the same remote endpoint has been detected, the remote side may have restarted."; + LOG(WARNING) + << "A duplicate connection request from the same remote endpoint " + "has been detected, the remote side may have restarted."; } - std::string baseTag_ = key_str + inet_ntoa(local_rank_info->hostIp) + std::to_string(local_rank_info->devicePhyId); + std::string baseTag_ = key_str + inet_ntoa(local_rank_info->hostIp) + + std::to_string(local_rank_info->devicePhyId); hccl::HcclIpAddress rempoteDevIp; std::shared_ptr hccl_ctrl_socket; std::shared_ptr hccl_data_socket; - bool same_host = local_rank_info->hostIp.s_addr == remote_control_info.hostIp.s_addr; - // For A2 series, internal communication among 8 cards does not cross HCCS, such as communication among cards 0-7 - bool same_group = (local_rank_info->devicePhyId / 8) == (remote_control_info.devicePhyId / 8); + bool same_host = + local_rank_info->hostIp.s_addr == remote_control_info.hostIp.s_addr; + // For A2 series, internal communication among 8 cards does not cross HCCS, + // such as communication among cards 0-7 + bool same_group = (local_rank_info->devicePhyId / 8) == + (remote_control_info.devicePhyId / 8); bool is_cross_hccs = !(same_host && same_group); if (printEnabled()) { - LOG(INFO) << "transport is cross_hccs: " << (is_cross_hccs ? "true (cross-hccs)" : "false (same-hccs)"); + LOG(INFO) << "transport is cross_hccs: " + << (is_cross_hccs ? "true (cross-hccs)" + : "false (same-hccs)"); } if (!is_cross_hccs) { std::vector remoteDevPhyId; rempoteDevIp = hccl::HcclIpAddress(remote_control_info.devicePhyId); remoteDevPhyId.push_back(remote_control_info.devicePhyId); - ret = hrtRaGetSingleSocketVnicIpInfo(local_rank_info->devicePhyId, - DeviceIdType::DEVICE_ID_TYPE_PHY_ID, - remote_control_info.devicePhyId, - rempoteDevIp); + ret = hrtRaGetSingleSocketVnicIpInfo( + local_rank_info->devicePhyId, DeviceIdType::DEVICE_ID_TYPE_PHY_ID, + remote_control_info.devicePhyId, rempoteDevIp); if (ret) { LOG(ERROR) << "hrtRaGetSingleSocketVnicIpInfo failed, ret: " << ret; return ret; @@ -774,26 +871,34 @@ int transportMemAccept(RankInfo *local_rank_info) { LOG(ERROR) << "P2PMgmtPub EnableP2P failed, ret: " << ret; return ret; } - ret = acceptSocket(hccl_ctrl_socket, local_rank_info, remote_control_info, baseTag_ + "ctrl", rempoteDevIp, is_cross_hccs); + ret = + acceptSocket(hccl_ctrl_socket, local_rank_info, remote_control_info, + baseTag_ + "ctrl", rempoteDevIp, is_cross_hccs); if (ret) { LOG(ERROR) << "acceptSocket ctrl failed, ret: " << ret; return ret; } - ret = acceptSocket(hccl_data_socket, local_rank_info, remote_control_info, baseTag_ + "data", rempoteDevIp, is_cross_hccs); + ret = + acceptSocket(hccl_data_socket, local_rank_info, remote_control_info, + baseTag_ + "data", rempoteDevIp, is_cross_hccs); if (ret) { LOG(ERROR) << "acceptSocket data failed, ret: " << ret; return ret; } } else { rempoteDevIp = hccl::HcclIpAddress(remote_control_info.deviceIp); - ret = acceptSocket(hccl_ctrl_socket, local_rank_info, remote_control_info, baseTag_ + "ctrl", rempoteDevIp, is_cross_hccs); + ret = + acceptSocket(hccl_ctrl_socket, local_rank_info, remote_control_info, + baseTag_ + "ctrl", rempoteDevIp, is_cross_hccs); if (ret) { LOG(ERROR) << "acceptSocket ctrl failed, ret: " << ret; return ret; } - ret = acceptSocket(hccl_data_socket, local_rank_info, remote_control_info, baseTag_ + "data", rempoteDevIp, is_cross_hccs); + ret = + acceptSocket(hccl_data_socket, local_rank_info, remote_control_info, + baseTag_ + "data", rempoteDevIp, is_cross_hccs); if (ret) { LOG(ERROR) << "acceptSocket data failed, ret: " << ret; return ret; @@ -811,51 +916,50 @@ int transportMemAccept(RankInfo *local_rank_info) { attrInfo.serverId = local_rank_info->serverIdx; if (is_cross_hccs) { transport_mem = hccl::TransportMem::Create( - hccl::TransportMem::TpType::ROCE, notifyPool_, nicNetDevCtx_, + hccl::TransportMem::TpType::ROCE, notifyPool_, nicNetDevCtx_, dispatcher_, attrInfo); } else { transport_mem = hccl::TransportMem::Create( - hccl::TransportMem::TpType::IPC, notifyPool_, vnicNetDevCtx_, + hccl::TransportMem::TpType::IPC, notifyPool_, vnicNetDevCtx_, dispatcher_, attrInfo); } ret = transport_mem->SetDataSocket(hccl_data_socket); if (ret) { char deviceIp[64]; inet_ntop(AF_INET, &rempoteDevIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "client SetDataSocket failed, target devicePhyId: " - << remote_control_info.devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", ret: " << ret; + LOG(ERROR) << "client SetDataSocket failed, target devicePhyId: " + << remote_control_info.devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", rempoteDevIp: " << deviceIp << ", ret: " << ret; return ret; } - + ret = transport_mem->SetSocket(hccl_ctrl_socket); if (ret) { char deviceIp[64]; inet_ntop(AF_INET, &rempoteDevIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "client Connect failed, target devicePhyId: " - << remote_control_info.devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", ret: " << ret; + LOG(ERROR) << "client Connect failed, target devicePhyId: " + << remote_control_info.devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", rempoteDevIp: " << deviceIp << ", ret: " << ret; return ret; } - const char* transport_mem_timeout_str = std::getenv("Ascend_TRANSPORT_MEM_TIMEOUT"); - int transport_mem_timeout = transport_mem_timeout_str ? std::atoi(transport_mem_timeout_str) : 120; + const char *transport_mem_timeout_str = + std::getenv("Ascend_TRANSPORT_MEM_TIMEOUT"); + int transport_mem_timeout = + transport_mem_timeout_str ? std::atoi(transport_mem_timeout_str) : 120; ret = transport_mem->Connect(transport_mem_timeout); if (ret) { char deviceIp[64]; inet_ntop(AF_INET, &rempoteDevIp, deviceIp, sizeof(deviceIp)); - LOG(ERROR) << "client Connect failed, target devicePhyId: " - << remote_control_info.devicePhyId - << ", local devicePhyId: " << local_rank_info->devicePhyId - << ", rempoteDevIp: " << deviceIp - << ", ret: " << ret; + LOG(ERROR) << "client Connect failed, target devicePhyId: " + << remote_control_info.devicePhyId + << ", local devicePhyId: " << local_rank_info->devicePhyId + << ", rempoteDevIp: " << deviceIp << ", ret: " << ret; return ret; } - target_key_to_connection_map_[key_str].transport_mem = transport_mem; + target_key_to_connection_map_[key_str].transport_mem = transport_mem; size_t m_num = g_localMergeMem.size(); std::vector rmaMemDescs(m_num); @@ -871,33 +975,44 @@ int transportMemAccept(RankInfo *local_rank_info) { ret = HcclMemReg(nicNetDevCtx_, &mem, &buf); } if (ret != 0 && ret != 20) { - LOG(ERROR) << "HcclMemReg failed, ret: " << ret << ", addr: " << g_localMergeMem[i].addr << ", len: " << g_localMergeMem[i].len; + LOG(ERROR) << "HcclMemReg failed, ret: " << ret + << ", addr: " << g_localMergeMem[i].addr + << ", len: " << g_localMergeMem[i].len; return ret; } char *desc = nullptr; uint64_t desc_len = 0; ret = HcclMemExport(&buf, &desc, &desc_len); if (ret) { - LOG(ERROR) << "HcclMemExport failed, ret: " << ret << ", addr: " << g_localMergeMem[i].addr << ", len: " << g_localMergeMem[i].len; + LOG(ERROR) << "HcclMemExport failed, ret: " << ret + << ", addr: " << g_localMergeMem[i].addr + << ", len: " << g_localMergeMem[i].len; return ret; } - + rmaMemDescs[i].localRankId = local_rank_info->deviceLogicId; rmaMemDescs[i].remoteRankId = remote_control_info.deviceLogicId; - memset_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, 0, hccl::TRANSPORT_EMD_ESC_SIZE); - if (memcpy_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, desc, desc_len + 1) != EOK) { - LOG(ERROR) << "memcpy_s failed, ret: " << ret << ", addr: " << g_localMergeMem[i].addr << ", len: " << g_localMergeMem[i].len; + memset_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, 0, + hccl::TRANSPORT_EMD_ESC_SIZE); + if (memcpy_s(rmaMemDescs[i].memDesc, hccl::TRANSPORT_EMD_ESC_SIZE, desc, + desc_len + 1) != EOK) { + LOG(ERROR) << "memcpy_s failed, ret: " << ret + << ", addr: " << g_localMergeMem[i].addr + << ", len: " << g_localMergeMem[i].len; return -1; } - // In the scenario within HCCS, it is necessary to call HcclMemGrant to authorize peer memory + // In the scenario within HCCS, it is necessary to call HcclMemGrant to + // authorize peer memory if (!is_cross_hccs) { HcclMemGrantInfo grant_info; grant_info.remotePid = (int32_t)remote_control_info.pid; grant_info.remoteSdid = 0xFFFFFFFF; ret = HcclMemGrant(&buf, &grant_info); if (ret) { - LOG(ERROR) << "HcclMemGrant failed, ret: " << ret << ", addr: " << g_localMergeMem[i].addr << ", len: " << g_localMergeMem[i].len; + LOG(ERROR) << "HcclMemGrant failed, ret: " << ret + << ", addr: " << g_localMergeMem[i].addr + << ", len: " << g_localMergeMem[i].len; return ret; } } @@ -910,21 +1025,30 @@ int transportMemAccept(RankInfo *local_rank_info) { hccl::TransportMem::RmaMemDescs remoteRmaMemDescs; remoteRmaMemDescs.array = remoteRmaMemDescArray.data(); remoteRmaMemDescs.arrayLength = m_num; - ret = transport_mem->ExchangeMemDesc(localRmaMemDescs, remoteRmaMemDescs, actualNumOfRemote); + ret = transport_mem->ExchangeMemDesc(localRmaMemDescs, remoteRmaMemDescs, + actualNumOfRemote); if (ret) { - LOG(ERROR) << "transport_mem->ExchangeMemDesc failed, ret: " << ret << ", local_rank: " << local_rank_info->devicePhyId << ", remote_rank: " << remote_control_info.devicePhyId; + LOG(ERROR) << "transport_mem->ExchangeMemDesc failed, ret: " << ret + << ", local_rank: " << local_rank_info->devicePhyId + << ", remote_rank: " << remote_control_info.devicePhyId; return ret; } std::vector remoteRmaMemArray(m_num); for (uint32_t i = 0; i < m_num; ++i) { - ret = transport_mem->EnableMemAccess(remoteRmaMemDescArray[i], remoteRmaMemArray[i]); + ret = transport_mem->EnableMemAccess(remoteRmaMemDescArray[i], + remoteRmaMemArray[i]); if (ret) { - LOG(ERROR) << "transport_mem->EnableMemAccess failed, ret: " << ret << ", i: " << i << ", local_rank: " << local_rank_info->devicePhyId << ", remote_rank: " << remote_control_info.devicePhyId; + LOG(ERROR) << "transport_mem->EnableMemAccess failed, ret: " << ret + << ", i: " << i + << ", local_rank: " << local_rank_info->devicePhyId + << ", remote_rank: " << remote_control_info.devicePhyId; return ret; } } - LOG(INFO) << "ExchangeMem and EnableMemAccess Success, local devicePhyId: " << local_rank_info->devicePhyId << ", target devicePhyId: " << remote_control_info.devicePhyId; + LOG(INFO) << "ExchangeMem and EnableMemAccess Success, local devicePhyId: " + << local_rank_info->devicePhyId + << ", target devicePhyId: " << remote_control_info.devicePhyId; return 0; } @@ -935,4 +1059,4 @@ int regLocalRmaMem(void *addr, uint64_t length) { #ifdef __cplusplus } -#endif // __cplusplus \ No newline at end of file +#endif // __cplusplus \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/hccl_transport.cpp b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/hccl_transport.cpp index 22235ccba..b1e1b09a6 100644 --- a/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/hccl_transport.cpp +++ b/mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/hccl_transport.cpp @@ -23,15 +23,15 @@ #include #include "transport/ascend_transport/hccl_transport/hccl_transport.h" -namespace mooncake{ +namespace mooncake { HcclTransport::HcclTransport() : running_(-1) { - //TODO + // TODO } HcclTransport::~HcclTransport() { if (running_) { running_ = false; - + for (size_t i = 0; i < THREAD_NUM; ++i) { allInitiatorThreads_[i].join(); allAcceptThreads_[i].join(); @@ -51,11 +51,11 @@ void HcclTransport::initiatorLoop(int deviceLogicId, int selfIdx) { if (ret) { LOG(ERROR) << "HcclTransport: aclrtCreateStream error, ret: " << ret; } - - while(1) { + + while (1) { auto waitlock = std::chrono::high_resolution_clock::now(); std::unique_lock lock(initiator_mutex_); - if (allReqQueues_[selfIdx].empty()){ + if (allReqQueues_[selfIdx].empty()) { initiator_cond_.wait(lock); } auto start = std::chrono::high_resolution_clock::now(); @@ -65,9 +65,12 @@ void HcclTransport::initiatorLoop(int deviceLogicId, int selfIdx) { if (slice_list.empty()) { LOG(ERROR) << "HcclTransport: empty transfer request batch"; } - auto segment_desc = metadata_->getSegmentDescByID(slice_list[0]->target_id); + auto segment_desc = + metadata_->getSegmentDescByID(slice_list[0]->target_id); if (!segment_desc) { - LOG(ERROR) << "Unable to get target segment ID, please recheck, segment ID: " << slice_list[0]->target_id; + LOG(ERROR) << "Unable to get target segment ID, please recheck, " + "segment ID: " + << slice_list[0]->target_id; for (auto slice : slice_list) { slice->markFailed(); } @@ -75,41 +78,42 @@ void HcclTransport::initiatorLoop(int deviceLogicId, int selfIdx) { } remote_rank_info_.rankId = segment_desc->rank_info.rankId; - inet_pton(AF_INET, segment_desc->rank_info.hostIp.c_str(), &remote_rank_info_.hostIp); + inet_pton(AF_INET, segment_desc->rank_info.hostIp.c_str(), + &remote_rank_info_.hostIp); remote_rank_info_.hostPort = segment_desc->rank_info.hostPort; remote_rank_info_.deviceLogicId = segment_desc->rank_info.deviceLogicId; remote_rank_info_.devicePhyId = segment_desc->rank_info.devicePhyId; - inet_pton(AF_INET, segment_desc->rank_info.deviceIp.c_str(), &remote_rank_info_.deviceIp); + inet_pton(AF_INET, segment_desc->rank_info.deviceIp.c_str(), + &remote_rank_info_.deviceIp); remote_rank_info_.devicePort = segment_desc->rank_info.devicePort; remote_rank_info_.serverIdx = 0; remote_rank_info_.pid = segment_desc->rank_info.pid; for (auto slice : slice_list) { - ret = transportMemTask(&local_rank_info_, &remote_rank_info_, slice->opcode, - slice->hccl.dest_addr, slice->length, slice->source_addr, stream); + ret = transportMemTask(&local_rank_info_, &remote_rank_info_, + slice->opcode, slice->hccl.dest_addr, + slice->length, slice->source_addr, stream); if (ret) { - LOG(ERROR) << "HcclTransport: transportMemTask error, local devicePhyId: " - << local_rank_info_.devicePhyId - << ", remote devicePhyId: " - << remote_rank_info_.devicePhyId - << ", source_addr: " - << slice->source_addr - << ", dest_addr: " - << slice->hccl.dest_addr - << ", ret: " << ret; + LOG(ERROR) << "HcclTransport: transportMemTask error, local " + "devicePhyId: " + << local_rank_info_.devicePhyId + << ", remote devicePhyId: " + << remote_rank_info_.devicePhyId + << ", source_addr: " << slice->source_addr + << ", dest_addr: " << slice->hccl.dest_addr + << ", ret: " << ret; slice->markFailed(); slice->status = Slice::SliceStatus::FAILED; } } - + auto mid = std::chrono::high_resolution_clock::now(); ret = transportMemAddOpFence(&remote_rank_info_, stream); if (ret) { - LOG(ERROR) << "transportMemAddOpFence failed, local devicePhyId: " - << local_rank_info_.devicePhyId - << ", remote devicePhyId: " - << remote_rank_info_.devicePhyId - << ", ret: " << ret; + LOG(ERROR) << "transportMemAddOpFence failed, local devicePhyId: " + << local_rank_info_.devicePhyId + << ", remote devicePhyId: " + << remote_rank_info_.devicePhyId << ", ret: " << ret; for (auto slice : slice_list) { slice->markFailed(); } @@ -118,11 +122,10 @@ void HcclTransport::initiatorLoop(int deviceLogicId, int selfIdx) { ret = aclrtSynchronizeStream(stream); if (ret) { - LOG(ERROR) << "aclrtSynchronizeStream failed, local devicePhyId: " - << local_rank_info_.devicePhyId - << ", remote devicePhyId: " - << remote_rank_info_.devicePhyId - << ", ret: " << ret; + LOG(ERROR) << "aclrtSynchronizeStream failed, local devicePhyId: " + << local_rank_info_.devicePhyId + << ", remote devicePhyId: " + << remote_rank_info_.devicePhyId << ", ret: " << ret; for (auto slice : slice_list) { slice->markFailed(); } @@ -131,23 +134,35 @@ void HcclTransport::initiatorLoop(int deviceLogicId, int selfIdx) { if (slice->status != Slice::SliceStatus::FAILED) { slice->markSuccess(); slice->task->transferred_bytes = slice->length; - } + } } auto stop = std::chrono::high_resolution_clock::now(); if (printEnabled()) { pid_t pid = getpid(); - auto duration_wait = std::chrono::duration_cast(start - waitlock); - auto duration_call = std::chrono::duration_cast(mid - start); - auto duration_addOpfence = std::chrono::duration_cast(addOpfence - mid); - auto duration_sync = std::chrono::duration_cast(stop - addOpfence); - LOG(INFO) << "pid: " << pid - << ", target hostIp: " << segment_desc->rank_info.hostIp.c_str() - << ", local devicePhyId: " << local_rank_info_.devicePhyId - << ", target devicePhyId: " << remote_rank_info_.devicePhyId - << ", batch waitlock spent: "<< duration_wait.count() << "ms" - << ", batch call spent: "<< duration_call.count() << "us" - << ", batch addOpfence spent: " << duration_addOpfence.count() << "us" - << ", batch sync spent: " << duration_sync.count() << "us"; + auto duration_wait = + std::chrono::duration_cast(start - + waitlock); + auto duration_call = + std::chrono::duration_cast(mid - + start); + auto duration_addOpfence = + std::chrono::duration_cast( + addOpfence - mid); + auto duration_sync = + std::chrono::duration_cast( + stop - addOpfence); + LOG(INFO) << "pid: " << pid << ", target hostIp: " + << segment_desc->rank_info.hostIp.c_str() + << ", local devicePhyId: " << local_rank_info_.devicePhyId + << ", target devicePhyId: " + << remote_rank_info_.devicePhyId + << ", batch waitlock spent: " << duration_wait.count() + << "ms" + << ", batch call spent: " << duration_call.count() << "us" + << ", batch addOpfence spent: " + << duration_addOpfence.count() << "us" + << ", batch sync spent: " << duration_sync.count() + << "us"; } else { (void)waitlock; (void)start; @@ -163,11 +178,12 @@ void HcclTransport::acceptLoop(int deviceLogicId) { if (ret) { LOG(ERROR) << "HcclTransport: aclrtSetDevice failed ret: " << ret; } - while(running_) { + while (running_) { ret = transportMemAccept(&local_rank_info_); if (ret) { - LOG(ERROR) << "HcclTransport: transportMemAccept failed ret: " << ret; - } + LOG(ERROR) << "HcclTransport: transportMemAccept failed ret: " + << ret; + } } } @@ -179,34 +195,44 @@ int HcclTransport::initPdThread() { if (ret) { LOG(ERROR) << "HcclTransport: aclrtGetDevice failed, ret: " << ret; return ret; - } + } for (int i = 0; i < THREAD_NUM; ++i) { - allInitiatorThreads_[i] = std::thread(&HcclTransport::initiatorLoop, this, deviceLogicId, i); - allAcceptThreads_[i] = std::thread(&HcclTransport::acceptLoop, this, deviceLogicId); + allInitiatorThreads_[i] = + std::thread(&HcclTransport::initiatorLoop, this, deviceLogicId, i); + allAcceptThreads_[i] = + std::thread(&HcclTransport::acceptLoop, this, deviceLogicId); } - LOG(INFO) << "HcclTransport: initPdThread, pid: " << pid << ";" << "init " << THREAD_NUM << " initiator threads and accept threads, deviceLogicId: " << deviceLogicId; + LOG(INFO) << "HcclTransport: initPdThread, pid: " << pid << ";" << "init " + << THREAD_NUM + << " initiator threads and accept threads, deviceLogicId: " + << deviceLogicId; return 0; } // Get HostIp\Port\DevicePhyId -int HcclTransport::getDevIdAndIpPortFromServerName(std::string& identifier, std::string& hostIp, int &port, int& npuId) { +int HcclTransport::getDevIdAndIpPortFromServerName(std::string &identifier, + std::string &hostIp, + int &port, int &npuId) { size_t firstColon = identifier.find(":"); if (firstColon == std::string::npos) { - LOG(ERROR) << "HcclTransport: getDevIdAndIpPortFromServerName failed, identifier is empty"; + LOG(ERROR) << "HcclTransport: getDevIdAndIpPortFromServerName failed, " + "identifier is empty"; return -1; } size_t secondColon = identifier.find(":", firstColon + 1); if (secondColon == std::string::npos) { - LOG(ERROR) << "HcclTransport: getDevIdAndIpPortFromServerName failed, second colon missing"; + LOG(ERROR) << "HcclTransport: getDevIdAndIpPortFromServerName failed, " + "second colon missing"; return -1; } hostIp = identifier.substr(0, firstColon); - std::string portStr = identifier.substr(firstColon + 1, secondColon - firstColon - 1); + std::string portStr = + identifier.substr(firstColon + 1, secondColon - firstColon - 1); try { port = std::stoi(portStr); } catch (const std::exception &e) { @@ -238,7 +264,7 @@ int HcclTransport::rankInfoParse(int devicePhyId, std::string hostIp) { LOG(ERROR) << "HcclTransport: aclrtGetDevice failed, ret: " << ret; return ret; } - + // Default configuration file path for HCCL std::ifstream fin("/etc/hccn.conf"); if (!fin) { @@ -252,31 +278,55 @@ int HcclTransport::rankInfoParse(int devicePhyId, std::string hostIp) { size_t equal_pos = line.find('='); if (equal_pos != std::string::npos) { std::string key = line.substr(8, equal_pos - 8); - key.erase(key.begin(), std::find_if(key.begin(), key.end(), [](unsigned char c){ return !std::isspace(c); })); + key.erase(key.begin(), std::find_if(key.begin(), key.end(), + [](unsigned char c) { + return !std::isspace(c); + })); if (key == std::to_string(devicePhyId)) { std::string deviceIp = line.substr(equal_pos + 1); - deviceIp.erase(deviceIp.begin(), std::find_if(deviceIp.begin(), deviceIp.end(), [](unsigned char c){ return !std::isspace(c); })); - deviceIp.erase(std::find_if(deviceIp.rbegin(), deviceIp.rend(), [](unsigned char c){ return !std::isspace(c); }).base(), deviceIp.end()); - - if (inet_pton(AF_INET, hostIp.c_str(), &local_rank_info_.hostIp) != 1) { - LOG(ERROR) << "HcclTransport: Invalid Host IP format: " << hostIp; + deviceIp.erase( + deviceIp.begin(), + std::find_if( + deviceIp.begin(), deviceIp.end(), + [](unsigned char c) { return !std::isspace(c); })); + deviceIp.erase( + std::find_if( + deviceIp.rbegin(), deviceIp.rend(), + [](unsigned char c) { return !std::isspace(c); }) + .base(), + deviceIp.end()); + + if (inet_pton(AF_INET, hostIp.c_str(), + &local_rank_info_.hostIp) != 1) { + LOG(ERROR) << "HcclTransport: Invalid Host IP format: " + << hostIp; return -1; } local_rank_info_.rankId = devicePhyId; local_rank_info_.serverIdx = 0; local_rank_info_.devicePhyId = devicePhyId; - local_rank_info_.hostPort = ASCEND_DEFAULT_HOST_PORT + devicePhyId; + local_rank_info_.hostPort = + ASCEND_DEFAULT_HOST_PORT + devicePhyId; local_rank_info_.deviceLogicId = deviceLogicId; local_rank_info_.devicePort = ASCEND_DEFAULT_DEVICE_PORT; local_rank_info_.pid = 0; - if (inet_pton(AF_INET, deviceIp.c_str(), &local_rank_info_.deviceIp) != 1) { - LOG(ERROR) << "HcclTransport: Invalid Device IP format: " << deviceIp; + if (inet_pton(AF_INET, deviceIp.c_str(), + &local_rank_info_.deviceIp) != 1) { + LOG(ERROR) + << "HcclTransport: Invalid Device IP format: " + << deviceIp; return -1; } - LOG(INFO) << "rankInfoParse Success, hostIp: " << hostIp << ", rankId: " << local_rank_info_.rankId - << ", serverIdx: " << local_rank_info_.serverIdx << ", devicePhyId: " << local_rank_info_.devicePhyId - << ", hostPort: " << local_rank_info_.hostPort << ", deviceLogicId: " << local_rank_info_.deviceLogicId - << ", devicePort: " << local_rank_info_.devicePort << ", deviceIp: " << deviceIp << ", device pid: " << local_rank_info_.pid; + LOG(INFO) + << "rankInfoParse Success, hostIp: " << hostIp + << ", rankId: " << local_rank_info_.rankId + << ", serverIdx: " << local_rank_info_.serverIdx + << ", devicePhyId: " << local_rank_info_.devicePhyId + << ", hostPort: " << local_rank_info_.hostPort + << ", deviceLogicId: " << local_rank_info_.deviceLogicId + << ", devicePort: " << local_rank_info_.devicePort + << ", deviceIp: " << deviceIp + << ", device pid: " << local_rank_info_.pid; // Exit after finishing rankInfoParse return 0; } @@ -288,20 +338,26 @@ int HcclTransport::rankInfoParse(int devicePhyId, std::string hostIp) { } int HcclTransport::install(std::string &local_server_name, - std::shared_ptr meta, std::shared_ptr topo) { + std::shared_ptr meta, + std::shared_ptr topo) { int ret = 0; int port; std::string hostIp; int devicePhyId; metadata_ = meta; - ret = getDevIdAndIpPortFromServerName(local_server_name, hostIp, port, devicePhyId); + ret = getDevIdAndIpPortFromServerName(local_server_name, hostIp, port, + devicePhyId); if (ret) { - LOG(ERROR) << "HcclTransport: getDevIdAndIpPortFromServerName failed, ret: " << ret; - return ret; + LOG(ERROR) + << "HcclTransport: getDevIdAndIpPortFromServerName failed, ret: " + << ret; + return ret; } local_server_name_ = hostIp + ":" + std::to_string(port); - LOG(INFO) << "HcclTransport: begin to install transport, local devicePhyId: " << devicePhyId << ", local_server_name: " << local_server_name; + LOG(INFO) + << "HcclTransport: begin to install transport, local devicePhyId: " + << devicePhyId << ", local_server_name: " << local_server_name; // add to local_rank_info_ ret = rankInfoParse(devicePhyId, hostIp); @@ -318,14 +374,16 @@ int HcclTransport::install(std::string &local_server_name, ret = allocateLocalSegmentID(); if (ret) { - LOG(ERROR) << "HcclTransport: cannot allocate local segment, ret: " << ret; + LOG(ERROR) << "HcclTransport: cannot allocate local segment, ret: " + << ret; return ret; } ret = metadata_->updateLocalSegmentDesc(); if (ret) { LOG(ERROR) << "HcclTransport: cannot publish segments, " - "check the availability of metadata storage, ret: " << ret; + "check the availability of metadata storage, ret: " + << ret; return ret; } @@ -412,7 +470,7 @@ Status HcclTransport::submitTransferTask( } Status HcclTransport::getTransferStatus(BatchID batch_id, size_t task_id, - TransferStatus &status) { + TransferStatus &status) { auto &batch_desc = *((BatchDesc *)(batch_id)); const size_t task_count = batch_desc.task_list.size(); if (task_id >= task_count) { @@ -438,9 +496,9 @@ Status HcclTransport::getTransferStatus(BatchID batch_id, size_t task_id, } int HcclTransport::registerLocalMemory(void *addr, size_t length, - const std::string &location, - bool remote_accessible, - bool update_metadata) { + const std::string &location, + bool remote_accessible, + bool update_metadata) { (void)remote_accessible; BufferDesc buffer_desc; buffer_desc.name = location; @@ -456,7 +514,8 @@ int HcclTransport::registerLocalMemory(void *addr, size_t length, ret = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); if (ret) { - LOG(ERROR) << "HcclTransport: addLocalMemoryBuffer failed, ret: " << ret; + LOG(ERROR) << "HcclTransport: addLocalMemoryBuffer failed, ret: " + << ret; return ret; } @@ -500,4 +559,4 @@ int HcclTransport::unregisterLocalMemoryBatch( return metadata_->updateLocalSegmentDesc(); } -} +} // namespace mooncake diff --git a/mooncake-transfer-engine/src/transport/cxl_transport/cxl_transport.cpp b/mooncake-transfer-engine/src/transport/cxl_transport/cxl_transport.cpp index 421c15166..1b549dead 100644 --- a/mooncake-transfer-engine/src/transport/cxl_transport/cxl_transport.cpp +++ b/mooncake-transfer-engine/src/transport/cxl_transport/cxl_transport.cpp @@ -29,9 +29,9 @@ #include "transfer_metadata.h" #include "transport/transport.h" #include -#include // For O_RDWR, O_CREAT, etc. -#include // For open(), close(), read(), write() -#include // For mmap, munmap +#include // For O_RDWR, O_CREAT, etc. +#include // For open(), close(), read(), write() +#include // For mmap, munmap namespace mooncake { @@ -39,17 +39,17 @@ CxlTransport::CxlTransport() { // cxl_dev_path = "/dev/dax0.0"; // cxl_dev_size = 1024 * 1024 * 1024; // get from env - const char* env_cxl_dev_path = std::getenv("MC_CXL_DEV_PATH"); + const char *env_cxl_dev_path = std::getenv("MC_CXL_DEV_PATH"); if (env_cxl_dev_path) { LOG(INFO) << "MC_CXL_DEV_PATH: " << env_cxl_dev_path; - cxl_dev_path = (char *) env_cxl_dev_path; + cxl_dev_path = (char *)env_cxl_dev_path; cxl_dev_size = cxlGetDeviceSize(); } } CxlTransport::~CxlTransport() { - if (cxl_base_addr != nullptr && cxl_base_addr != MAP_FAILED && + if (cxl_base_addr != nullptr && cxl_base_addr != MAP_FAILED && cxl_dev_size != 0) { munmap(cxl_base_addr, cxl_dev_size); } @@ -58,11 +58,11 @@ CxlTransport::~CxlTransport() { size_t CxlTransport::cxlGetDeviceSize() { // for now, get cxl_shm size from env - const char* env_cxl_dev_size = std::getenv("MC_CXL_DEV_SIZE"); + const char *env_cxl_dev_size = std::getenv("MC_CXL_DEV_SIZE"); if (env_cxl_dev_size) { LOG(INFO) << "MC_CXL_DEV_SIZE: " << env_cxl_dev_size; - char* end = nullptr; + char *end = nullptr; unsigned long long val = strtoull(env_cxl_dev_size, &end, 10); if (end != env_cxl_dev_size && *end == '\0') return static_cast(val); @@ -73,25 +73,26 @@ size_t CxlTransport::cxlGetDeviceSize() { int CxlTransport::cxlMemcpy(void *dest, void *src, size_t size) { // Input validation if (!src || !dest) { - LOG(ERROR) << "CxlTransport::cxlMemcpy invalid arguments: null pointer provided."; - return -1; // null pointer + LOG(ERROR) << "CxlTransport::cxlMemcpy invalid arguments: null pointer " + "provided."; + return -1; // null pointer } - + // Validate memory bounds using the helper function if (!validateMemoryBounds(dest, src, size)) { - return -1; // validation failed + return -1; // validation failed } - + // Perform the memory copy std::memcpy(dest, src, size); - + // Memory barriers and cache operations if (isAddressInCxlRange(dest) || isAddressInCxlRange(src)) { // Ensure memory ordering for CXL operations __sync_synchronize(); } - - return 0; // success + + return 0; // success } bool CxlTransport::validateMemoryBounds(void *dest, void *src, size_t size) { @@ -99,7 +100,7 @@ bool CxlTransport::validateMemoryBounds(void *dest, void *src, size_t size) { uintptr_t end = base + cxl_dev_size; uintptr_t dest_ptr = reinterpret_cast(dest); uintptr_t src_ptr = reinterpret_cast(src); - + if (isAddressInCxlRange(dest)) { uintptr_t dest_end = dest_ptr + size; if (dest_end > end || dest_end < dest_ptr) { @@ -107,7 +108,7 @@ bool CxlTransport::validateMemoryBounds(void *dest, void *src, size_t size) { return false; } } - + if (isAddressInCxlRange(src)) { uintptr_t src_end = src_ptr + size; if (src_end > end || src_end < src_ptr) { @@ -115,17 +116,17 @@ bool CxlTransport::validateMemoryBounds(void *dest, void *src, size_t size) { return false; } } - + return true; } bool CxlTransport::isAddressInCxlRange(void *addr) { if (!addr || !cxl_base_addr) return false; - + uintptr_t base = reinterpret_cast(cxl_base_addr); uintptr_t end = base + cxl_dev_size; uintptr_t ptr = reinterpret_cast(addr); - + return (ptr >= base && ptr < end); } @@ -136,11 +137,13 @@ int CxlTransport::cxlDevInit() { } int fd = open(cxl_dev_path, O_RDWR); if (fd == -1) { - LOG(ERROR) << "CxlTransport: Cannot open cxl device." << strerror(errno); + LOG(ERROR) << "CxlTransport: Cannot open cxl device." + << strerror(errno); return -1; } - void* ptr = mmap(NULL, cxl_dev_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + void *ptr = + mmap(NULL, cxl_dev_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); if (ptr == MAP_FAILED) { close(fd); return ERR_MEMORY; @@ -292,13 +295,13 @@ Status CxlTransport::submitTransfer( __sync_fetch_and_add(&task.slice_count, 1); int err; if (slice->opcode == TransferRequest::READ) - //READ: Source is in local memory, Destination is on CXL + // READ: Source is in local memory, Destination is on CXL err = cxlMemcpy(slice->source_addr, (void *)slice->cxl.dest_addr, - slice->length); + slice->length); else - //WRITE: Source is in local memory, Destination is on CXL + // WRITE: Source is in local memory, Destination is on CXL err = cxlMemcpy((void *)slice->cxl.dest_addr, slice->source_addr, - slice->length); + slice->length); if (err != 0) slice->markFailed(); else @@ -317,7 +320,7 @@ Status CxlTransport::submitTransferTask( auto &request = *task.request; uint64_t dest_cxl_offset = request.target_offset; task.total_bytes = request.length; - + Slice *slice = getSliceCache().allocate(); slice->source_addr = (char *)request.source; slice->cxl.dest_addr = (char *)cxl_base_addr + dest_cxl_offset; @@ -330,13 +333,13 @@ Status CxlTransport::submitTransferTask( __sync_fetch_and_add(&task.slice_count, 1); int err; if (slice->opcode == TransferRequest::READ) - //READ: Source is in local memory, Destination is on CXL + // READ: Source is in local memory, Destination is on CXL err = cxlMemcpy(slice->source_addr, (void *)slice->cxl.dest_addr, - slice->length); + slice->length); else - //WRITE: Source is in local memory, Destination is on CXL + // WRITE: Source is in local memory, Destination is on CXL err = cxlMemcpy((void *)slice->cxl.dest_addr, slice->source_addr, - slice->length); + slice->length); if (err != 0) slice->markFailed(); else diff --git a/mooncake-transfer-engine/src/transport/nvlink_transport/nvlink_transport.cpp b/mooncake-transfer-engine/src/transport/nvlink_transport/nvlink_transport.cpp index 89a7d1a77..282fdf72f 100644 --- a/mooncake-transfer-engine/src/transport/nvlink_transport/nvlink_transport.cpp +++ b/mooncake-transfer-engine/src/transport/nvlink_transport/nvlink_transport.cpp @@ -45,8 +45,9 @@ namespace mooncake { static int getNumDevices() { static int cached_num_devices = -1; if (cached_num_devices == -1) { - if (!checkCudaErrorReturn(cudaGetDeviceCount(&cached_num_devices), - "NvlinkTransport: cudaGetDeviceCount failed")) { + if (!checkCudaErrorReturn( + cudaGetDeviceCount(&cached_num_devices), + "NvlinkTransport: cudaGetDeviceCount failed")) { return 0; } } @@ -55,13 +56,13 @@ static int getNumDevices() { static bool supportFabricMem() { if (getenv("MC_USE_NVLINK_IPC")) return false; - + int num_devices = getNumDevices(); if (num_devices == 0) { LOG(ERROR) << "NvlinkTransport: no device found"; return false; } - + for (int device_id = 0; device_id < num_devices; ++device_id) { int device_support_fabric_mem = 0; cuDeviceGetAttribute(&device_support_fabric_mem, @@ -406,7 +407,8 @@ int NvlinkTransport::registerLocalMemory(void *addr, size_t length, desc.addr = (uint64_t)real_addr; // (uint64_t)addr; desc.length = real_size; // length; desc.name = location; - desc.shm_name = serializeBinaryData(&export_handle, sizeof(CUmemFabricHandle)); + desc.shm_name = + serializeBinaryData(&export_handle, sizeof(CUmemFabricHandle)); return metadata_->addLocalMemoryBuffer(desc, true); } } @@ -425,7 +427,9 @@ int NvlinkTransport::relocateSharedMemoryAddress(uint64_t &dest_addr, dest_addr + length <= entry.addr + entry.length) { remap_lock_.lockShared(); if (remap_entries_.count(std::make_pair(target_id, entry.addr))) { - auto shm_addr = remap_entries_[std::make_pair(target_id, entry.addr)].shm_addr; + auto shm_addr = + remap_entries_[std::make_pair(target_id, entry.addr)] + .shm_addr; remap_lock_.unlockShared(); dest_addr = dest_addr - entry.addr + ((uint64_t)shm_addr); return 0; @@ -451,7 +455,8 @@ int NvlinkTransport::relocateSharedMemoryAddress(uint64_t &dest_addr, OpenedShmEntry shm_entry; shm_entry.shm_addr = shm_addr; shm_entry.length = length; - remap_entries_[std::make_pair(target_id, entry.addr)] = shm_entry; + remap_entries_[std::make_pair(target_id, entry.addr)] = + shm_entry; } else if (output_buffer.size() == sizeof(CUmemFabricHandle) && use_fabric_mem_) { CUmemFabricHandle export_handle; @@ -482,14 +487,17 @@ int NvlinkTransport::relocateSharedMemoryAddress(uint64_t &dest_addr, << "NvlinkTransport: cuMemMap failed: " << result; return -1; } - + int device_count; cudaGetDeviceCount(&device_count); CUmemAccessDesc accessDesc[device_count]; - for (int device_id = 0; device_id < device_count; ++device_id) { - accessDesc[device_id].location.type = CU_MEM_LOCATION_TYPE_DEVICE; + for (int device_id = 0; device_id < device_count; + ++device_id) { + accessDesc[device_id].location.type = + CU_MEM_LOCATION_TYPE_DEVICE; accessDesc[device_id].location.id = device_id; - accessDesc[device_id].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + accessDesc[device_id].flags = + CU_MEM_ACCESS_FLAGS_PROT_READWRITE; } result = cuMemSetAccess((CUdeviceptr)shm_addr, entry.length, accessDesc, device_count); @@ -501,13 +509,15 @@ int NvlinkTransport::relocateSharedMemoryAddress(uint64_t &dest_addr, OpenedShmEntry shm_entry; shm_entry.shm_addr = shm_addr; shm_entry.length = length; - remap_entries_[std::make_pair(target_id, entry.addr)] = shm_entry; + remap_entries_[std::make_pair(target_id, entry.addr)] = + shm_entry; } else { LOG(ERROR) << "Mismatched NVLink data transfer method"; return -1; } } - auto shm_addr = remap_entries_[std::make_pair(target_id, entry.addr)].shm_addr; + auto shm_addr = + remap_entries_[std::make_pair(target_id, entry.addr)].shm_addr; dest_addr = dest_addr - entry.addr + ((uint64_t)shm_addr); return 0; } diff --git a/mooncake-transfer-engine/src/transport/nvmeof_transport/nvmeof_transport.cpp b/mooncake-transfer-engine/src/transport/nvmeof_transport/nvmeof_transport.cpp index 5007fb23e..8f51aeb5c 100644 --- a/mooncake-transfer-engine/src/transport/nvmeof_transport/nvmeof_transport.cpp +++ b/mooncake-transfer-engine/src/transport/nvmeof_transport/nvmeof_transport.cpp @@ -75,7 +75,7 @@ NVMeoFTransport::BatchID NVMeoFTransport::allocateBatchID(size_t batch_size) { } Status NVMeoFTransport::getTransferStatus(BatchID batch_id, size_t task_id, - TransferStatus &status) { + TransferStatus &status) { auto &batch_desc = *((BatchDesc *)(batch_id)); auto &task = batch_desc.task_list[task_id]; auto &nvmeof_desc = *((NVMeoFBatchDesc *)(batch_desc.context)); @@ -118,8 +118,9 @@ Status NVMeoFTransport::submitTransfer( auto &nvmeof_desc = *((NVMeoFBatchDesc *)(batch_desc.context)); if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size) { - LOG(ERROR) << "NVMeoFTransport: Exceed the limitation of current batch's " - "capacity"; + LOG(ERROR) + << "NVMeoFTransport: Exceed the limitation of current batch's " + "capacity"; return Status::InvalidArgument( "NVMeoFTransport: Exceed the limitation of capacity, batch id: " + std::to_string(batch_id)); diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_endpoint.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_endpoint.cpp index 6c5429a4a..daa660038 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_endpoint.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_endpoint.cpp @@ -86,8 +86,8 @@ int RdmaEndPoint::deconstruct() { bool displayed = false; if (wr_depth_list_[i] != 0) { if (!displayed) { - LOG(WARNING) - << "Outstanding work requests found, CQ will not be generated"; + LOG(WARNING) << "Outstanding work requests found, CQ will not " + "be generated"; displayed = true; } __sync_fetch_and_sub(cq_outstanding_, wr_depth_list_[i]); @@ -236,8 +236,8 @@ void RdmaEndPoint::disconnectUnlocked() { bool displayed = false; if (wr_depth_list_[i] != 0) { if (!displayed) { - LOG(WARNING) - << "Outstanding work requests found, CQ will not be generated"; + LOG(WARNING) << "Outstanding work requests found, CQ will not " + "be generated"; displayed = true; } __sync_fetch_and_sub(cq_outstanding_, wr_depth_list_[i]); diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp index a4bcb4790..faf80f14b 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/rdma_transport.cpp @@ -271,7 +271,8 @@ Status RdmaTransport::submitTransferTask( const size_t kBlockSize = globalConfig().slice_size; const int kMaxRetryCount = globalConfig().retry_cnt; const size_t kFragmentSize = globalConfig().fragment_limit; - const size_t kSubmitWatermark = globalConfig().max_wr * globalConfig().num_qp_per_ep; + const size_t kSubmitWatermark = + globalConfig().max_wr * globalConfig().num_qp_per_ep; uint64_t nr_slices; for (size_t index = 0; index < task_list.size(); ++index) { assert(task_list[index]); @@ -286,11 +287,13 @@ Status RdmaTransport::submitTransferTask( if (!slice->from_cache) { nr_slices++; } - - bool merge_final_slice = request.length - offset <= kBlockSize + kFragmentSize; + + bool merge_final_slice = + request.length - offset <= kBlockSize + kFragmentSize; slice->source_addr = (char *)request.source + offset; - slice->length = merge_final_slice ? request.length - offset : kBlockSize; + slice->length = + merge_final_slice ? request.length - offset : kBlockSize; slice->opcode = request.opcode; slice->rdma.dest_addr = request.target_offset + offset; slice->rdma.retry_cnt = request.advise_retry_cnt; @@ -352,8 +355,7 @@ Status RdmaTransport::submitTransferTask( } for (auto &entry : slices_to_post) - if (!entry.second.empty()) - entry.first->submitPostSend(entry.second); + if (!entry.second.empty()) entry.first->submitPostSend(entry.second); return Status::OK(); } @@ -467,33 +469,40 @@ int RdmaTransport::startHandshakeDaemon(std::string &local_server_name) { // According to the request desc, offset and length information, find proper // buffer_id and device_id as output. // Return 0 if successful, ERR_ADDRESS_NOT_REGISTERED otherwise. -int RdmaTransport::selectDevice(SegmentDesc *desc, uint64_t offset, size_t length, - std::string_view hint, int &buffer_id, int &device_id, int retry_count) { +int RdmaTransport::selectDevice(SegmentDesc *desc, uint64_t offset, + size_t length, std::string_view hint, + int &buffer_id, int &device_id, + int retry_count) { if (desc == nullptr) return ERR_ADDRESS_NOT_REGISTERED; const auto &buffers = desc->buffers; - for (buffer_id = 0; buffer_id < static_cast(buffers.size()); ++buffer_id) { + for (buffer_id = 0; buffer_id < static_cast(buffers.size()); + ++buffer_id) { const auto &buffer = buffers[buffer_id]; // Check if offset is within buffer range - if (offset < buffer.addr || length > buffer.length || offset - buffer.addr > buffer.length - length) { + if (offset < buffer.addr || length > buffer.length || + offset - buffer.addr > buffer.length - length) { continue; } - device_id = hint.empty() - ? desc->topology.selectDevice(buffer.name, retry_count) - : desc->topology.selectDevice(buffer.name, hint, retry_count); + device_id = + hint.empty() + ? desc->topology.selectDevice(buffer.name, retry_count) + : desc->topology.selectDevice(buffer.name, hint, retry_count); if (device_id >= 0) return 0; - device_id = hint.empty() - ? desc->topology.selectDevice(kWildcardLocation, retry_count) - : desc->topology.selectDevice(kWildcardLocation, hint, retry_count); + device_id = hint.empty() ? desc->topology.selectDevice( + kWildcardLocation, retry_count) + : desc->topology.selectDevice( + kWildcardLocation, hint, retry_count); if (device_id >= 0) return 0; } return ERR_ADDRESS_NOT_REGISTERED; } -int RdmaTransport::selectDevice(SegmentDesc *desc, uint64_t offset, size_t length, - int &buffer_id, int &device_id, int retry_count) { - return selectDevice(desc, offset, length, "", buffer_id, device_id, retry_count); +int RdmaTransport::selectDevice(SegmentDesc *desc, uint64_t offset, + size_t length, int &buffer_id, int &device_id, + int retry_count) { + return selectDevice(desc, offset, length, "", buffer_id, device_id, + retry_count); } } // namespace mooncake - diff --git a/mooncake-transfer-engine/src/transport/rdma_transport/worker_pool.cpp b/mooncake-transfer-engine/src/transport/rdma_transport/worker_pool.cpp index 2d674640d..4728236bf 100644 --- a/mooncake-transfer-engine/src/transport/rdma_transport/worker_pool.cpp +++ b/mooncake-transfer-engine/src/transport/rdma_transport/worker_pool.cpp @@ -110,10 +110,12 @@ int WorkerPool::submitPostSend( } auto &peer_segment_desc = segment_desc_map[slice->target_id]; int buffer_id, device_id; - auto hint = globalConfig().enable_dest_device_affinity ? context_.deviceName() : ""; - if (RdmaTransport::selectDevice( - peer_segment_desc.get(), slice->rdma.dest_addr, slice->length, - hint, buffer_id, device_id)) { + auto hint = globalConfig().enable_dest_device_affinity + ? context_.deviceName() + : ""; + if (RdmaTransport::selectDevice(peer_segment_desc.get(), + slice->rdma.dest_addr, slice->length, + hint, buffer_id, device_id)) { peer_segment_desc = context_.engine().meta()->getSegmentDescByID( slice->target_id, true); if (!peer_segment_desc) { @@ -230,7 +232,8 @@ void WorkerPool::performPostSend(int thread_id) { } if (!endpoint->active()) { if (endpoint->inactiveTime() > 1.0) - context_.deleteEndpoint(entry.first); // enable for re-establishation + context_.deleteEndpoint( + entry.first); // enable for re-establishation for (auto &slice : entry.second) failed_slice_list.push_back(slice); entry.second.clear(); continue; diff --git a/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp b/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp index b186f244a..e07bd930f 100644 --- a/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp +++ b/mooncake-transfer-engine/src/transport/tcp_transport/tcp_transport.cpp @@ -246,7 +246,7 @@ int TcpTransport::install(std::string &local_server_name, return -1; } - close(sockfd); // the above function has opened a socket + close(sockfd); // the above function has opened a socket LOG(INFO) << "TcpTransport: listen on port " << tcp_port; context_ = new TcpContext(tcp_port); running_ = true; diff --git a/mooncake-transfer-engine/tests/common_test.cpp b/mooncake-transfer-engine/tests/common_test.cpp index d1236fa95..886f09846 100644 --- a/mooncake-transfer-engine/tests/common_test.cpp +++ b/mooncake-transfer-engine/tests/common_test.cpp @@ -5,14 +5,12 @@ #include "common.h" - namespace { using namespace mooncake; const uint16_t kDefaultPort = getDefaultHandshakePort(); - //------------------------------------------------------------------------------ // parseFromString //------------------------------------------------------------------------------ @@ -149,7 +147,6 @@ TEST(ParsePortAndDevice, InvalidDeviceIdSetsToZero) { // parseHostNameWithPortAscend //------------------------------------------------------------------------------ - TEST(ParseHostNameWithPortAscend, BracketedIpv6WithPort) { int dev = -1; auto [host, port] = parseHostNameWithPortAscend("[2001:db8::1]:1234", &dev); @@ -208,7 +205,8 @@ TEST(ParseHostNameWithPortAscend, HostWithoutPort) { TEST(ParseHostNameWithPortAscend, BracketedIpv6PortDevice) { int dev = -1; - auto [host, port] = parseHostNameWithPortAscend("[2001:db8::1]:8080:npu_3", &dev); + auto [host, port] = + parseHostNameWithPortAscend("[2001:db8::1]:8080:npu_3", &dev); EXPECT_EQ(host, "2001:db8::1"); EXPECT_EQ(port, 8080); EXPECT_EQ(dev, 3); @@ -232,7 +230,8 @@ TEST(ParseHostNameWithPortAscend, Ipv4PortInvalidDevice) { TEST(ParseHostNameWithPortAscend, HostPortDevice) { int dev = -1; - auto [host, port] = parseHostNameWithPortAscend("example.com:8080:npu_1", &dev); + auto [host, port] = + parseHostNameWithPortAscend("example.com:8080:npu_1", &dev); EXPECT_EQ(host, "example.com"); EXPECT_EQ(port, 8080); EXPECT_EQ(dev, 1); diff --git a/mooncake-transfer-engine/tests/cxl_transport_test.cpp b/mooncake-transfer-engine/tests/cxl_transport_test.cpp index cb02fff2a..2907535bf 100644 --- a/mooncake-transfer-engine/tests/cxl_transport_test.cpp +++ b/mooncake-transfer-engine/tests/cxl_transport_test.cpp @@ -72,7 +72,6 @@ class CXLTransportTest : public ::testing::Test { mooncake::Transport::SegmentID segment_id; std::shared_ptr segment_desc; const size_t kDataLength = 4 * 1024; - protected: void SetUp() override { @@ -102,22 +101,21 @@ class CXLTransportTest : public ::testing::Test { xport = engine->installTransport("cxl", args); ASSERT_NE(xport, nullptr); - cxl_xport = dynamic_cast(xport); - base_addr = (uint8_t*)cxl_xport->getCxlBaseAddr(); - addr = (uint8_t*) allocateMemoryPool(kDataLength, 0, false); + cxl_xport = dynamic_cast(xport); + base_addr = (uint8_t *)cxl_xport->getCxlBaseAddr(); + addr = (uint8_t *)allocateMemoryPool(kDataLength, 0, false); int rc = engine->registerLocalMemory(base_addr + offset_1, len); ASSERT_EQ(rc, 0); segment_id = engine->openSegment(FLAGS_local_server_name.c_str()); // bindToSocket(0); segment_desc = engine->getMetadata()->getSegmentDescByID(segment_id); - } void TearDown() override { - if (tmp_fd >= 0) { - close(tmp_fd); - unlink(FLAGS_device_name.c_str()); + if (tmp_fd >= 0) { + close(tmp_fd); + unlink(FLAGS_device_name.c_str()); } free(args); google::ShutdownGoogleLogging(); @@ -141,7 +139,7 @@ TEST_F(CXLTransportTest, MultiWrite) { // s = xport->submitTransfer(batch_id, {entry}); s = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(s.ok()); - + bool completed = false; TransferStatus status; while (!completed) { @@ -176,7 +174,7 @@ TEST_F(CXLTransportTest, MultipleRead) { // s = xport->submitTransfer(batch_id, {entry}); s = engine->submitTransfer(batch_id, {entry}); LOG_ASSERT(s.ok()); - + bool completed = false; TransferStatus status; while (!completed) { @@ -189,7 +187,7 @@ TEST_F(CXLTransportTest, MultipleRead) { completed = true; } } - + s = xport->freeBatchID(batch_id); ASSERT_EQ(s, Status::OK()); } @@ -233,7 +231,6 @@ TEST_F(CXLTransportTest, MultipleRead) { freeMemoryPool(src, kDataLength); } engine->unregisterLocalMemory(addr); - } } // namespace mooncake diff --git a/mooncake-transfer-engine/tests/memory_location_test.cpp b/mooncake-transfer-engine/tests/memory_location_test.cpp index 865344568..56dbc7d86 100644 --- a/mooncake-transfer-engine/tests/memory_location_test.cpp +++ b/mooncake-transfer-engine/tests/memory_location_test.cpp @@ -110,13 +110,13 @@ TEST(MemoryLocationTest, MallocMultipleNodes) { // check the second memory location EXPECT_EQ(entries[1].start, - reinterpret_cast(addr) + 4096 * 2); + reinterpret_cast(addr) + 4096 * 2); EXPECT_EQ(entries[1].location, locationa); EXPECT_EQ(entries[1].len, static_cast(4096 * 7)); // check the third memory location EXPECT_EQ(entries[2].start, - reinterpret_cast(addr) + 4096 * 9); + reinterpret_cast(addr) + 4096 * 9); EXPECT_EQ(entries[2].location, locationb); EXPECT_EQ(entries[2].len, static_cast(4096 - 1024 * 2)); } diff --git a/mooncake-transfer-engine/tests/nvlink_transport_test.cpp b/mooncake-transfer-engine/tests/nvlink_transport_test.cpp index a31a66504..b49341e84 100644 --- a/mooncake-transfer-engine/tests/nvlink_transport_test.cpp +++ b/mooncake-transfer-engine/tests/nvlink_transport_test.cpp @@ -16,7 +16,7 @@ DEFINE_string(local_server_name, "cuda_server:12345", "Local server name"); DEFINE_string(segment_id, "cuda_server:12345", "Segment ID to access data"); DEFINE_int32(gpu_id, 0, "GPU ID to use"); -static void checkCudaError(cudaError_t result, const char *message) { +static void checkCudaError(cudaError_t result, const char* message) { if (result != cudaSuccess) { LOG(ERROR) << message << " (Error code: " << result << " - " << cudaGetErrorString(result) << ")"; @@ -27,7 +27,8 @@ static void checkCudaError(cudaError_t result, const char *message) { static void* allocateCudaBuffer(size_t size, int gpu_id) { checkCudaError(cudaSetDevice(gpu_id), "Failed to set device"); void* d_buf = nullptr; - checkCudaError(cudaMalloc(&d_buf, size), "Failed to allocate device memory"); + checkCudaError(cudaMalloc(&d_buf, size), + "Failed to allocate device memory"); return d_buf; } @@ -44,11 +45,13 @@ TEST(NvlinkTransportTest, WriteAndRead) { server_engine->init(FLAGS_metadata_server, FLAGS_local_server_name); // Install NvlinkTransport on server - Transport* server_transport = server_engine->installTransport("nvlink", nullptr); + Transport* server_transport = + server_engine->installTransport("nvlink", nullptr); ASSERT_NE(server_transport, nullptr); void* server_buffer = allocateCudaBuffer(kDataLength * 2, gpu_id); - int rc = server_engine->registerLocalMemory(server_buffer, kDataLength * 2, "cuda:0"); + int rc = server_engine->registerLocalMemory(server_buffer, kDataLength * 2, + "cuda:0"); ASSERT_EQ(rc, 0); auto segment_id = server_engine->openSegment(FLAGS_segment_id); @@ -58,7 +61,8 @@ TEST(NvlinkTransportTest, WriteAndRead) { client_engine->init(FLAGS_metadata_server, "cuda_client:12346"); // Install NvlinkTransport on client - Transport* client_transport = client_engine->installTransport("nvlink", nullptr); + Transport* client_transport = + client_engine->installTransport("nvlink", nullptr); ASSERT_NE(client_transport, nullptr); void* client_buffer = allocateCudaBuffer(kDataLength * 2, gpu_id); @@ -70,7 +74,9 @@ TEST(NvlinkTransportTest, WriteAndRead) { { // Fill client buffer with data std::vector host_data(kDataLength, 'A'); - checkCudaError(cudaMemcpy(client_buffer, host_data.data(), kDataLength, cudaMemcpyHostToDevice), "Memcpy to client_buffer"); + checkCudaError(cudaMemcpy(client_buffer, host_data.data(), kDataLength, + cudaMemcpyHostToDevice), + "Memcpy to client_buffer"); auto batch_id = client_engine->allocateBatchID(1); TransferRequest entry; @@ -120,7 +126,10 @@ TEST(NvlinkTransportTest, WriteAndRead) { // Check data std::vector host_check(kDataLength); - checkCudaError(cudaMemcpy(host_check.data(), (char*)client_buffer + kDataLength, kDataLength, cudaMemcpyDeviceToHost), "Memcpy from client_buffer"); + checkCudaError( + cudaMemcpy(host_check.data(), (char*)client_buffer + kDataLength, + kDataLength, cudaMemcpyDeviceToHost), + "Memcpy from client_buffer"); for (size_t i = 0; i < kDataLength; ++i) { ASSERT_EQ(host_check[i], 'A'); } diff --git a/mooncake-transfer-engine/tests/nvmeof_transport_test.cpp b/mooncake-transfer-engine/tests/nvmeof_transport_test.cpp index a82f3229e..13e285bd5 100644 --- a/mooncake-transfer-engine/tests/nvmeof_transport_test.cpp +++ b/mooncake-transfer-engine/tests/nvmeof_transport_test.cpp @@ -45,7 +45,8 @@ DEFINE_string(device_name, "erdma_1", DEFINE_string(nic_priority_matrix, "", "Path to RDMA NIC priority matrix file (Advanced)"); -// python /workspace/Mooncake/mooncake-transfer-engine/scripts/register.py localhost test_nvmeof /workspace/sample +// python /workspace/Mooncake/mooncake-transfer-engine/scripts/register.py +// localhost test_nvmeof /workspace/sample DEFINE_string(segment_id, "nvmeof/test_nvmeof", "Segment ID to access data"); static void *allocateMemoryPool(size_t size, int socket_id, diff --git a/mooncake-transfer-engine/tests/rdma_loopback_test.cpp b/mooncake-transfer-engine/tests/rdma_loopback_test.cpp index f2c1758a0..791d3af26 100644 --- a/mooncake-transfer-engine/tests/rdma_loopback_test.cpp +++ b/mooncake-transfer-engine/tests/rdma_loopback_test.cpp @@ -87,7 +87,7 @@ TEST_F(RDMALoopbackTest, MultiWrite) { } s = engine->freeBatchID(batch_id); ASSERT_EQ(s, Status::OK()); - ASSERT_EQ(0, memcmp(addr, (char *) addr + kDataLength, kDataLength)); + ASSERT_EQ(0, memcmp(addr, (char *)addr + kDataLength, kDataLength)); } } } // namespace mooncake diff --git a/mooncake-transfer-engine/tests/rdma_transport_test.cpp b/mooncake-transfer-engine/tests/rdma_transport_test.cpp index 4f5c48991..f3d3d9903 100644 --- a/mooncake-transfer-engine/tests/rdma_transport_test.cpp +++ b/mooncake-transfer-engine/tests/rdma_transport_test.cpp @@ -266,7 +266,8 @@ int initiator() { LOG_ASSERT(!rc); #else addr = allocateMemoryPool(ram_buffer_size, 0, false); - int rc = engine->registerLocalMemory(addr, ram_buffer_size, kWildcardLocation); + int rc = + engine->registerLocalMemory(addr, ram_buffer_size, kWildcardLocation); LOG_ASSERT(!rc); #endif diff --git a/scripts/ascend/pkg/hccl_mem.h b/scripts/ascend/pkg/hccl_mem.h index 9893b2aba..0cd7e83c6 100644 --- a/scripts/ascend/pkg/hccl_mem.h +++ b/scripts/ascend/pkg/hccl_mem.h @@ -12,7 +12,7 @@ #ifdef __cplusplus extern "C" { -#endif // __cplusplus +#endif // __cplusplus /* 网络设备句柄 */ typedef void *HcclNetDev; @@ -37,7 +37,8 @@ typedef struct { * @param[out] buf 返回的缓冲区描述符 * @return 执行状态码 HcclResult */ -extern HcclResult HcclMemReg(HcclNetDev netDev, const HcclMem *mem, HcclBuf *buf); +extern HcclResult HcclMemReg(HcclNetDev netDev, const HcclMem *mem, + HcclBuf *buf); /** * @brief 注销已注册的内存区域 @@ -53,7 +54,8 @@ extern HcclResult HcclMemDereg(const HcclBuf *buf); * @param[out] outDescLen 返回描述信息长度 * @return 执行状态码 HcclResult */ -extern HcclResult HcclMemExport(HcclBuf *buf, char **outDesc, uint64_t *outDescLen); +extern HcclResult HcclMemExport(HcclBuf *buf, char **outDesc, + uint64_t *outDescLen); /** * @brief 通过描述信息重建内存缓冲区 @@ -63,7 +65,8 @@ extern HcclResult HcclMemExport(HcclBuf *buf, char **outDesc, uint64_t *outDescL * @param[out] outBuf 返回的缓冲区描述符 * @return 执行状态码 HcclResult */ -extern HcclResult HcclMemImport(const char *description, uint32_t descLen, bool isRemote, HcclBuf *outBuf); +extern HcclResult HcclMemImport(const char *description, uint32_t descLen, + bool isRemote, HcclBuf *outBuf); /** * @brief 关闭已打开的内存缓冲区 @@ -89,7 +92,8 @@ typedef struct { * @param[in] remoteGrantInfo 远端授权目标信息 * @return 执行状态码 HcclResult */ -extern HcclResult HcclMemGrant(HcclBuf *localBuf, const HcclMemGrantInfo *remoteGrantInfo); +extern HcclResult HcclMemGrant(HcclBuf *localBuf, + const HcclMemGrantInfo *remoteGrantInfo); /** * @brief 内存重映射接口 @@ -99,9 +103,10 @@ extern HcclResult HcclMemGrant(HcclBuf *localBuf, const HcclMemGrantInfo *remote * @return 执行状态码 HcclResult * @attention 需确保内存段已经在目标网络设备注册 */ -extern HcclResult HcclMemRemap(HcclNetDev netDev, const HcclMem *memArray, uint64_t arraySize); +extern HcclResult HcclMemRemap(HcclNetDev netDev, const HcclMem *memArray, + uint64_t arraySize); #ifdef __cplusplus } -#endif // __cplusplus +#endif // __cplusplus #endif \ No newline at end of file diff --git a/scripts/ascend/pkg/hccl_mem_defs.h b/scripts/ascend/pkg/hccl_mem_defs.h index 883b0f9f1..4f52d881d 100644 --- a/scripts/ascend/pkg/hccl_mem_defs.h +++ b/scripts/ascend/pkg/hccl_mem_defs.h @@ -10,16 +10,16 @@ #ifdef __cplusplus extern "C" { -#endif // __cplusplus +#endif // __cplusplus /** * @enum HcclMemType * @brief 内存类型枚举定义 */ typedef enum { - HCCL_MEM_TYPE_DEVICE, ///< 设备侧内存(如NPU等) - HCCL_MEM_TYPE_HOST, ///< 主机侧内存 - HCCL_MEM_TYPE_NUM ///< 内存类型数量 + HCCL_MEM_TYPE_DEVICE, ///< 设备侧内存(如NPU等) + HCCL_MEM_TYPE_HOST, ///< 主机侧内存 + HCCL_MEM_TYPE_NUM ///< 内存类型数量 } HcclMemType; /** @@ -37,5 +37,5 @@ typedef struct { #ifdef __cplusplus } -#endif // __cplusplus +#endif // __cplusplus #endif \ No newline at end of file diff --git a/scripts/ascend/pkg/transport_mem.h b/scripts/ascend/pkg/transport_mem.h index 63cfeb84e..51e10b245 100644 --- a/scripts/ascend/pkg/transport_mem.h +++ b/scripts/ascend/pkg/transport_mem.h @@ -27,12 +27,8 @@ enum class RmaMemType : int { constexpr size_t TRANSPORT_EMD_ESC_SIZE = 512U - (sizeof(u32) * 2); class TransportMem { -public: - enum class TpType : int { - IPC = 0, - ROCE = 1, - TYPE_NUM - }; + public: + enum class TpType : int { IPC = 0, ROCE = 1, TYPE_NUM }; struct AttrInfo { u32 localRankId; @@ -63,44 +59,54 @@ class TransportMem { u64 size; // segment的size }; - static std::shared_ptr Create(TpType tpType, - const std::unique_ptr ¬ifyPool, const HcclNetDevCtx &netDevCtx, const HcclDispatcher &dispatcher, + static std::shared_ptr Create( + TpType tpType, const std::unique_ptr ¬ifyPool, + const HcclNetDevCtx &netDevCtx, const HcclDispatcher &dispatcher, AttrInfo &attrInfo); - explicit TransportMem(const std::unique_ptr ¬ifyPool, const HcclNetDevCtx &netDevCtx, - const HcclDispatcher &dispatcher, AttrInfo &attrInfo); + explicit TransportMem(const std::unique_ptr ¬ifyPool, + const HcclNetDevCtx &netDevCtx, + const HcclDispatcher &dispatcher, AttrInfo &attrInfo); virtual ~TransportMem(); - virtual HcclResult ExchangeMemDesc( - const RmaMemDescs &localMemDescs, RmaMemDescs &remoteMemDescs, u32 &actualNumOfRemote) = 0; - virtual HcclResult EnableMemAccess(const RmaMemDesc &remoteMemDesc, RmaMem &remoteMem) = 0; + virtual HcclResult ExchangeMemDesc(const RmaMemDescs &localMemDescs, + RmaMemDescs &remoteMemDescs, + u32 &actualNumOfRemote) = 0; + virtual HcclResult EnableMemAccess(const RmaMemDesc &remoteMemDesc, + RmaMem &remoteMem) = 0; virtual HcclResult DisableMemAccess(const RmaMemDesc &remoteMemDesc) = 0; virtual HcclResult SetDataSocket(const std::shared_ptr &socket); virtual HcclResult SetSocket(const std::shared_ptr &socket) = 0; virtual HcclResult Connect(s32 timeoutSec) = 0; - virtual HcclResult Write(const RmaOpMem &remoteMem, const RmaOpMem &localMem, const rtStream_t &stream) = 0; - virtual HcclResult Read(const RmaOpMem &localMem, const RmaOpMem &remoteMem, const rtStream_t &stream) = 0; + virtual HcclResult Write(const RmaOpMem &remoteMem, + const RmaOpMem &localMem, + const rtStream_t &stream) = 0; + virtual HcclResult Read(const RmaOpMem &localMem, const RmaOpMem &remoteMem, + const rtStream_t &stream) = 0; virtual HcclResult AddOpFence(const rtStream_t &stream) = 0; -protected: + protected: // 从 string 拷贝到 memDesc - HcclResult RmaMemDescCopyFromStr(RmaMemDesc &rmaMemDesc, const std::string &memDescStr) const - { - if (memcpy_s(rmaMemDesc.memDesc, TRANSPORT_EMD_ESC_SIZE, memDescStr.c_str(), memDescStr.size() + 1) != EOK) { + HcclResult RmaMemDescCopyFromStr(RmaMemDesc &rmaMemDesc, + const std::string &memDescStr) const { + if (memcpy_s(rmaMemDesc.memDesc, TRANSPORT_EMD_ESC_SIZE, + memDescStr.c_str(), memDescStr.size() + 1) != EOK) { return HCCL_E_INTERNAL; } return HCCL_SUCCESS; } // 从 memDesc 转换为 string - std::string RmaMemDescCopyToStr(const RmaMemDesc &rmaMemDesc) const - { + std::string RmaMemDescCopyToStr(const RmaMemDesc &rmaMemDesc) const { return std::string(rmaMemDesc.memDesc, TRANSPORT_EMD_ESC_SIZE); } - HcclResult DoExchangeMemDesc(const RmaMemDescs &localMemDescs, RmaMemDescs &remoteMemDescs, u32 &actualNumOfRemote); + HcclResult DoExchangeMemDesc(const RmaMemDescs &localMemDescs, + RmaMemDescs &remoteMemDescs, + u32 &actualNumOfRemote); HcclResult SendLocalMemDesc(const RmaMemDescs &localMemDescs); - HcclResult ReceiveRemoteMemDesc(RmaMemDescs &remoteMemDescs, u32 &actualNumOfRemote); + HcclResult ReceiveRemoteMemDesc(RmaMemDescs &remoteMemDescs, + u32 &actualNumOfRemote); const std::unique_ptr ¬ifyPool_; HcclNetDevCtx netDevCtx_{nullptr};