Skip to content

Commit af69ed7

Browse files
committed
cleaner order of iteration
1 parent d9384e8 commit af69ed7

File tree

4 files changed

+11
-12
lines changed

4 files changed

+11
-12
lines changed

cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,14 @@ TargetRanksInfo TargetRanksInfoForDP(
205205
}
206206

207207
std::vector<int> retRanks;
208-
for (int i = peerTPRankStart; i < peerTPRankEnd; i++)
208+
for (int i = peerCPRankStart; i < peerCPRankEnd; i++)
209209
{
210-
for (int j = peerCPRankStart; j < peerCPRankEnd; j++)
210+
for (int j = peerTPRankStart; j < peerTPRankEnd; j++)
211211
{
212212
for (int k = peerPPRankStart; k < peerPPRankEnd; k++)
213213
{
214-
// Rank formula: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank
215-
// where i=tpRank, j=cpRank, k=ppRank
216-
int irank = (k * peerTPNum * peerCPNum) + (i * peerCPNum) + j;
214+
// Rank formula: ppRank * (tpNum * cpNum) + tpRank * cpNum + cpRank.
215+
int irank = (k * peerTPNum * peerCPNum) + (j * peerCPNum) + i;
217216
retRanks.push_back(irank);
218217
}
219218
}

tensorrt_llm/_torch/device_mesh.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def build_mesh(self):
118118
"DeviceMesh creation requested but torch.distributed process group "
119119
"has not been initialised.")
120120

121-
# Dimensions go from slowest-varying (outermost) to fastest-varying (innermost)
122-
# Layout: pp is outermost, then tp, then cp is innermost (consecutive)
121+
# Dimensions go from slowest-varying (outermost) to fastest-varying (innermost).
122+
# Layout: pp is outermost, then tp, then cp is innermost (consecutive).
123123
dims = ["pp", "tp"]
124124
shape = [self.pp_size, self.tp_size]
125125

tensorrt_llm/mapping.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,13 +292,13 @@ def has_cp(self):
292292
return self.cp_size > 1
293293

294294
def prev_cp_rank(self):
295-
# cp ranks are consecutive, so prev is rank - 1 with wraparound within cp group
295+
# cp ranks are consecutive, so prev is rank - 1 with wraparound within cp group.
296296
if self.cp_rank == 0:
297297
return self.rank + self.cp_size - 1
298298
return self.rank - 1
299299

300300
def next_cp_rank(self):
301-
# cp ranks are consecutive, so next is rank + 1 with wraparound within cp group
301+
# cp ranks are consecutive, so next is rank + 1 with wraparound within cp group.
302302
if self.cp_rank == self.cp_size - 1:
303303
return self.rank - self.cp_size + 1
304304
return self.rank + 1
@@ -596,15 +596,15 @@ def _init_parallel_groups(self):
596596
ranks = range(i, self.world_size, self.tp_size * self.cp_size)
597597
self.pp_groups.append(list(ranks))
598598

599-
# init cp group (consecutive ranks within each tp slice)
599+
# init cp group (consecutive ranks within each tp slice).
600600
for i in range(self.pp_size):
601601
for j in range(self.tp_size):
602602
ranks = range(
603603
i * self.tp_size * self.cp_size + j * self.cp_size,
604604
i * self.tp_size * self.cp_size + (j + 1) * self.cp_size)
605605
self.cp_groups.append(list(ranks))
606606

607-
# init tp group (interleaved ranks with stride of cp_size)
607+
# init tp group (interleaved ranks with stride of cp_size).
608608
for i in range(self.pp_size):
609609
for j in range(self.cp_size):
610610
ranks = range(i * self.tp_size * self.cp_size + j,

tensorrt_llm/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ def from_checkpoint(
740740

741741
rank = config.mapping.rank
742742
if config.mapping.cp_size > 1:
743-
# cp_tp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt
743+
# cp_tp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt.
744744
tp_size = config.mapping.tp_size
745745
cp_size = config.mapping.cp_size
746746
rank = (rank % (tp_size * cp_size)) // cp_size + rank // (

0 commit comments

Comments
 (0)