4949 RopeParams ,
5050)
5151from tensorrt_llm ._torch .attention_backend .utils import get_attention_backend
52- from tensorrt_llm . _torch . distributed . ops import cp_allgather
52+ # cp_allgather import removed - not needed after refactoring
5353from tensorrt_llm ._torch .model_config import ModelConfig
5454from tensorrt_llm ._torch .modules .attention import MLA
5555from tensorrt_llm ._torch .pyexecutor .llm_request import LlmRequest , LlmRequestState , SamplingConfig
@@ -488,12 +488,12 @@ def _make_latent_cache_gen(
488488 cp_size : int ,
489489 ctx_len_per_cp : int ,
490490 input_ctx_bs : torch .Tensor ,
491- ref_attn_metadata : Optional [ AttentionMetadata ] ,
491+ ref_attn_metadata : AttentionMetadata ,
492492):
493493 """Generate latent cache for Helix CP generation phase.
494494
495- The latent cache is communicated across CP ranks. All TP ranks in the same
496- CP group share the same latent cache values .
495+ Since all ranks compute the reference MLA locally (same random seed and inputs),
496+ all ranks have the same KV cache and can compute the latent cache locally .
497497
498498 Args:
499499 mla: MLA module
@@ -503,93 +503,67 @@ def _make_latent_cache_gen(
503503 cp_size: Total CP size
504504 ctx_len_per_cp: Context length per CP rank
505505 input_ctx_bs: Input context tensor (batch, ctx_len, hidden_size)
506- ref_attn_metadata: Reference attention metadata (only on cp_rank 0 )
506+ ref_attn_metadata: Reference attention metadata (available on all ranks )
507507
508508 Returns:
509509 Latent cache tensor for this CP rank, or None for last CP rank
510510 """
511- world_size = tp_size * cp_size
511+ # Last CP rank doesn't need latent cache
512+ if cp_rank == cp_size - 1 :
513+ return None
512514
513- # Only the first CP rank (across all TP ranks) has the reference metadata
514- # and generates the latent cache
515- if cp_rank == 0 :
516- assert ref_attn_metadata is not None
517- kv_cache_block_offsets = ref_attn_metadata .host_kv_cache_block_offsets
518- kv_buffer = ref_attn_metadata .kv_cache_manager .get_buffers (0 )
519- ret = input_ctx_bs .new_empty (
520- (cp_size - 1 , input_ctx_bs .shape [0 ], mla .kv_lora_rank + mla .qk_rope_head_dim )
521- )
522- # the RoPE values in the KV cache are embedded and we need to get the
523- # original values instead for the latent cache
524- # so we first get the cos/sin cache used in MLA
525- _ , cos_sin_cache = mla .pos_embd_params .rope .create_rope_const_params ()
526- cos_sin_cache = cos_sin_cache .reshape (- 1 , mla .qk_rope_head_dim , 2 )
527- assert cos_sin_cache .dtype == torch .float32
528-
529- def rotate_half (x ):
530- """Rotates half the hidden dims of the input."""
531- x1 = x [..., : x .shape [- 1 ] // 2 ]
532- x2 = x [..., x .shape [- 1 ] // 2 :]
533- return torch .cat ((- x2 , x1 ), dim = - 1 )
534-
535- def rotate_half_inv (x ):
536- """Rotates half the hidden dims of the input."""
537- x1 = x [..., : x .shape [- 1 ] // 2 ]
538- x2 = x [..., x .shape [- 1 ] // 2 :]
539- return torch .cat ((x2 , - x1 ), dim = - 1 )
540-
541- # Generate latent cache for each CP rank (except the first)
542- for r in range (cp_size - 1 ):
543- for b in range (input_ctx_bs .shape [0 ]):
544- block , t = divmod (
545- (r + 1 ) * ctx_len_per_cp , ref_attn_metadata .kv_cache_manager .tokens_per_block
546- )
547- kv_block = kv_cache_block_offsets [0 , b , 0 , block ].item ()
548- ret [r , b ] = kv_buffer [kv_block , 0 , t , 0 , :]
549- rope_values = ret [:, :, mla .kv_lora_rank :].clone ().to (dtype = torch .float32 )
550- # now we apply the inverse of RoPE embedding to get the original values
551- # rope_values has shape (cp_size - 1, batch_size, rope_dim)
552- # cos_sin_cache has shape (max_pos, rope_dim, 2)
553-
554- # Setup position and cos/sin values
555- positions = torch .arange (1 , cp_size , device = rope_values .device ) * ctx_len_per_cp
556- cos_sin_cache_pos = torch .index_select (cos_sin_cache , 0 , positions )
557- cos = cos_sin_cache_pos [..., 0 ].unsqueeze (1 )
558- sin = cos_sin_cache_pos [..., 1 ].unsqueeze (1 )
559- # cos/sin shape is (cp_size - 1, 1, rope_dim) to broadcast with batch
560-
561- # Reshape for pairwise rotation
562- rope_values_reshaped = (
563- rope_values .unflatten (- 1 , [- 1 , 2 ]).transpose (- 1 , - 2 ).flatten (start_dim = - 2 )
564- )
565- orig_rope_values = rope_values_reshaped * cos + rotate_half_inv (rope_values_reshaped ) * sin
566- orig_rope_values_reshaped = (
567- orig_rope_values .unflatten (- 1 , [2 , - 1 ]).transpose (- 2 , - 1 ).flatten (start_dim = - 2 )
568- )
569-
570- ret [:, :, mla .kv_lora_rank :] = orig_rope_values_reshaped .to (dtype = ret .dtype )
571- else :
572- ret = input_ctx_bs .new_empty (
573- (cp_size - 1 , input_ctx_bs .shape [0 ], mla .kv_lora_rank + mla .qk_rope_head_dim )
515+ # All ranks have ref_attn_metadata and can compute the latent cache locally
516+ kv_cache_block_offsets = ref_attn_metadata .host_kv_cache_block_offsets
517+ kv_buffer = ref_attn_metadata .kv_cache_manager .get_buffers (0 )
518+
519+ # Only need latent cache for this specific cp_rank
520+ ret = input_ctx_bs .new_empty (
521+ (input_ctx_bs .shape [0 ], mla .kv_lora_rank + mla .qk_rope_head_dim )
522+ )
523+
524+ # the RoPE values in the KV cache are embedded and we need to get the
525+ # original values instead for the latent cache
526+ # so we first get the cos/sin cache used in MLA
527+ _ , cos_sin_cache = mla .pos_embd_params .rope .create_rope_const_params ()
528+ cos_sin_cache = cos_sin_cache .reshape (- 1 , mla .qk_rope_head_dim , 2 )
529+ assert cos_sin_cache .dtype == torch .float32
530+
531+ def rotate_half_inv (x ):
532+ """Rotates half the hidden dims of the input."""
533+ x1 = x [..., : x .shape [- 1 ] // 2 ]
534+ x2 = x [..., x .shape [- 1 ] // 2 :]
535+ return torch .cat ((x2 , - x1 ), dim = - 1 )
536+
537+ # Get latent cache for this cp_rank's boundary position
538+ target_pos = (cp_rank + 1 ) * ctx_len_per_cp
539+ for b in range (input_ctx_bs .shape [0 ]):
540+ block , t = divmod (
541+ target_pos , ref_attn_metadata .kv_cache_manager .tokens_per_block
574542 )
575-
576- # Create mapping for allgather across CP group
577- # All TP ranks in the same CP group will have the same latent cache
578- rank = tp_rank + cp_rank * tp_size # Reconstruct global rank
579- mapping = Mapping (
580- world_size = world_size ,
581- rank = rank ,
582- tp_size = tp_size ,
583- cp_size = cp_size ,
584- cp_config = {"cp_type" : CpType .HELIX , "tokens_per_block" : 32 }
543+ kv_block = kv_cache_block_offsets [0 , b , 0 , block ].item ()
544+ ret [b ] = kv_buffer [kv_block , 0 , t , 0 , :]
545+
546+ # Apply inverse RoPE to get original values
547+ rope_values = ret [:, mla .kv_lora_rank :].clone ().to (dtype = torch .float32 )
548+
549+ # Setup position and cos/sin values for this specific position
550+ position = torch .tensor ([target_pos ], device = rope_values .device )
551+ cos_sin_cache_pos = torch .index_select (cos_sin_cache , 0 , position )
552+ cos = cos_sin_cache_pos [..., 0 ] # shape: (1, rope_dim)
553+ sin = cos_sin_cache_pos [..., 1 ] # shape: (1, rope_dim)
554+
555+ # Reshape for pairwise rotation
556+ rope_values_reshaped = (
557+ rope_values .unflatten (- 1 , [- 1 , 2 ]).transpose (- 1 , - 2 ).flatten (start_dim = - 2 )
558+ )
559+ orig_rope_values = rope_values_reshaped * cos + rotate_half_inv (rope_values_reshaped ) * sin
560+ orig_rope_values_reshaped = (
561+ orig_rope_values .unflatten (- 1 , [2 , - 1 ]).transpose (- 2 , - 1 ).flatten (start_dim = - 2 )
585562 )
586- # use cp_allgather here to broadcast from cp_rank 0 to all other cp_ranks
587- ret_all = cp_allgather (ret , mapping = mapping , dim = 0 )
588- ret = ret_all .view (cp_size , * ret .shape )[0 ]
563+
564+ ret [:, mla .kv_lora_rank :] = orig_rope_values_reshaped .to (dtype = ret .dtype )
589565
590- if cp_rank == cp_size - 1 :
591- return None
592- return ret [cp_rank ]
566+ return ret
593567
594568
595569
@@ -953,109 +927,102 @@ def _full_test_multi_gpu(
953927 _generate_random_weights (mla )
954928 weights = mla .state_dict ()
955929
956- # up to this point, all ranks should have same tensors because the seed is the same
957- # now we run the reference MLA on rank 0
958- if rank == 0 :
959- # Reference output (single GPU, no parallelism)
960- ref_mapping = Mapping (world_size = 1 , tp_size = 1 , cp_size = 1 , rank = 0 )
961- ref_kv_cache_manager , ref_attn_metadata = _setup_kv_and_metadata (
962- scenario , ref_mapping , gen_steps
963- )
964- # this represents the context step
965- mla (position_ids_ctx , input_ctx , ref_attn_metadata )
966- ref_outputs = []
967- start = time .time ()
968-
969- # CUDA graph setup for timing
970- use_cuda_graph = gen_steps > scenario .ref_steps
971- graph = None
972- graph_output = None
973-
974- for step in range (gen_steps ):
975- for req_id in range (scenario .batch ):
976- ref_kv_cache_manager .impl .add_token (req_id )
977- if step == 0 :
978- ref_attn_metadata = get_attention_backend ("TRTLLM" ).Metadata (
979- seq_lens = torch .tensor ([1 ] * scenario .batch , dtype = torch .int ),
980- request_ids = list (range (scenario .batch )),
981- max_num_requests = scenario .batch ,
982- num_contexts = 0 ,
983- prompt_lens = [scenario .ctx_len ] * scenario .batch ,
984- max_num_tokens = scenario .ctx_len ,
985- kv_cache_manager = ref_kv_cache_manager ,
986- kv_cache_params = KVCacheParams (
987- use_cache = True ,
988- num_cached_tokens_per_seq = [
989- scenario .ctx_len + step for _ in range (scenario .batch )
990- ],
991- ),
992- enable_context_mla_with_cached_kv = True ,
993- )
994- else :
995- ref_attn_metadata .kv_cache_params = KVCacheParams (
930+ # Since all ranks have the same random seed and inputs, each rank computes
931+ # the reference output locally. This avoids the need for collective communication.
932+ # Reference output (single GPU, no parallelism)
933+ ref_mapping = Mapping (world_size = 1 , tp_size = 1 , cp_size = 1 , rank = 0 )
934+ ref_kv_cache_manager , ref_attn_metadata = _setup_kv_and_metadata (
935+ scenario , ref_mapping , gen_steps
936+ )
937+ # this represents the context step
938+ mla (position_ids_ctx , input_ctx , ref_attn_metadata )
939+ ref_outputs = []
940+ start = time .time ()
941+
942+ # CUDA graph setup for timing
943+ use_cuda_graph = gen_steps > scenario .ref_steps
944+ graph = None
945+ graph_output = None
946+
947+ for step in range (gen_steps ):
948+ for req_id in range (scenario .batch ):
949+ ref_kv_cache_manager .impl .add_token (req_id )
950+ if step == 0 :
951+ ref_attn_metadata = get_attention_backend ("TRTLLM" ).Metadata (
952+ seq_lens = torch .tensor ([1 ] * scenario .batch , dtype = torch .int ),
953+ request_ids = list (range (scenario .batch )),
954+ max_num_requests = scenario .batch ,
955+ num_contexts = 0 ,
956+ prompt_lens = [scenario .ctx_len ] * scenario .batch ,
957+ max_num_tokens = scenario .ctx_len ,
958+ kv_cache_manager = ref_kv_cache_manager ,
959+ kv_cache_params = KVCacheParams (
996960 use_cache = True ,
997961 num_cached_tokens_per_seq = [
998962 scenario .ctx_len + step for _ in range (scenario .batch )
999963 ],
1000- )
1001- ref_attn_metadata .prepare ()
964+ ),
965+ enable_context_mla_with_cached_kv = True ,
966+ )
967+ else :
968+ ref_attn_metadata .kv_cache_params = KVCacheParams (
969+ use_cache = True ,
970+ num_cached_tokens_per_seq = [
971+ scenario .ctx_len + step for _ in range (scenario .batch )
972+ ],
973+ )
974+ ref_attn_metadata .prepare ()
1002975
1003- if not use_cuda_graph :
1004- result = mla (position_ids_gen , input_gen , ref_attn_metadata )
1005- if step < scenario .ref_steps :
1006- ref_outputs .append (result )
976+ if not use_cuda_graph :
977+ result = mla (position_ids_gen , input_gen , ref_attn_metadata )
978+ if step < scenario .ref_steps :
979+ ref_outputs .append (result )
980+ if rank == 0 :
1007981 print (f"Ref result: { result [0 , :8 ]} / { result [- 1 , - 8 :]} " )
1008- # update position_ids_gen
1009- position_ids_gen += 1
1010- continue
982+ # update position_ids_gen
983+ position_ids_gen += 1
984+ continue
1011985
1012- # CUDA graph capture on first step when timing
1013- if step == 0 :
986+ # CUDA graph capture on first step when timing
987+ if step == 0 :
988+ if rank == 0 :
1014989 print ("Creating CUDA graph and capturing" )
1015- # Create CUDA graph metadata for capture
1016- ref_attn_metadata = ref_attn_metadata .create_cuda_graph_metadata (
1017- max_batch_size = scenario .batch
1018- )
1019- ref_attn_metadata .prepare ()
1020-
1021- # Warm-up runs before graph capture
1022- for _ in range (2 ):
1023- result = mla (position_ids_gen , input_gen , ref_attn_metadata )
1024-
1025- # Capture the graph
1026- graph = torch .cuda .CUDAGraph ()
1027- with torch .cuda .graph (graph ):
1028- graph_output = mla (position_ids_gen , input_gen , ref_attn_metadata )
1029- result = graph_output
1030- elif step == scenario .ref_steps :
1031- # Start timing with CUDA graph
1032- start = time .time ()
1033- graph .replay ()
990+ # Create CUDA graph metadata for capture
991+ ref_attn_metadata = ref_attn_metadata .create_cuda_graph_metadata (
992+ max_batch_size = scenario .batch
993+ )
994+ ref_attn_metadata .prepare ()
995+
996+ # Warm-up runs before graph capture
997+ for _ in range (2 ):
998+ result = mla (position_ids_gen , input_gen , ref_attn_metadata )
999+
1000+ # Capture the graph
1001+ graph = torch .cuda .CUDAGraph ()
1002+ with torch .cuda .graph (graph ):
1003+ graph_output = mla (position_ids_gen , input_gen , ref_attn_metadata )
10341004 result = graph_output
1035- # update position_ids_gen
1036- position_ids_gen += 1
1037- if step < scenario .ref_steps :
1038- ref_outputs .append (result )
1039- end = time .time ()
1040- if gen_steps == scenario .ref_steps :
1041- avg_gen_time = float ("inf" )
1042- else :
1043- avg_gen_time = (end - start ) / (gen_steps - scenario .ref_steps )
1044- throughput = scenario .batch / avg_gen_time
1005+ elif step == scenario .ref_steps :
1006+ # Start timing with CUDA graph
1007+ start = time .time ()
1008+ graph .replay ()
1009+ result = graph_output
1010+ # update position_ids_gen
1011+ position_ids_gen += 1
1012+ if step < scenario .ref_steps :
1013+ ref_outputs .append (result )
1014+ end = time .time ()
1015+ if gen_steps == scenario .ref_steps :
1016+ avg_gen_time = float ("inf" )
1017+ else :
1018+ avg_gen_time = (end - start ) / (gen_steps - scenario .ref_steps )
1019+ throughput = scenario .batch / avg_gen_time
1020+ if rank == 0 :
10451021 print (
10461022 f"Time taken for { gen_steps - scenario .ref_steps } steps: "
10471023 f"{ end - start } s, throughput: { throughput } MLA/s"
10481024 )
1049- ref_output = torch .stack (ref_outputs , dim = 0 )
1050- else :
1051- ref_output = torch .empty (
1052- scenario .ref_steps ,
1053- scenario .batch ,
1054- scenario .hidden_size ,
1055- dtype = scenario .dtype ,
1056- device = "cuda" ,
1057- )
1058- ref_attn_metadata = None
1025+ ref_output = torch .stack (ref_outputs , dim = 0 )
10591026
10601027 # Distributed mapping for mixed TP+CP helix
10611028 mapping = Mapping (
@@ -1066,11 +1033,6 @@ def _full_test_multi_gpu(
10661033 cp_config = {"cp_type" : CpType .HELIX }
10671034 )
10681035
1069- # Broadcast reference output from rank 0 to all ranks using allgather
1070- ref_output_all = cp_allgather (ref_output , mapping = mapping , dim = 0 )
1071- # we only need the values from rank 0
1072- ref_output = ref_output_all .view (world_size , * ref_output .shape )[0 ]
1073-
10741036 test_params = (
10751037 input_ctx ,
10761038 input_gen ,
0 commit comments