Skip to content

Commit 4cc4cbe

Browse files
[https://nvbugs/5716787][fix] terminate nixl running when exiting (#9785)
Signed-off-by: Chuang Zhu <[email protected]> Co-authored-by: Patrice Castonguay <[email protected]>
1 parent 9c59c9f commit 4cc4cbe

File tree

9 files changed

+54
-4
lines changed

9 files changed

+54
-4
lines changed

cpp/include/tensorrt_llm/executor/cacheCommunicator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class ConnectionManager
6666
[[nodiscard]] virtual std::vector<Connection const*> getConnections(CommState const& state) = 0;
6767

6868
[[nodiscard]] virtual CommState const& getCommState() const = 0;
69+
[[nodiscard]] virtual bool isRunning() const = 0;
6970
};
7071

7172
} // namespace tensorrt_llm::executor::kv_cache

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,12 @@ class CacheSender::Impl
360360
RequestInfo info;
361361
auto const* connection = isAgent ? agentConnectionManager->recvConnectionAndRequestInfo(info)
362362
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id));
363+
if (connection == nullptr && !mManager->isRunning())
364+
{
365+
TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating");
366+
return info;
367+
}
368+
363369
if (!isAgent)
364370
{
365371
TLLM_CHECK(id == TransceiverTag::Id::REQUEST_SEND);
@@ -616,6 +622,10 @@ class CacheSender::Impl
616622
if (!mReadyResponses.empty())
617623
{
618624
auto const& requestInfo = recvRequestInfo();
625+
if (mTerminate || !mManager->isRunning())
626+
{
627+
return;
628+
}
619629
auto reqId = requestInfo.getRequestId();
620630

621631
{

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(batc
319319
{
320320
while (true)
321321
{
322+
if (!mIsRunning)
323+
{
324+
return nullptr;
325+
}
322326
updateUnhandledNotifications();
323327
std::scoped_lock lock(mNotificationMutex);
324328
auto it = mUnhandledNotifications.begin();
@@ -491,6 +495,11 @@ void AgentConnectionManager::waitForNotification(std::string const& remoteAgentN
491495
while (true)
492496
{
493497

498+
if (!mIsRunning)
499+
{
500+
return;
501+
}
502+
494503
updateUnhandledNotifications();
495504
std::scoped_lock lock(mNotificationMutex);
496505
auto it = mUnhandledNotifications.begin();
@@ -587,6 +596,13 @@ std::string const& AgentConnectionManager::getAgentName() const
587596

588597
AgentConnectionManager::~AgentConnectionManager()
589598
{
599+
mIsRunning = false;
590600
m_Agent->deregisterMemory(mRegMemDescs);
591601
}
602+
603+
bool AgentConnectionManager::isRunning() const
604+
{
605+
return mIsRunning;
606+
}
607+
592608
} // namespace tensorrt_llm::executor::kv_cache

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ class AgentConnectionManager : public ConnectionManager
296296
void waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo);
297297
void waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo);
298298
void waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo);
299+
[[nodiscard]] bool isRunning() const override;
299300

300301
private:
301302
std::map<std::string, std::shared_ptr<AgentConnection>> mConnections;
@@ -309,6 +310,7 @@ class AgentConnectionManager : public ConnectionManager
309310
int mDeviceId;
310311
std::string mAgentName;
311312
MemoryDescs mRegMemDescs;
313+
std::atomic<bool> mIsRunning{true};
312314
};
313315

314316
} // namespace tensorrt_llm::executor::kv_cache

cpp/tensorrt_llm/executor/cache_transmission/mpi_utils/connection.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,13 @@ CommState const& MpiConnectionManager::getCommState() const
7777
return mCommState;
7878
}
7979

80+
bool MpiConnectionManager::isRunning() const
81+
{
82+
return mIsRunning;
83+
}
84+
85+
MpiConnectionManager::~MpiConnectionManager()
86+
{
87+
mIsRunning = false;
88+
}
8089
} // namespace tensorrt_llm::executor::kv_cache

cpp/tensorrt_llm/executor/cache_transmission/mpi_utils/connection.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,17 @@ class MpiConnectionManager : public ConnectionManager
4242
{
4343
public:
4444
MpiConnectionManager(mpi::MpiComm const* comm);
45+
~MpiConnectionManager();
4546
MpiConnection const* recvConnect(DataContext const& ctx, void* data, size_t size) override;
4647
[[nodiscard]] std::vector<Connection const*> getConnections(CommState const& state) override;
4748
[[nodiscard]] CommState const& getCommState() const override;
49+
[[nodiscard]] bool isRunning() const override;
4850

4951
private:
5052
mpi::MpiComm const* mComm;
5153
std::map<int, MpiConnection> mConnections;
5254
CommState mCommState;
55+
std::atomic<bool> mIsRunning{true};
5356
};
5457

5558
} // namespace tensorrt_llm::executor::kv_cache

cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ UcxConnectionManager::~UcxConnectionManager()
504504
socket.close();
505505
mZmqRepThread.join();
506506
}
507-
507+
mIsRunning = false;
508508
mZmqRepSocket.close();
509509

510510
mZmqContext.close();
@@ -673,6 +673,11 @@ std::vector<Connection const*> UcxConnectionManager::getConnections(CommState co
673673
return ret;
674674
}
675675

676+
bool UcxConnectionManager::isRunning() const
677+
{
678+
return mIsRunning;
679+
}
680+
676681
CommState const& UcxConnectionManager::getCommState() const
677682
{
678683
return mCommState;

cpp/tensorrt_llm/executor/cache_transmission/ucx_utils/ucxCacheCommunicator.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class UcxConnectionManager : public ConnectionManager, public std::enable_shared
6262
zmq::socket_t mZmqRepSocket;
6363
std::string mZmqRepEndpoint;
6464
std::thread mZmqRepThread;
65+
std::atomic<bool> mIsRunning{true};
6566

6667
UcxConnection::ConnectionIdType getNewConnectionId(std::shared_ptr<ucxx::Endpoint> const& newEp);
6768
UcxConnection::ConnectionIdType addConnection(std::string const& ip, uint16_t port);
@@ -85,6 +86,8 @@ class UcxConnectionManager : public ConnectionManager, public std::enable_shared
8586
{
8687
return mRank;
8788
}
89+
90+
[[nodiscard]] bool isRunning() const override;
8891
};
8992

9093
#if defined(__clang__)

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,8 +1052,9 @@ def test_llm_context_only_timed_out():
10521052
@pytest.mark.part0
10531053
@skip_ray
10541054
@pytest.mark.parametrize("sender_future_timeout_ms", [100, 1000])
1055-
def test_llm_context_only_timed_out_kv_cache_exhausted(
1056-
sender_future_timeout_ms):
1055+
@pytest.mark.parametrize("backend", ["NIXL", "UCX"])
1056+
def test_llm_context_only_timed_out_kv_cache_exhausted(sender_future_timeout_ms,
1057+
backend):
10571058
tp_size = 1
10581059
use_overlap = False
10591060
enable_iter_req_stats = False
@@ -1073,7 +1074,7 @@ def test_llm_context_only_timed_out_kv_cache_exhausted(
10731074
kv_cache_config=kv_cache_config,
10741075
tensor_parallel_size=tp_size,
10751076
cache_transceiver_config=CacheTransceiverConfig(
1076-
backend="UCX",
1077+
backend=backend,
10771078
kv_transfer_timeout_ms=1000,
10781079
kv_transfer_sender_future_timeout_ms=sender_future_timeout_ms),
10791080
**llm_args_extra)

0 commit comments

Comments
 (0)