Skip to content

Commit e3c1a94

Browse files
authored
[TRTLLM-6549][fix] add kv cache time output back (#7798)
Signed-off-by: zhengd-nv <200704041+zhengd-nv@users.noreply.github.com>
1 parent 7d4d6cc commit e3c1a94

File tree

3 files changed

+91
-2
lines changed

3 files changed

+91
-2
lines changed

cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,21 @@ static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId)
156156
return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF);
157157
}
158158

159+
namespace fs = std::filesystem;
160+
161+
static fs::path getTransferOutputPath(char const* tag)
162+
{
163+
auto outputPath = common::getEnvKVCacheTransferOutputPath();
164+
if (!outputPath.empty())
165+
{
166+
auto rank = mpi::MpiComm::world().getRank();
167+
auto path = fs::path(outputPath);
168+
fs::create_directories(path);
169+
return path / ("rank_" + std::to_string(rank) + "_" + tag + ".csv");
170+
}
171+
return {};
172+
}
173+
159174
struct ReceiveCacheResource
160175
{
161176
runtime::BufferManager mBufferManager;
@@ -282,6 +297,17 @@ class CacheSender::Impl
282297
auto it = mRequestToSession.find(requestId);
283298
TLLM_CHECK(it != mRequestToSession.end());
284299
std::unique_lock<std::mutex> lk(mMtxForMap);
300+
if (!common::getEnvKVCacheTransferOutputPath().empty())
301+
{
302+
if (!mMeasuresFile.is_open())
303+
{
304+
auto outputPath = getTransferOutputPath("send");
305+
mMeasuresFile.open(outputPath);
306+
TLLM_CHECK_WITH_INFO(
307+
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
308+
}
309+
it->second.exportMeasure(mMeasuresFile, true);
310+
}
285311
mRequestToSession.erase(it);
286312
}
287313

@@ -331,7 +357,8 @@ class CacheSender::Impl
331357
if (it == mRequestToSession.end())
332358
{
333359
auto session = TransferSession(std::vector<Connection const*>(peerRelativeRanks.size(), nullptr),
334-
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager);
360+
DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr,
361+
!common::getEnvKVCacheTransferOutputPath().empty());
335362
it = mRequestToSession.emplace(requestId, std::move(session)).first;
336363
}
337364
it->second.setConnection(peerIdx, connection);
@@ -527,6 +554,7 @@ class CacheSender::Impl
527554
std::unique_ptr<BaseCacheFormatter> mFormatter;
528555
std::mutex mMtxForMap;
529556
runtime::BufferManager mBufferManager;
557+
std::ofstream mMeasuresFile;
530558
};
531559

532560
class CacheReceiver::Impl
@@ -587,6 +615,18 @@ class CacheReceiver::Impl
587615
void receiveSync(TransferSession& session)
588616
{
589617
mFormatter->unformat(session);
618+
if (!common::getEnvKVCacheTransferOutputPath().empty())
619+
{
620+
std::unique_lock<std::mutex> lock(mMeasuresFileMutex);
621+
if (!mMeasuresFile.is_open())
622+
{
623+
auto outputPath = getTransferOutputPath("recv");
624+
mMeasuresFile.open(outputPath);
625+
TLLM_CHECK_WITH_INFO(
626+
mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str());
627+
}
628+
session.exportMeasure(mMeasuresFile, false);
629+
}
590630
}
591631

592632
TransferSession sendRequestInfo(LlmRequest const& llmRequest)
@@ -652,7 +692,7 @@ class CacheReceiver::Impl
652692
}
653693
auto const& resource = getReceiveCacheResource(llmRequest);
654694
return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState,
655-
contextState, resource->mBufferManager, &llmRequest);
695+
contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty());
656696
}
657697

658698
std::unique_ptr<ReceiveCacheResource> const& getReceiveCacheResource(LlmRequest const& llmRequest)
@@ -831,6 +871,8 @@ class CacheReceiver::Impl
831871
std::unordered_map<std::string, std::unique_ptr<ReceiveCacheResource>> mProcessToResources;
832872
std::mutex mProcessIoResouceMutex;
833873
runtime::BufferManager mBufferManager;
874+
std::ofstream mMeasuresFile;
875+
std::mutex mMeasuresFileMutex;
834876
};
835877

836878
void CacheSender::ImplDeleter::operator()(Impl* ptr)

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,52 @@ def extra_endpoints_test(server_url: str):
584584
extra_endpoints_test=extra_endpoints_test)
585585

586586

587+
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
588+
indirect=True)
589+
def test_disaggregated_kv_cache_time_output(disaggregated_test_root, llm_venv,
590+
disaggregated_example_root,
591+
llama_model_root):
592+
src_dst_dict = {
593+
llama_model_root:
594+
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
595+
}
596+
for src, dst in src_dst_dict.items():
597+
if not os.path.islink(dst):
598+
os.makedirs(os.path.dirname(dst), exist_ok=True)
599+
os.symlink(src, dst, target_is_directory=True)
600+
601+
output_path = os.path.join(llm_venv.get_working_directory(), "cache_time")
602+
run_disaggregated_test(disaggregated_example_root,
603+
"perf_metrics",
604+
env=llm_venv._new_env
605+
| {"TRTLLM_KVCACHE_TIME_OUTPUT_PATH": output_path},
606+
cwd=llm_venv.get_working_directory())
607+
assert os.path.isdir(output_path)
608+
send_file = os.path.join(output_path, "rank_0_send.csv")
609+
recv_file = os.path.join(output_path, "rank_1_recv.csv")
610+
assert os.path.exists(send_file)
611+
assert os.path.exists(recv_file)
612+
with open(send_file, "r") as f:
613+
lines = f.readlines()
614+
assert len(lines) > 1
615+
assert lines[0].startswith(
616+
"RequestID,Delay(ms),Duration(ms),Bandwidth(Gbps)")
617+
# get a send sample and match the recv
618+
sample = lines[1].split(',')
619+
assert len(sample) >= 4
620+
with open(recv_file, "r") as f:
621+
lines = f.readlines()
622+
assert len(lines) > 1
623+
matched = False
624+
for line in lines:
625+
sample_recv = line.split(',')
626+
if sample_recv[0] == sample[0]:
627+
matched = True
628+
assert float(sample_recv[1]) <= float(sample[1])
629+
break
630+
assert matched
631+
632+
587633
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
588634
indirect=True)
589635
def test_disaggregated_trtllm_sampler(disaggregated_test_root, llm_venv,

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ l0_h100:
7575
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_two_mtp[DeepSeek-V3-Lite-fp8]
7676
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx_tp1_single_gpu[DeepSeek-V3-Lite-fp8]
7777
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
78+
- disaggregated/test_disaggregated.py::test_disaggregated_kv_cache_time_output[TinyLlama-1.1B-Chat-v1.0]
7879
- disaggregated/test_disaggregated.py::test_disaggregated_mixed[TinyLlama-1.1B-Chat-v1.0]
7980
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
8081
- disaggregated/test_disaggregated.py::test_disaggregated_perf_metrics[TinyLlama-1.1B-Chat-v1.0]

0 commit comments

Comments
 (0)