Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1468,7 +1468,8 @@ class CacheTransceiverConfig
DEFAULT = 0,
MPI = 1,
UCX = 2,
NIXL = 3
NIXL = 3,
MOONCAKE = 4
};
explicit CacheTransceiverConfig(std::optional<BackendType> backendType = std::nullopt,
std::optional<size_t> maxNumTokens = std::nullopt, std::optional<int> kvTransferTimeoutMs = std::nullopt,
Expand Down
8 changes: 8 additions & 0 deletions cpp/include/tensorrt_llm/executor/transferAgent.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,14 @@ template <typename... Args>
"libtensorrt_llm_nixl_wrapper.so", "createNixlTransferAgent");
return func(std::forward<Args>(args)...);
}
if (backend == "mooncake")
{
auto& loader = DynLibLoader::getInstance();
using CreateMooncakeFuncType = std::unique_ptr<BaseTransferAgent> (*)(BaseAgentConfig const*);
auto* func = loader.getFunctionPointer<CreateMooncakeFuncType>(
"libtensorrt_llm_mooncake_wrapper.so", "createMooncakeTransferAgent");
return func(std::forward<Args>(args)...);
}
TLLM_THROW("Unknown backend name.");
}

Expand Down
9 changes: 9 additions & 0 deletions cpp/tensorrt_llm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ if(NIXL_ROOT)
set(NIXL_WRAPPER_TARGET tensorrt_llm_nixl_wrapper)
endif()

if(MOONCAKE_ROOT)
set(MOONCAKE_WRAPPER_TARGET tensorrt_llm_mooncake_wrapper)
endif()

add_subdirectory(executor)

find_package(Threads REQUIRED)
Expand Down Expand Up @@ -272,6 +276,11 @@ if(TARGET ${NIXL_WRAPPER_TARGET})
add_dependencies(${SHARED_TARGET} ${NIXL_WRAPPER_TARGET})
endif()

if(TARGET ${MOONCAKE_WRAPPER_TARGET})
target_link_libraries(${MOONCAKE_WRAPPER_TARGET} INTERFACE ${SHARED_TARGET})
add_dependencies(${SHARED_TARGET} ${MOONCAKE_WRAPPER_TARGET})
endif()

if(NOT WIN32)
# Load libraries at $PREFIX/lib from
# $PREFIX/lib/python3.12/site-packages/tensorrt_llm/libs
Expand Down
13 changes: 12 additions & 1 deletion cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ std::unique_ptr<BaseCacheTransceiver> CacheTransceiverFactory::createCacheTransc
backendType = executor::CacheTransceiverConfig::BackendType::NIXL;
TLLM_LOG_INFO("Enable NIXL KV cache transport.");
}
else if (common::getEnvUseMooncakeKvCache())
{
backendType = executor::CacheTransceiverConfig::BackendType::MOONCAKE;
TLLM_LOG_INFO("Enable MOONCAKE KV cache transport.");
}
else if (common::getEnvUseMPIKvCache())
{
backendType = executor::CacheTransceiverConfig::BackendType::MPI;
Expand Down Expand Up @@ -203,9 +208,15 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL)
{
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
mCacheTransBufferManagerPtrs, *mCacheState);
mCacheTransBufferManagerPtrs, *mCacheState, "nixl");
TLLM_LOG_INFO("NIXL Connection Manager created");
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MOONCAKE)
{
mManager = std::make_unique<tensorrt_llm::executor::kv_cache::AgentConnectionManager>(
mCacheTransBufferManagerPtrs, *mCacheState, "mooncake");
TLLM_LOG_INFO("MOONCAKE Connection Manager created");
}
else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI)
{
mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world());
Expand Down
23 changes: 23 additions & 0 deletions cpp/tensorrt_llm/common/envUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ bool getEnvUseNixlKvCache()
return useNixlKvCache;
}

bool getEnvUseMooncakeKvCache()
{
static bool const useMooncakeKvCache = getBoolEnv("TRTLLM_USE_MOONCAKE_KVCACHE");
return useMooncakeKvCache;
}

bool getEnvUseRoundRobinBlockDistForCP()
{
static bool const useRoundRobinBlockDistForCP = getBoolEnv("TRTLLM_USE_ROUND_ROBIN_BLOCK_DIST_FOR_CP");
Expand Down Expand Up @@ -343,6 +349,23 @@ std::string getEnvNixlBackend()
return nixlBackend;
}

std::string getEnvMooncakeInterface()
{
static std::once_flag flag;
static std::string mooncakeInterface;

std::call_once(flag,
[&]()
{
char const* mooncake_interface = std::getenv("TRTLLM_MOONCAKE_INTERFACE");
if (mooncake_interface)
{
mooncakeInterface = mooncake_interface;
}
});
return mooncakeInterface;
}

bool getEnvDisaggLayerwise()
{
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
Expand Down
5 changes: 5 additions & 0 deletions cpp/tensorrt_llm/common/envUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,11 @@ inline void launchWithPdlWhenEnabled(char const* name, KernelFn kernelFn, dim3 g
bool getEnvUseUCXKvCache();

bool getEnvUseMPIKvCache();

bool getEnvUseNixlKvCache();

bool getEnvUseMooncakeKvCache();

bool getEnvUseRoundRobinBlockDistForCP();

std::string getEnvUCXInterface();
Expand All @@ -93,6 +96,8 @@ std::string getEnvNixlInterface();

std::string getEnvNixlBackend();

std::string getEnvMooncakeInterface();

bool getEnvDisaggLayerwise();

bool getEnvParallelCacheSend();
Expand Down
226 changes: 226 additions & 0 deletions cpp/tensorrt_llm/common/ipUtils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ipUtils.h"
#include "tensorrt_llm/common/logger.h"

#include <arpa/inet.h>
#include <dirent.h>
#include <fcntl.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <netdb.h>
#include <netinet/in.h>
#include <string>
#include <sys/socket.h>
#include <unistd.h>

TRTLLM_NAMESPACE_BEGIN

namespace common
{

std::string getLocalIpByNic(std::string const& interface, int rank)
{
struct ifaddrs* ifaddr = nullptr;
if (getifaddrs(&ifaddr) == -1)
{
TLLM_LOG_ERROR(rank,
"getLocalIpByNic: Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is "
"set "
"correctly.");
return std::string{};
}

for (struct ifaddrs* ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next)
{
if (ifa->ifa_addr == nullptr)
{
continue;
}

if (ifa->ifa_name == interface)
{
if (ifa->ifa_addr->sa_family == AF_INET)
{
char ip[INET_ADDRSTRLEN]{};
void* addr = &((reinterpret_cast<struct sockaddr_in*>(ifa->ifa_addr))->sin_addr);
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "0.0.0.0") != 0)
{
freeifaddrs(ifaddr);
return std::string(ip);
}
}
else if (ifa->ifa_addr->sa_family == AF_INET6)
{
char ip[INET6_ADDRSTRLEN]{};
void* addr = &((reinterpret_cast<struct sockaddr_in6*>(ifa->ifa_addr))->sin6_addr);
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
&& std::strcmp(ip, "::1") != 0)
{
freeifaddrs(ifaddr);
return std::string(ip);
}
}
}
}

freeifaddrs(ifaddr);
TLLM_LOG_ERROR(
rank, "Can't get local ip from NIC Interface. Please check whether corresponding INTERFACE is set correctly.");
return std::string{};
}

std::string getLocalIpByHostname(int rank)
{
char hostname[256]{};
if (gethostname(hostname, sizeof(hostname)) == -1)
{
TLLM_LOG_ERROR(rank, "getLocalIpByHostname: Can't get hostname");
return std::string{};
}

struct addrinfo hints = {};
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_CANONNAME;

struct addrinfo* res = nullptr;
if (getaddrinfo(hostname, nullptr, &hints, &res) != 0)
{
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get address info for hostname");
return std::string{};
}

for (struct addrinfo* p = res; p != nullptr; p = p->ai_next)
{

if (p->ai_family == AF_INET)
{ // IPv4
char ip[INET_ADDRSTRLEN]{};
struct sockaddr_in* ipv4 = reinterpret_cast<struct sockaddr_in*>(p->ai_addr);
void* addr = &(ipv4->sin_addr);
if ((inet_ntop(AF_INET, addr, ip, sizeof(ip)) != nullptr) && std::strcmp(ip, "127.0.0.1") != 0
&& std::strcmp(ip, "0.0.0.0") != 0)
{
freeaddrinfo(res);
return std::string(ip);
}
}
else if (p->ai_family == AF_INET6)
{ // IPv6
char ip[INET6_ADDRSTRLEN]{};
struct sockaddr_in6* ipv6 = reinterpret_cast<struct sockaddr_in6*>(p->ai_addr);
void* addr = &(ipv6->sin6_addr);
if ((inet_ntop(AF_INET6, addr, ip, sizeof(ip)) != nullptr) && std::strncmp(ip, "fe80::", 6) != 0
&& std::strcmp(ip, "::1") != 0)
{
freeaddrinfo(res);
return std::string(ip);
}
}
}

freeaddrinfo(res);
TLLM_LOG_WARNING(rank, "getLocalIpByHostname: Can't get local ip from hostname");
return std::string{};
}

std::string getLocalIpByRemoteOrHostName(int rank)
{

// Try IPv4
struct sockaddr_in addr
{
};

addr.sin_family = AF_INET;
addr.sin_port = htons(80);
// using google's public dns server to get the local ip which can be accessed from remote
char const* dns_ip_v4 = "8.8.8.8";
inet_pton(AF_INET, dns_ip_v4, &addr.sin_addr);

int sock = socket(AF_INET, SOCK_DGRAM, 0);
if (sock != -1)
{
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)) != -1)
{
socklen_t addr_len = sizeof(addr);
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr), &addr_len) != -1)
{
char ip[INET_ADDRSTRLEN]{};
inet_ntop(AF_INET, &addr.sin_addr, ip, sizeof(ip));
close(sock);
return std::string(ip);
}
}
close(sock);
}

// Try IPv6
struct sockaddr_in6 addr6
{
};

addr6.sin6_family = AF_INET6;
addr6.sin6_port = htons(80);
// using google's public dns server
char const* dns_ipv6 = "2001:4860:4860::8888";
inet_pton(AF_INET6, dns_ipv6, &addr6.sin6_addr);

sock = socket(AF_INET6, SOCK_DGRAM, 0);
if (sock != -1)
{
if (connect(sock, reinterpret_cast<struct sockaddr*>(&addr6), sizeof(addr6)) != -1)
{
socklen_t addr_len = sizeof(addr6);
if (getsockname(sock, reinterpret_cast<struct sockaddr*>(&addr6), &addr_len) != -1)
{
char ip[INET6_ADDRSTRLEN]{};
inet_ntop(AF_INET6, &addr6.sin6_addr, ip, sizeof(ip));
close(sock);
return std::string(ip);
}
}
close(sock);
}

// Try hostname
return getLocalIpByHostname(rank);
}

std::string getLocalIp(std::string interface, int rank)
{
std::string localIP = {};
if (!interface.empty())
{
localIP = getLocalIpByNic(interface, rank);
}
if (localIP.empty())
{
localIP = getLocalIpByRemoteOrHostName(rank);
}
// check whether the localIP is valid
if (localIP.empty())
{
TLLM_THROW("getLocalIp: Can't get local ip");
}
return localIP;
}
} // namespace common

TRTLLM_NAMESPACE_END
28 changes: 28 additions & 0 deletions cpp/tensorrt_llm/common/ipUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tensorrt_llm/common/config.h"
#include <string>

TRTLLM_NAMESPACE_BEGIN

namespace common
{
std::string getLocalIp(std::string interface, int rank);
} // namespace common

TRTLLM_NAMESPACE_END
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/executor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,4 @@ target_compile_definitions(${EXECUTOR_STATIC_TARGET}

add_subdirectory(cache_transmission/ucx_utils)
add_subdirectory(cache_transmission/nixl_utils)
add_subdirectory(cache_transmission/mooncake_utils)
Loading
Loading