Skip to content

Commit 57fd9a3

Browse files
committed
test_mla_helix: Fix NCCL error by computing reference locally on all ranks
Previously, the test tried to use cp_allgather to broadcast the reference output from rank 0 to all other ranks. This failed because NCCL process groups weren't properly initialized in the MPI worker context. The fix: 1. All ranks now compute the reference MLA output locally since they all have the same random seed and inputs 2. _make_latent_cache_gen also computes locally instead of broadcasting, since all ranks now have the reference KV cache 3. Removed cp_allgather import as it's no longer needed This approach is actually cleaner as it avoids inter-process communication for the reference computation.
1 parent 08a8fbd commit 57fd9a3

File tree

1 file changed

+140
-178
lines changed

1 file changed

+140
-178
lines changed

tests/unittest/_torch/modules/test_mla_helix.py

Lines changed: 140 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
RopeParams,
5050
)
5151
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
52-
from tensorrt_llm._torch.distributed.ops import cp_allgather
52+
# cp_allgather import removed - not needed after refactoring
5353
from tensorrt_llm._torch.model_config import ModelConfig
5454
from tensorrt_llm._torch.modules.attention import MLA
5555
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest, LlmRequestState, SamplingConfig
@@ -488,12 +488,12 @@ def _make_latent_cache_gen(
488488
cp_size: int,
489489
ctx_len_per_cp: int,
490490
input_ctx_bs: torch.Tensor,
491-
ref_attn_metadata: Optional[AttentionMetadata],
491+
ref_attn_metadata: AttentionMetadata,
492492
):
493493
"""Generate latent cache for Helix CP generation phase.
494494
495-
The latent cache is communicated across CP ranks. All TP ranks in the same
496-
CP group share the same latent cache values.
495+
Since all ranks compute the reference MLA locally (same random seed and inputs),
496+
all ranks have the same KV cache and can compute the latent cache locally.
497497
498498
Args:
499499
mla: MLA module
@@ -503,93 +503,67 @@ def _make_latent_cache_gen(
503503
cp_size: Total CP size
504504
ctx_len_per_cp: Context length per CP rank
505505
input_ctx_bs: Input context tensor (batch, ctx_len, hidden_size)
506-
ref_attn_metadata: Reference attention metadata (only on cp_rank 0)
506+
ref_attn_metadata: Reference attention metadata (available on all ranks)
507507
508508
Returns:
509509
Latent cache tensor for this CP rank, or None for last CP rank
510510
"""
511-
world_size = tp_size * cp_size
511+
# Last CP rank doesn't need latent cache
512+
if cp_rank == cp_size - 1:
513+
return None
512514

513-
# Only the first CP rank (across all TP ranks) has the reference metadata
514-
# and generates the latent cache
515-
if cp_rank == 0:
516-
assert ref_attn_metadata is not None
517-
kv_cache_block_offsets = ref_attn_metadata.host_kv_cache_block_offsets
518-
kv_buffer = ref_attn_metadata.kv_cache_manager.get_buffers(0)
519-
ret = input_ctx_bs.new_empty(
520-
(cp_size - 1, input_ctx_bs.shape[0], mla.kv_lora_rank + mla.qk_rope_head_dim)
521-
)
522-
# the RoPE values in the KV cache are embedded and we need to get the
523-
# original values instead for the latent cache
524-
# so we first get the cos/sin cache used in MLA
525-
_, cos_sin_cache = mla.pos_embd_params.rope.create_rope_const_params()
526-
cos_sin_cache = cos_sin_cache.reshape(-1, mla.qk_rope_head_dim, 2)
527-
assert cos_sin_cache.dtype == torch.float32
528-
529-
def rotate_half(x):
530-
"""Rotates half the hidden dims of the input."""
531-
x1 = x[..., : x.shape[-1] // 2]
532-
x2 = x[..., x.shape[-1] // 2 :]
533-
return torch.cat((-x2, x1), dim=-1)
534-
535-
def rotate_half_inv(x):
536-
"""Rotates half the hidden dims of the input."""
537-
x1 = x[..., : x.shape[-1] // 2]
538-
x2 = x[..., x.shape[-1] // 2 :]
539-
return torch.cat((x2, -x1), dim=-1)
540-
541-
# Generate latent cache for each CP rank (except the first)
542-
for r in range(cp_size - 1):
543-
for b in range(input_ctx_bs.shape[0]):
544-
block, t = divmod(
545-
(r + 1) * ctx_len_per_cp, ref_attn_metadata.kv_cache_manager.tokens_per_block
546-
)
547-
kv_block = kv_cache_block_offsets[0, b, 0, block].item()
548-
ret[r, b] = kv_buffer[kv_block, 0, t, 0, :]
549-
rope_values = ret[:, :, mla.kv_lora_rank :].clone().to(dtype=torch.float32)
550-
# now we apply the inverse of RoPE embedding to get the original values
551-
# rope_values has shape (cp_size - 1, batch_size, rope_dim)
552-
# cos_sin_cache has shape (max_pos, rope_dim, 2)
553-
554-
# Setup position and cos/sin values
555-
positions = torch.arange(1, cp_size, device=rope_values.device) * ctx_len_per_cp
556-
cos_sin_cache_pos = torch.index_select(cos_sin_cache, 0, positions)
557-
cos = cos_sin_cache_pos[..., 0].unsqueeze(1)
558-
sin = cos_sin_cache_pos[..., 1].unsqueeze(1)
559-
# cos/sin shape is (cp_size - 1, 1, rope_dim) to broadcast with batch
560-
561-
# Reshape for pairwise rotation
562-
rope_values_reshaped = (
563-
rope_values.unflatten(-1, [-1, 2]).transpose(-1, -2).flatten(start_dim=-2)
564-
)
565-
orig_rope_values = rope_values_reshaped * cos + rotate_half_inv(rope_values_reshaped) * sin
566-
orig_rope_values_reshaped = (
567-
orig_rope_values.unflatten(-1, [2, -1]).transpose(-2, -1).flatten(start_dim=-2)
568-
)
569-
570-
ret[:, :, mla.kv_lora_rank :] = orig_rope_values_reshaped.to(dtype=ret.dtype)
571-
else:
572-
ret = input_ctx_bs.new_empty(
573-
(cp_size - 1, input_ctx_bs.shape[0], mla.kv_lora_rank + mla.qk_rope_head_dim)
515+
# All ranks have ref_attn_metadata and can compute the latent cache locally
516+
kv_cache_block_offsets = ref_attn_metadata.host_kv_cache_block_offsets
517+
kv_buffer = ref_attn_metadata.kv_cache_manager.get_buffers(0)
518+
519+
# Only need latent cache for this specific cp_rank
520+
ret = input_ctx_bs.new_empty(
521+
(input_ctx_bs.shape[0], mla.kv_lora_rank + mla.qk_rope_head_dim)
522+
)
523+
524+
# the RoPE values in the KV cache are embedded and we need to get the
525+
# original values instead for the latent cache
526+
# so we first get the cos/sin cache used in MLA
527+
_, cos_sin_cache = mla.pos_embd_params.rope.create_rope_const_params()
528+
cos_sin_cache = cos_sin_cache.reshape(-1, mla.qk_rope_head_dim, 2)
529+
assert cos_sin_cache.dtype == torch.float32
530+
531+
def rotate_half_inv(x):
532+
"""Rotates half the hidden dims of the input."""
533+
x1 = x[..., : x.shape[-1] // 2]
534+
x2 = x[..., x.shape[-1] // 2 :]
535+
return torch.cat((x2, -x1), dim=-1)
536+
537+
# Get latent cache for this cp_rank's boundary position
538+
target_pos = (cp_rank + 1) * ctx_len_per_cp
539+
for b in range(input_ctx_bs.shape[0]):
540+
block, t = divmod(
541+
target_pos, ref_attn_metadata.kv_cache_manager.tokens_per_block
574542
)
575-
576-
# Create mapping for allgather across CP group
577-
# All TP ranks in the same CP group will have the same latent cache
578-
rank = tp_rank + cp_rank * tp_size # Reconstruct global rank
579-
mapping = Mapping(
580-
world_size=world_size,
581-
rank=rank,
582-
tp_size=tp_size,
583-
cp_size=cp_size,
584-
cp_config={"cp_type": CpType.HELIX, "tokens_per_block": 32}
543+
kv_block = kv_cache_block_offsets[0, b, 0, block].item()
544+
ret[b] = kv_buffer[kv_block, 0, t, 0, :]
545+
546+
# Apply inverse RoPE to get original values
547+
rope_values = ret[:, mla.kv_lora_rank:].clone().to(dtype=torch.float32)
548+
549+
# Setup position and cos/sin values for this specific position
550+
position = torch.tensor([target_pos], device=rope_values.device)
551+
cos_sin_cache_pos = torch.index_select(cos_sin_cache, 0, position)
552+
cos = cos_sin_cache_pos[..., 0] # shape: (1, rope_dim)
553+
sin = cos_sin_cache_pos[..., 1] # shape: (1, rope_dim)
554+
555+
# Reshape for pairwise rotation
556+
rope_values_reshaped = (
557+
rope_values.unflatten(-1, [-1, 2]).transpose(-1, -2).flatten(start_dim=-2)
558+
)
559+
orig_rope_values = rope_values_reshaped * cos + rotate_half_inv(rope_values_reshaped) * sin
560+
orig_rope_values_reshaped = (
561+
orig_rope_values.unflatten(-1, [2, -1]).transpose(-2, -1).flatten(start_dim=-2)
585562
)
586-
# use cp_allgather here to broadcast from cp_rank 0 to all other cp_ranks
587-
ret_all = cp_allgather(ret, mapping=mapping, dim=0)
588-
ret = ret_all.view(cp_size, *ret.shape)[0]
563+
564+
ret[:, mla.kv_lora_rank:] = orig_rope_values_reshaped.to(dtype=ret.dtype)
589565

590-
if cp_rank == cp_size - 1:
591-
return None
592-
return ret[cp_rank]
566+
return ret
593567

594568

595569

@@ -953,109 +927,102 @@ def _full_test_multi_gpu(
953927
_generate_random_weights(mla)
954928
weights = mla.state_dict()
955929

956-
# up to this point, all ranks should have same tensors because the seed is the same
957-
# now we run the reference MLA on rank 0
958-
if rank == 0:
959-
# Reference output (single GPU, no parallelism)
960-
ref_mapping = Mapping(world_size=1, tp_size=1, cp_size=1, rank=0)
961-
ref_kv_cache_manager, ref_attn_metadata = _setup_kv_and_metadata(
962-
scenario, ref_mapping, gen_steps
963-
)
964-
# this represents the context step
965-
mla(position_ids_ctx, input_ctx, ref_attn_metadata)
966-
ref_outputs = []
967-
start = time.time()
968-
969-
# CUDA graph setup for timing
970-
use_cuda_graph = gen_steps > scenario.ref_steps
971-
graph = None
972-
graph_output = None
973-
974-
for step in range(gen_steps):
975-
for req_id in range(scenario.batch):
976-
ref_kv_cache_manager.impl.add_token(req_id)
977-
if step == 0:
978-
ref_attn_metadata = get_attention_backend("TRTLLM").Metadata(
979-
seq_lens=torch.tensor([1] * scenario.batch, dtype=torch.int),
980-
request_ids=list(range(scenario.batch)),
981-
max_num_requests=scenario.batch,
982-
num_contexts=0,
983-
prompt_lens=[scenario.ctx_len] * scenario.batch,
984-
max_num_tokens=scenario.ctx_len,
985-
kv_cache_manager=ref_kv_cache_manager,
986-
kv_cache_params=KVCacheParams(
987-
use_cache=True,
988-
num_cached_tokens_per_seq=[
989-
scenario.ctx_len + step for _ in range(scenario.batch)
990-
],
991-
),
992-
enable_context_mla_with_cached_kv=True,
993-
)
994-
else:
995-
ref_attn_metadata.kv_cache_params = KVCacheParams(
930+
# Since all ranks have the same random seed and inputs, each rank computes
931+
# the reference output locally. This avoids the need for collective communication.
932+
# Reference output (single GPU, no parallelism)
933+
ref_mapping = Mapping(world_size=1, tp_size=1, cp_size=1, rank=0)
934+
ref_kv_cache_manager, ref_attn_metadata = _setup_kv_and_metadata(
935+
scenario, ref_mapping, gen_steps
936+
)
937+
# this represents the context step
938+
mla(position_ids_ctx, input_ctx, ref_attn_metadata)
939+
ref_outputs = []
940+
start = time.time()
941+
942+
# CUDA graph setup for timing
943+
use_cuda_graph = gen_steps > scenario.ref_steps
944+
graph = None
945+
graph_output = None
946+
947+
for step in range(gen_steps):
948+
for req_id in range(scenario.batch):
949+
ref_kv_cache_manager.impl.add_token(req_id)
950+
if step == 0:
951+
ref_attn_metadata = get_attention_backend("TRTLLM").Metadata(
952+
seq_lens=torch.tensor([1] * scenario.batch, dtype=torch.int),
953+
request_ids=list(range(scenario.batch)),
954+
max_num_requests=scenario.batch,
955+
num_contexts=0,
956+
prompt_lens=[scenario.ctx_len] * scenario.batch,
957+
max_num_tokens=scenario.ctx_len,
958+
kv_cache_manager=ref_kv_cache_manager,
959+
kv_cache_params=KVCacheParams(
996960
use_cache=True,
997961
num_cached_tokens_per_seq=[
998962
scenario.ctx_len + step for _ in range(scenario.batch)
999963
],
1000-
)
1001-
ref_attn_metadata.prepare()
964+
),
965+
enable_context_mla_with_cached_kv=True,
966+
)
967+
else:
968+
ref_attn_metadata.kv_cache_params = KVCacheParams(
969+
use_cache=True,
970+
num_cached_tokens_per_seq=[
971+
scenario.ctx_len + step for _ in range(scenario.batch)
972+
],
973+
)
974+
ref_attn_metadata.prepare()
1002975

1003-
if not use_cuda_graph:
1004-
result = mla(position_ids_gen, input_gen, ref_attn_metadata)
1005-
if step < scenario.ref_steps:
1006-
ref_outputs.append(result)
976+
if not use_cuda_graph:
977+
result = mla(position_ids_gen, input_gen, ref_attn_metadata)
978+
if step < scenario.ref_steps:
979+
ref_outputs.append(result)
980+
if rank == 0:
1007981
print(f"Ref result: {result[0, :8]} / {result[-1, -8:]}")
1008-
# update position_ids_gen
1009-
position_ids_gen += 1
1010-
continue
982+
# update position_ids_gen
983+
position_ids_gen += 1
984+
continue
1011985

1012-
# CUDA graph capture on first step when timing
1013-
if step == 0:
986+
# CUDA graph capture on first step when timing
987+
if step == 0:
988+
if rank == 0:
1014989
print("Creating CUDA graph and capturing")
1015-
# Create CUDA graph metadata for capture
1016-
ref_attn_metadata = ref_attn_metadata.create_cuda_graph_metadata(
1017-
max_batch_size=scenario.batch
1018-
)
1019-
ref_attn_metadata.prepare()
1020-
1021-
# Warm-up runs before graph capture
1022-
for _ in range(2):
1023-
result = mla(position_ids_gen, input_gen, ref_attn_metadata)
1024-
1025-
# Capture the graph
1026-
graph = torch.cuda.CUDAGraph()
1027-
with torch.cuda.graph(graph):
1028-
graph_output = mla(position_ids_gen, input_gen, ref_attn_metadata)
1029-
result = graph_output
1030-
elif step == scenario.ref_steps:
1031-
# Start timing with CUDA graph
1032-
start = time.time()
1033-
graph.replay()
990+
# Create CUDA graph metadata for capture
991+
ref_attn_metadata = ref_attn_metadata.create_cuda_graph_metadata(
992+
max_batch_size=scenario.batch
993+
)
994+
ref_attn_metadata.prepare()
995+
996+
# Warm-up runs before graph capture
997+
for _ in range(2):
998+
result = mla(position_ids_gen, input_gen, ref_attn_metadata)
999+
1000+
# Capture the graph
1001+
graph = torch.cuda.CUDAGraph()
1002+
with torch.cuda.graph(graph):
1003+
graph_output = mla(position_ids_gen, input_gen, ref_attn_metadata)
10341004
result = graph_output
1035-
# update position_ids_gen
1036-
position_ids_gen += 1
1037-
if step < scenario.ref_steps:
1038-
ref_outputs.append(result)
1039-
end = time.time()
1040-
if gen_steps == scenario.ref_steps:
1041-
avg_gen_time = float("inf")
1042-
else:
1043-
avg_gen_time = (end - start) / (gen_steps - scenario.ref_steps)
1044-
throughput = scenario.batch / avg_gen_time
1005+
elif step == scenario.ref_steps:
1006+
# Start timing with CUDA graph
1007+
start = time.time()
1008+
graph.replay()
1009+
result = graph_output
1010+
# update position_ids_gen
1011+
position_ids_gen += 1
1012+
if step < scenario.ref_steps:
1013+
ref_outputs.append(result)
1014+
end = time.time()
1015+
if gen_steps == scenario.ref_steps:
1016+
avg_gen_time = float("inf")
1017+
else:
1018+
avg_gen_time = (end - start) / (gen_steps - scenario.ref_steps)
1019+
throughput = scenario.batch / avg_gen_time
1020+
if rank == 0:
10451021
print(
10461022
f"Time taken for {gen_steps - scenario.ref_steps} steps: "
10471023
f"{end - start} s, throughput: {throughput} MLA/s"
10481024
)
1049-
ref_output = torch.stack(ref_outputs, dim=0)
1050-
else:
1051-
ref_output = torch.empty(
1052-
scenario.ref_steps,
1053-
scenario.batch,
1054-
scenario.hidden_size,
1055-
dtype=scenario.dtype,
1056-
device="cuda",
1057-
)
1058-
ref_attn_metadata = None
1025+
ref_output = torch.stack(ref_outputs, dim=0)
10591026

10601027
# Distributed mapping for mixed TP+CP helix
10611028
mapping = Mapping(
@@ -1066,11 +1033,6 @@ def _full_test_multi_gpu(
10661033
cp_config={"cp_type": CpType.HELIX}
10671034
)
10681035

1069-
# Broadcast reference output from rank 0 to all ranks using allgather
1070-
ref_output_all = cp_allgather(ref_output, mapping=mapping, dim=0)
1071-
# we only need the values from rank 0
1072-
ref_output = ref_output_all.view(world_size, *ref_output.shape)[0]
1073-
10741036
test_params = (
10751037
input_ctx,
10761038
input_gen,

0 commit comments

Comments
 (0)