@@ -715,7 +715,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
715
715
# are non-contiguous (it's not locally guaranteed that they will be)
716
716
# Disadvantage is that the encoded NixlAgentMetadata is now larger
717
717
# (roughly 8KB vs 5KB).
718
- # Conversely for FlashInfer, K and V are transferred in the same tensor
718
+ # Conversely for FlashInfer, K and V are registered in the same region
719
719
# to better exploit the memory layout (ie num_blocks is the first dim).
720
720
split_k_and_v = not (self .use_mla or self ._use_pallas_v1
721
721
or self ._use_flashinfer )
@@ -758,12 +758,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
758
758
assert tensor_size_bytes % self .num_blocks == 0
759
759
self .block_len = tensor_size_bytes // self .num_blocks
760
760
self .slot_size_bytes = self .block_len // self .block_size
761
+ self .device_kv_caches = kv_caches
762
+ self .dst_num_blocks [self .engine_id ] = self .num_blocks
761
763
if self ._use_flashinfer :
762
764
assert self .slot_size_bytes % 2 == 0
763
765
self .slot_size_bytes /= 2
764
- self .device_kv_caches = kv_caches
765
- self .dst_num_blocks [self .engine_id ] = self .num_blocks
766
766
767
+ # NOTE (NickLucche) When FlashInfer is used, memory is registered
768
+ # with joint KV for each block. This minimizes the overhead in
769
+ # registerMem allowing faster descs queries. In order to be able to
770
+ # split on kv_heads dim as required by heterogeneous TP, one must
771
+ # be able to index K/V separately. Hence the we double the number
772
+ # of 'virtual' regions here and halve `block_len` below.
773
+ self .num_regions *= 2
774
+
775
+ kv_block_len = self .get_backend_aware_kv_block_len ()
767
776
# Register local/src descr for NIXL xfer.
768
777
blocks_data = []
769
778
for base_addr in seen_base_addresses :
@@ -776,8 +785,18 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
776
785
block_offset = block_id * self .block_len
777
786
addr = base_addr + block_offset
778
787
# (addr, len, device id)
779
- # TODO: does device_id matter to DRAM?
780
- blocks_data .append ((addr , self .block_len , self .tp_rank ))
788
+ blocks_data .append ((addr , kv_block_len , self .tp_rank ))
789
+
790
+ if self ._use_flashinfer :
791
+ # Separate and interleave K/V regions to maintain the same
792
+ # descs ordering. This is needed for selecting contiguous heads
793
+ # when split across TP ranks.
794
+ for block_id in range (self .num_blocks ):
795
+ block_offset = block_id * self .block_len
796
+ addr = base_addr + block_offset
797
+ # Register addresses for V cache (K registered first).
798
+ v_addr = addr + kv_block_len
799
+ blocks_data .append ((v_addr , kv_block_len , self .tp_rank ))
781
800
logger .debug ("Created %s blocks for src engine %s and rank %s" ,
782
801
len (blocks_data ), self .engine_id , self .tp_rank )
783
802
@@ -903,7 +922,7 @@ def add_remote_agent(self,
903
922
remote_block_size = nixl_agent_meta .block_len // (
904
923
self .slot_size_bytes * tp_ratio )
905
924
if self ._use_flashinfer :
906
- # Account for joint KV in FlashInfer .
925
+ # With flashinfer, KV are sent in the same message .
907
926
remote_block_size //= 2
908
927
if tp_ratio > 1 :
909
928
# Heterogeneous TP expects same kv_cache_layout.
@@ -929,10 +948,10 @@ def add_remote_agent(self,
929
948
# rank. With heterogeneous TP, prepare the descriptors by splitting the
930
949
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
931
950
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
932
- # Only register the remote's descriptors if current rank pulls from it.
933
951
self .kv_caches_base_addr [
934
952
engine_id ] = nixl_agent_meta .kv_caches_base_addr
935
- rank_offset = self .tp_rank % tp_ratio * self .block_len \
953
+ kv_block_len = self .get_backend_aware_kv_block_len ()
954
+ rank_offset = self .tp_rank % tp_ratio * kv_block_len \
936
955
if not (self .use_mla or is_kv_replicated ) else 0
937
956
# Register all remote blocks, but only the corresponding kv heads.
938
957
for base_addr in nixl_agent_meta .kv_caches_base_addr :
@@ -943,7 +962,16 @@ def add_remote_agent(self,
943
962
# self.block_len == remote_block_len//tp_ratio bytes.
944
963
addr = base_addr + block_offset + rank_offset
945
964
# (addr, len, device id)
946
- blocks_data .append ((addr , self .block_len , remote_tp_rank ))
965
+ blocks_data .append ((addr , kv_block_len , remote_tp_rank ))
966
+
967
+ if self ._use_flashinfer :
968
+ # With FlashInfer index V separately to allow head splitting.
969
+ for block_id in range (nixl_agent_meta .num_blocks ):
970
+ block_offset = block_id * nixl_agent_meta .block_len
971
+ addr = base_addr + block_offset + rank_offset
972
+ v_addr = addr + nixl_agent_meta .block_len // 2
973
+ blocks_data .append ((v_addr , kv_block_len , remote_tp_rank ))
974
+
947
975
logger .debug (
948
976
"Created %s blocks for dst engine %s with remote rank %s and "
949
977
"local rank %s" , len (blocks_data ), engine_id , remote_tp_rank ,
@@ -1249,6 +1277,22 @@ def _get_block_descs_ids(self,
1249
1277
descs_ids .append (reg_id * num_blocks + block_id )
1250
1278
return descs_ids
1251
1279
1280
+ def get_backend_aware_kv_block_len (self ):
1281
+ """
1282
+ Get the block length for one K/V element (K and V have the same size).
1283
+
1284
+ For FA and other backends, this is equal to the length of the whole
1285
+ block, as K and V are in separate regions.
1286
+ For FlashInfer, this is half the length of the whole block, as K and V
1287
+ share the same region.
1288
+ """
1289
+ if self ._use_flashinfer :
1290
+ # For indexing only half (either just the K or V part).
1291
+ block_len = self .block_len // 2
1292
+ else :
1293
+ block_len = self .block_len
1294
+ return block_len
1295
+
1252
1296
1253
1297
@contextlib .contextmanager
1254
1298
def zmq_ctx (socket_type : Any , addr : str ) -> Iterator [zmq .Socket ]:
0 commit comments