Skip to content

Commit b80c07c

Browse files
Merge branch 'main' into lschneider/fix-static-nccl-tear-down
2 parents 591739e + 55bc6a5 commit b80c07c

File tree

294 files changed

+14922
-5320
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

294 files changed

+14922
-5320
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.<
1010
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
1111
[![cuda](https://img.shields.io/badge/cuda-13.0.0-green)](https://developer.nvidia.com/cuda-downloads)
1212
[![torch](https://img.shields.io/badge/torch-2.9.0-green)](https://pytorch.org)
13-
[![version](https://img.shields.io/badge/release-1.2.0rc6-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
13+
[![version](https://img.shields.io/badge/release-1.2.0rc7-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
1414
[![license](https://img.shields.io/badge/license-Apache%202-blue)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/LICENSE)
1515

1616
[Architecture](https://nvidia.github.io/TensorRT-LLM/developer-guide/overview.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Performance](https://nvidia.github.io/TensorRT-LLM/developer-guide/perf-overview.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](https://nvidia.github.io/TensorRT-LLM/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)

cpp/include/tensorrt_llm/executor/cacheCommunicator.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pragma once
1818

1919
#include "tensorrt_llm/executor/serialization.h"
20+
#include <atomic>
2021
#include <vector>
2122

2223
namespace tensorrt_llm::executor::kv_cache
@@ -27,8 +28,9 @@ class CommState;
2728
struct DataContext
2829
{
2930
public:
30-
explicit DataContext(int tag)
31+
explicit DataContext(int tag, std::atomic<bool> const& transferTerminate = sDefaultTransferTerminate)
3132
: mTag{tag}
33+
, mTransferTerminate(transferTerminate)
3234
{
3335
}
3436

@@ -37,8 +39,15 @@ struct DataContext
3739
return mTag;
3840
}
3941

42+
[[nodiscard]] std::atomic<bool> const& getTransferTerminate() const noexcept
43+
{
44+
return mTransferTerminate;
45+
}
46+
4047
private:
48+
inline static std::atomic<bool> sDefaultTransferTerminate{false};
4149
int const mTag;
50+
std::atomic<bool> const& mTransferTerminate;
4251
};
4352

4453
class Connection

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,9 @@ class CacheSender::Impl
358358

359359
TransceiverTag::Id id;
360360
RequestInfo info;
361-
auto const* connection = isAgent ? agentConnectionManager->recvConnectionAndRequestInfo(info)
362-
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG}, &id, sizeof(id));
361+
auto const* connection = isAgent
362+
? agentConnectionManager->recvConnectionAndRequestInfo(info, mTerminate)
363+
: mManager->recvConnect(DataContext{TransceiverTag::kID_TAG, mTerminate}, &id, sizeof(id));
363364
if (connection == nullptr && !mManager->isRunning())
364365
{
365366
TLLM_LOG_WARNING(" recvRequestInfo connection is nullptr, maybe the server is terminating");
@@ -395,8 +396,8 @@ class CacheSender::Impl
395396
if (it == mRequestToSession.end())
396397
{
397398
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
398-
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager,
399-
info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
399+
DataContext{tagFromRequestId(requestId), mTerminate}, mSelfState, info.getTransState(),
400+
mBufferManager, info.getIndexFromEnd(), info.getLastBlockKey(), nullptr,
400401
!common::getEnvKVCacheTimeOutputPath().empty());
401402
session.setTime(TransferSession::kTimeRequestInfo);
402403
it = mRequestToSession.emplace(requestId, std::move(session)).first;
@@ -685,6 +686,10 @@ class CacheSender::Impl
685686
{
686687
future.get();
687688
}
689+
if (mResponseFuture.valid())
690+
{
691+
mResponseFuture.get();
692+
}
688693
}
689694

690695
void removeResponse(std::map<RequestIdType, Response>::iterator it)
@@ -886,9 +891,9 @@ class CacheReceiver::Impl
886891
}
887892
}
888893
auto const& resource = getReceiveCacheResource(llmRequest);
889-
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
890-
contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(), requestInfo.getLastBlockKey(),
891-
&llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
894+
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId), mTerminate},
895+
mSelfState, contextState, resource->mBufferManager, requestInfo.getIndexFromEnd(),
896+
requestInfo.getLastBlockKey(), &llmRequest, !common::getEnvKVCacheTimeOutputPath().empty());
892897
}
893898

894899
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
@@ -964,7 +969,7 @@ class CacheReceiver::Impl
964969
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connections.at(i));
965970
TLLM_CHECK(agentConnection);
966971
isReady = agentConnection->recvReadySignal(
967-
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG});
972+
executor::kv_cache::DataContext{TransceiverTag::kREADY_SIGNAL_TAG, mTerminate});
968973
}
969974
else
970975
{
@@ -979,6 +984,7 @@ class CacheReceiver::Impl
979984

980985
~Impl()
981986
{
987+
mTerminate.store(true);
982988
for (auto&& [processInfo, asyncResource] : mInstanceToAsyncResource)
983989
{
984990
asyncResource->mTerminate = true;
@@ -1134,6 +1140,7 @@ class CacheReceiver::Impl
11341140
runtime::BufferManager mBufferManager;
11351141
std::ofstream mMeasuresFile;
11361142
std::mutex mMeasuresFileMutex;
1143+
std::atomic<bool> mTerminate{false};
11371144
};
11381145

11391146
void CacheSender::ImplDeleter::operator()(Impl* ptr)

cpp/tensorrt_llm/common/cudaFp8Utils.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_
4343
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
4444
{
4545
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
46-
asm volatile("griddepcontrol.wait;");
46+
cudaGridDependencySynchronize();
4747
#endif
4848

4949
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
@@ -63,7 +63,7 @@ __global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* i
6363
}
6464
}
6565
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
66-
asm volatile("griddepcontrol.launch_dependents;");
66+
cudaTriggerProgrammaticLaunchCompletion();
6767
#endif
6868
}
6969

cpp/tensorrt_llm/common/envUtils.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,26 @@ bool getEnvUseTileSizeKv64ForTrtllmGen()
249249
bool getEnvEnablePDL()
250250
{
251251
static std::once_flag flag;
252-
static bool enablePDL = false;
252+
static bool enablePDL = true;
253253

254254
std::call_once(flag,
255255
[&]()
256256
{
257257
if (getSMVersion() >= 90)
258258
{
259259
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
260-
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
260+
char const* env = std::getenv("TRTLLM_ENABLE_PDL");
261+
if (env)
262+
{
263+
if (env[0] == '1' && env[1] == '\0')
264+
{
265+
enablePDL = true;
266+
}
267+
else if (env[0] == '0' && env[1] == '\0')
268+
{
269+
enablePDL = false;
270+
}
271+
};
261272
}
262273
});
263274
return enablePDL;

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/grid_dependency_control.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ CUTLASS_DEVICE
4646
void launch_dependent_grids()
4747
{
4848
#if (defined(CUTLASS_GDC_ENABLED))
49-
asm volatile("griddepcontrol.launch_dependents;");
49+
cudaTriggerProgrammaticLaunchCompletion();
5050
#endif
5151
}
5252

@@ -57,7 +57,7 @@ CUTLASS_DEVICE
5757
void wait_on_dependent_grids()
5858
{
5959
#if (defined(CUTLASS_GDC_ENABLED))
60-
asm volatile("griddepcontrol.wait;");
60+
cudaGridDependencySynchronize();
6161
#endif
6262
}
6363

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ void AgentConnection::recv(DataContext const& ctx, void* data, size_t size) cons
150150
{
151151

152152
NotificationSyncInfo syncInfo{mAgentName, ctx};
153-
mAgentConnectionManager->waitForSyncInfo(mRemoteAgentName, syncInfo);
153+
mAgentConnectionManager->waitForSyncInfo(mRemoteAgentName, syncInfo, ctx.getTransferTerminate());
154154
}
155155

156156
void AgentConnection::sendRequestAndBufferInfo(batch_manager::RequestInfo& requestInfo,
@@ -230,7 +230,7 @@ void AgentConnection::sendReadySignal(DataContext const& ctx, bool isReady) cons
230230
bool AgentConnection::recvReadySignal(DataContext const& ctx) const
231231
{
232232
ReadySignalInfo readySignalInfo{mAgentName, ctx, false};
233-
mAgentConnectionManager->waitForReadySignal(mRemoteAgentName, readySignalInfo);
233+
mAgentConnectionManager->waitForReadySignal(mRemoteAgentName, readySignalInfo, ctx.getTransferTerminate());
234234
return readySignalInfo.mIsReady;
235235
}
236236

@@ -315,9 +315,10 @@ AgentConnectionManager::AgentConnectionManager(
315315
" ***** AgentConnectionManager::AgentConnectionManager mCommState: %s", mCommState.toString().c_str());
316316
}
317317

318-
AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(batch_manager::RequestInfo& requestInfo)
318+
AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(
319+
batch_manager::RequestInfo& requestInfo, std::atomic<bool> const& terminateFlag)
319320
{
320-
while (true)
321+
while (!terminateFlag.load())
321322
{
322323
if (!mIsRunning)
323324
{
@@ -490,16 +491,16 @@ int AgentConnectionManager::getDeviceId() const
490491
}
491492

492493
template <typename NotificationType>
493-
void AgentConnectionManager::waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo)
494+
void AgentConnectionManager::waitForNotification(
495+
std::string const& remoteAgentName, NotificationType& expectedInfo, std::atomic<bool> const& terminateFlag)
494496
{
495-
while (true)
497+
while (!terminateFlag.load())
496498
{
497499

498500
if (!mIsRunning)
499501
{
500502
return;
501503
}
502-
503504
updateUnhandledNotifications();
504505
std::scoped_lock lock(mNotificationMutex);
505506
auto it = mUnhandledNotifications.begin();
@@ -575,18 +576,20 @@ void AgentConnectionManager::waitForNotification(std::string const& remoteAgentN
575576

576577
// Explicit template instantiations
577578
template void AgentConnectionManager::waitForNotification<NotificationSyncInfo>(
578-
std::string const& remoteAgentName, NotificationSyncInfo& expectedInfo);
579+
std::string const& remoteAgentName, NotificationSyncInfo& expectedInfo, std::atomic<bool> const& terminateFlag);
579580
template void AgentConnectionManager::waitForNotification<ReadySignalInfo>(
580-
std::string const& remoteAgentName, ReadySignalInfo& expectedInfo);
581+
std::string const& remoteAgentName, ReadySignalInfo& expectedInfo, std::atomic<bool> const& terminateFlag);
581582

582-
void AgentConnectionManager::waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo)
583+
void AgentConnectionManager::waitForSyncInfo(
584+
std::string const& remoteAgentName, NotificationSyncInfo& syncInfo, std::atomic<bool> const& terminateFlag)
583585
{
584-
waitForNotification(remoteAgentName, syncInfo);
586+
waitForNotification(remoteAgentName, syncInfo, terminateFlag);
585587
}
586588

587-
void AgentConnectionManager::waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo)
589+
void AgentConnectionManager::waitForReadySignal(
590+
std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo, std::atomic<bool> const& terminateFlag)
588591
{
589-
waitForNotification(remoteAgentName, readySignalInfo);
592+
waitForNotification(remoteAgentName, readySignalInfo, terminateFlag);
590593
}
591594

592595
std::string const& AgentConnectionManager::getAgentName() const

cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,8 @@ class AgentConnectionManager : public ConnectionManager
282282
AgentConnection* recvConnect(DataContext const& ctx, void* data, size_t size) override;
283283
[[nodiscard]] std::vector<Connection const*> getConnections(CommState const& state) override;
284284
[[nodiscard]] CommState const& getCommState() const override;
285-
AgentConnection const* recvConnectionAndRequestInfo(batch_manager::RequestInfo& requestInfo);
285+
AgentConnection const* recvConnectionAndRequestInfo(
286+
batch_manager::RequestInfo& requestInfo, std::atomic<bool> const& terminateFlag);
286287
[[nodiscard]] std::vector<batch_manager::kv_cache_manager::CacheTransBufferManager*> const&
287288
getCacheTransBufferManagers() const;
288289
void updateUnhandledNotifications();
@@ -293,9 +294,12 @@ class AgentConnectionManager : public ConnectionManager
293294
[[nodiscard]] std::string const& getAgentName() const;
294295

295296
template <typename NotificationType>
296-
void waitForNotification(std::string const& remoteAgentName, NotificationType& expectedInfo);
297-
void waitForSyncInfo(std::string const& remoteAgentName, NotificationSyncInfo& syncInfo);
298-
void waitForReadySignal(std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo);
297+
void waitForNotification(
298+
std::string const& remoteAgentName, NotificationType& expectedInfo, std::atomic<bool> const& terminateFlag);
299+
void waitForSyncInfo(
300+
std::string const& remoteAgentName, NotificationSyncInfo& syncInfo, std::atomic<bool> const& terminateFlag);
301+
void waitForReadySignal(
302+
std::string const& remoteAgentName, ReadySignalInfo& readySignalInfo, std::atomic<bool> const& terminateFlag);
299303
[[nodiscard]] bool isRunning() const override;
300304

301305
private:

0 commit comments

Comments
 (0)