Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,19 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa

if (mCacheState->getParallelConfig().mEnableAttentionDP)
{
int TPSizeInDPGroup
= mCacheState->getParallelConfig().mTensorParallelism / mCacheState->getParallelConfig().mDPsize;
int DPSize = mCacheState->getParallelConfig().mDPsize;
int TPRankInDPGroup = worldConfig.getTensorParallelRank() % TPSizeInDPGroup;

int DPRank = (worldConfig.getRank() - TPSizeInDPGroup * DPSize * worldConfig.getPipelineParallelRank()
- TPRankInDPGroup)
/ TPSizeInDPGroup;
// <PP,DP,TP>
// DPRank is derived from the tensor parallel rank, which already accounts for CP.
// Layout: rank = ppRank * (TP * CP) + tpRank * CP + cpRank.
// getTensorParallelRank() correctly extracts tpRank regardless of CP.
int DPRank = mCacheState->getParallelConfig().mDPrank;
// <PP,DP,TP,CP>
mGroupDataComm = std::make_shared<CacheTransceiverComm>(mGroupComm->split(DPRank, worldConfig.getRank()));
if (worldConfig.isTensorParallel())
{
// Group ranks with same (ppRank, DPRank) accounting for CP.
mGroupTPInDPComm = std::make_shared<CacheTransceiverComm>(
mGroupComm->split(worldConfig.getRank() / TPSizeInDPGroup, worldConfig.getRank()));
mGroupComm->split(worldConfig.getPipelineParallelRank() * DPSize + DPRank, worldConfig.getRank()));
}
}
bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA;
Expand Down
116 changes: 92 additions & 24 deletions cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,9 +552,10 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
mCpSize = genCp;
}

mTpRank = mRankInInstance % mTpSize;
// Rank formula must match targetIRanks: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank.
mCpRank = mRankInInstance % mCpSize;
mTpRank = (mRankInInstance / mCpSize) % mTpSize;
mPpRank = mRankInInstance / (mTpSize * mCpSize);
mCpRank = (mRankInInstance % (mTpSize * mCpSize)) / mTpSize;
mContextRankSize = contextRanks;
mGenRankSize = genRanks;
mContextTpSize = contextTp;
Expand Down Expand Up @@ -887,7 +888,16 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
auto makeLlmRequestWithDP(SizeType32 length, LlmRequest::RequestIdType requestId, int contextDpRank)
{
constexpr SizeType32 maxNewTokens{1};
texec::Request request{VecTokens(length), maxNewTokens};
auto const tokensPerBlock = mContextCacheState->getModelConfig().mTokensPerBlock;

std::optional<CPMetaData> cpMetaData;
int seqLen = length;
if (mCpSize > 1)
{
cpMetaData.emplace(length, tokensPerBlock, mCpRank, mCpSize);
seqLen = cpMetaData.value().mSeqLenOnThisCPRank;
}
texec::Request request{VecTokens(seqLen, seqLen), maxNewTokens};

auto state = std::make_unique<texec::DataTransceiverState>();
state->setCommState(texec::kv_cache::CommState{*mContextCommState});
Expand All @@ -905,7 +915,6 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
request.setContextPhaseParams(std::move(stats));
auto llmRequestPtr = std::make_unique<LlmRequest>(requestId, std::move(request));

std::optional<CPMetaData> cpMetaData;
return std::make_unique<WrappedLlmRequest>(std::move(llmRequestPtr), cpMetaData);
}

Expand Down Expand Up @@ -1428,6 +1437,27 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
{
GTEST_SKIP() << "Temporarily skipping cache transceiver tests with NIXL and MOONCAKE backend for CP.";
}
// Filter request lengths based on CP requirements.
// Each request must have at least one block per CP rank to be valid for CP tests.
std::vector<int> lenList = {60, 30, 60, 10};
if (genCp > 1)
{
std::vector<int> updatedLenList;
for (auto len : lenList)
{
if (len > tokensPerBlock * (genCp - 1))
{
updatedLenList.push_back(len);
}
}
if (updatedLenList.empty())
{
GTEST_SKIP() << "Skipping test because not even one request has one block per genCP rank. tokensPerBlock="
<< tokensPerBlock << ", genCp=" << genCp;
}
lenList = updatedLenList;
}

setUpCommunicator(contextTp, contextPp, contextCp, genTp, genPp, genCp, isMLA, contextDP, generationDP);

if (mIsContext || mIsGeneration)
Expand All @@ -1438,7 +1468,7 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
setUpCacheTransceiver();
std::vector<std::shared_ptr<WrappedLlmRequest>> requests;
int requestId = 0;
for (auto len : {60, 30, 60, 10})
for (auto len : lenList)
{
requests.emplace_back(makeLlmRequestWithDP(len, requestId, requestId % contextTp));
requestId++;
Expand Down Expand Up @@ -1814,6 +1844,44 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1WithCPForMLA, AsymmetricalCacheTest,
/*generationDP*/ testing::Values(false),
/*isWindow*/ testing::Values(false), testing::Values(false), testing::Values(0), testing::Values(128)));

// Tests cases where there's non-trivial TP and PP on context side while non-trivial CP & DP on gen side.
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithCPAndDPForMLA0, AsymmetricalCacheTestWithDP,
testing::Combine(/*contextTp*/ testing::Values(1, 2),
/*contextPp*/ testing::Values(1, 2),
/*contextCp*/ testing::Values(1),
/*genTp*/ testing::Values(2),
/*genPp*/ testing::Values(1),
/*genCp*/ testing::Values(2),
/*numLayers*/ testing::Values(4),
/*numHeads*/ testing::Values(1),
/*sizePerHead*/ testing::Values(4),
/*tokensPerBlock*/ testing::Values(8),
/*dataType*/ testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8),
/*kvFactor*/ testing::Values(1),
/*isMLA*/ testing::Values(true),
/*contextDP*/ testing::Values(false),
/*generationDP*/ testing::Values(true),
/*isWindow*/ testing::Values(false), testing::Values(false), testing::Values(0), testing::Values(128)));

// Tests cases where there's non-trivial DP on context side while non-trivial CP & DP on gen side.
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithCPAndDPForMLA1, AsymmetricalCacheTestWithDP,
testing::Combine(/*contextTp*/ testing::Values(2, 4),
/*contextPp*/ testing::Values(1),
/*contextCp*/ testing::Values(1),
/*genTp*/ testing::Values(2),
/*genPp*/ testing::Values(1),
/*genCp*/ testing::Values(2),
/*numLayers*/ testing::Values(4),
/*numHeads*/ testing::Values(1),
/*sizePerHead*/ testing::Values(4),
/*tokensPerBlock*/ testing::Values(8),
/*dataType*/ testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8),
/*kvFactor*/ testing::Values(1),
/*isMLA*/ testing::Values(true),
/*contextDP*/ testing::Values(true),
/*generationDP*/ testing::Values(true),
/*isWindow*/ testing::Values(false), testing::Values(false), testing::Values(0), testing::Values(128)));

INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA1, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
Expand Down Expand Up @@ -2226,8 +2294,8 @@ TEST(targetTest, CacheStateContextDP)
auto const verifyContext = [&](int contextRank, int generationRank, std::vector<int> const& expectRanks,
int expectPPDomain, int expectTPDomain, bool expectNeedSend)
{
int contextDPRank = contextRank % contextTP;
int generationDPRank = generationRank % genTP;
int contextDPRank = (contextRank % (contextTP * contextCP)) / contextCP;
int generationDPRank = (generationRank % (genTP * genCP)) / genCP;
auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;

Expand All @@ -2239,12 +2307,12 @@ TEST(targetTest, CacheStateContextDP)
tokensPerBlock, genTP, genPP, genCP, genAttentionLayerNumPerPP, dataType, attentionType, kvFactor,
genEnableDP, generationDPRank, genTP};

auto const contextTragetInfo
auto const contextTargetInfo
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank);

EXPECT_EQ(expectRanks, contextTragetInfo.mIRanks);
EXPECT_EQ(expectPPDomain, contextTragetInfo.mDomainPPSize);
EXPECT_EQ(expectTPDomain, contextTragetInfo.mDomainTPSize);
EXPECT_EQ(expectRanks, contextTargetInfo.mIRanks);
EXPECT_EQ(expectPPDomain, contextTargetInfo.mDomainPPSize);
EXPECT_EQ(expectTPDomain, contextTargetInfo.mDomainTPSize);
EXPECT_EQ(expectNeedSend, MLACacheFormatter::needSendCache(contextCache, genCache, contextRank));
};

Expand Down Expand Up @@ -2330,11 +2398,11 @@ TEST(targetTest, CacheStateContextDP)
contextTP = 1;
genTP = 2;

auto const verfiyGeneration = [&](int contextRank, int generationRank, std::vector<int> const& expectRanks,
auto const verifyGeneration = [&](int contextRank, int generationRank, std::vector<int> const& expectRanks,
int expectPPDomain, int expectTPDomain)
{
int contextDPRank = contextRank % contextTP;
int generationDPRank = generationRank % genTP;
int contextDPRank = (contextRank % (contextTP * contextCP)) / contextCP;
int generationDPRank = (generationRank % (genTP * genCP)) / genCP;
auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;

Expand All @@ -2346,17 +2414,17 @@ TEST(targetTest, CacheStateContextDP)
tokensPerBlock, genTP, genPP, genCP, genAttentionLayerNumPerPP, dataType, attentionType, kvFactor,
genEnableDP, generationDPRank, genTP};

auto const contextTragetInfo
auto const contextTargetInfo
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(contextCache, genCache, generationRank);

EXPECT_EQ(expectRanks, contextTragetInfo.mIRanks);
EXPECT_EQ(expectPPDomain, contextTragetInfo.mDomainPPSize);
EXPECT_EQ(expectTPDomain, contextTragetInfo.mDomainTPSize);
EXPECT_EQ(expectRanks, contextTargetInfo.mIRanks);
EXPECT_EQ(expectPPDomain, contextTargetInfo.mDomainPPSize);
EXPECT_EQ(expectTPDomain, contextTargetInfo.mDomainTPSize);
};

verfiyGeneration(
verifyGeneration(
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
verfiyGeneration(
verifyGeneration(
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);

contextTP = 1;
Expand All @@ -2366,9 +2434,9 @@ TEST(targetTest, CacheStateContextDP)
contextAttentionLayerNumPerPP = std::vector<SizeType32>(contextPP, numLayers / contextPP);
genAttentionLayerNumPerPP = std::vector<SizeType32>(genPP, numLayers / genPP);

verfiyGeneration(
verifyGeneration(
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
verfiyGeneration(
verifyGeneration(
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);

genEnableDP = false;
Expand All @@ -2381,8 +2449,8 @@ TEST(targetTest, CacheStateContextDP)
contextAttentionLayerNumPerPP = std::vector<SizeType32>(contextPP, numLayers / contextPP);
genAttentionLayerNumPerPP = std::vector<SizeType32>(genPP, numLayers / genPP);

verfiyGeneration(
verifyGeneration(
/*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
verfiyGeneration(
verifyGeneration(
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1);
}
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,7 +1547,7 @@ def _maybe_sync_cache_data(self, strategy: DistributedTuningStrategy,
def _merge_cache_data(self, custom_op: str):
cache_data = self.profiling_cache.get_specific_custom_op(custom_op)
merged_cache_data = dict()
all_cache_data = self._dist.tp_allgather(obj=cache_data)
all_cache_data = self._dist.tp_cp_allgather(obj=cache_data)

for data in all_cache_data:
for key, value in data.items():
Expand Down
61 changes: 45 additions & 16 deletions tensorrt_llm/_torch/distributed/communicator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import math
import pickle # nosec B403
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -136,6 +135,35 @@ def tp_cp_broadcast(self, obj, root=0, **kwargs):
obj = self.cp_broadcast(obj, root=root, **kwargs)
return obj

@abstractmethod
def tp_allgather(self, obj):
pass

@abstractmethod
def cp_allgather(self, obj):
pass

def tp_cp_allgather(self, obj):
"""Allgather across both TP and CP dimensions.

First gathers within CP group, then across TP groups, returning
a flattened list with tp_size * cp_size entries.
"""
# Gather across CP dimension.
if self.cp_size > 1:
obj = self.cp_allgather(obj)
else:
obj = [obj] # Wrap to match cp_allgather output format.

# Gather across TP dimension.
if self.tp_size > 1:
obj = self.tp_allgather(obj)
else:
obj = [obj] # Wrap to match tp_allgather output format.

# Flatten: [[cp0, cp1], [cp0, cp1], ...] -> [tp0_cp0, tp0_cp1, tp1_cp0, ...]
return [entry for tp_group in obj for entry in tp_group]


def safe_broadcast(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
"""
Expand Down Expand Up @@ -363,24 +391,9 @@ class MPIDist(Distributed):
def __init__(self, mapping: Mapping):
super().__init__(mapping)
self.create_cp_comm()
# Repurpose CP ranks to TP for Helix so that the right comms are created.
mapping_with_cp = None
if self.mapping.has_cp_helix():
logger.info(
f"[MPIDist::__init__] Repurposing CP ranks to TP for Helix.")
mapping_with_cp = copy.deepcopy(self.mapping)
self.mapping = self.mapping.repurpose_helix_cp_to_tp()

self.create_tp_comm()
self.create_pp_comm()

# Restore the original mapping.
if mapping_with_cp is not None:
logger.info(
f"[MPIDist::__init__] Restoring original mapping undoing Helix manipulation."
)
self.mapping = mapping_with_cp

def broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
comm = mpi_comm()
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
Expand Down Expand Up @@ -758,6 +771,22 @@ def cp_broadcast(self, obj, root=0, **kwargs):
device=torch.device("cpu"))
return ret[0]

@log_op
def cp_allgather(self, obj):
if isinstance(obj, torch.Tensor):
output_list = [
torch.empty_like(obj)
for _ in range(self.mapping.cp_group_pg.size())
]
dist.all_gather(output_list, obj, group=self.mapping.cp_group_pg)
return output_list
else:
output_list = [None] * self.mapping.cp_group_pg.size()
dist.all_gather_object(output_list,
obj,
group=self.mapping.cp_group_pg)
return output_list

@log_op
def pp_allgather(self, obj):
if isinstance(obj, torch.Tensor):
Expand Down
17 changes: 14 additions & 3 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,13 +1120,19 @@ def __init__(self,
reduce_output=not self.enable_attention_dp
and self.mapping.tp_size > 1)
else:
# When enable_attention_dp is True, TP reduction is skipped since each DP rank
# works on different batch elements. However, with CP > 1, attention is split
# across CP ranks for the SAME batch element, so reduction is still needed
# within the CP group.
needs_tp_reduce = not self.enable_attention_dp and self.mapping.tp_size > 1
needs_cp_reduce = mapping_with_cp is not None and mapping_with_cp.has_cp_helix(
)
self.self_attn = DeepseekV3Attention(
model_config,
layer_idx=layer_idx_for_attention,
aux_stream=aux_stream_dict[AuxStreamType.Attention],
mapping_with_cp=mapping_with_cp,
reduce_output=not self.enable_attention_dp
and self.mapping.tp_size > 1)
reduce_output=needs_tp_reduce or needs_cp_reduce)

self.fusion_config = EagerFusionConfig()
self.enable_fusion = os.environ.get(
Expand Down Expand Up @@ -1192,10 +1198,15 @@ def __init__(self,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)

# When enable_attention_dp is True, we normally skip attention all-reduce since each
# DP rank works on different batch elements. However, with CP > 1, attention is split
# across CP ranks for the SAME batch element, so all-reduce is still needed.
has_cp = mapping_with_cp is not None and mapping_with_cp.cp_size > 1
can_skip_for_attention_dp = self.enable_attention_dp and not has_cp
self.disable_attn_allreduce = (self.fusion_config.PRE_MOE_FUSION
or self.fusion_config.PRE_MLP_FUSION
or self.mapping.tp_size == 1
or self.enable_attention_dp)
or can_skip_for_attention_dp)

self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
Expand Down
Loading