diff --git a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh index e90b72a7cf24..c17a93ab27c7 100755 --- a/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh +++ b/tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh @@ -44,6 +44,39 @@ get_model_args() { echo "$extra_args" } +set_cli_args() { + PREFILLER_TP_SIZE=1 + DECODER_TP_SIZE=1 + # Iterate through the rest of the arguments + while [[ $# -gt 0 ]]; do + echo $# + case "$1" in + --prefiller-tp-size) + if [[ -n "$2" ]]; then + PREFILLER_TP_SIZE="$2" + shift 2 # Consume the flag and its value ($2) + else + echo "Error: --prefiller-tp-size requires a value." >&2 + exit 1 + fi + ;; + --decoder-tp-size) + if [[ -n "$2" ]]; then + DECODER_TP_SIZE="$2" + shift 2 + else + echo "Error: --decoder-tp-size requires a value." >&2 + exit 1 + fi + ;; + *) + # Handle any arguments not recognized + shift # Ignore unknown argument + ;; + esac + done +} + # Function to run tests for a specific model run_tests_for_model() { @@ -54,6 +87,7 @@ run_tests_for_model() { # Get model-specific arguments local model_args=$(get_model_args "$model_name") + set_cli_args "$@" # Arrays to store all hosts and ports PREFILL_HOSTS=() @@ -64,20 +98,31 @@ run_tests_for_model() { # Start prefill instances for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do # Calculate GPU ID - we'll distribute across available GPUs - GPU_ID=$((i % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) + # For tensor parallelism, we need to assign multiple consecutive GPUs + BASE_GPU_ID=$(((i * $PREFILLER_TP_SIZE) % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) + # Create a comma-separated list of GPU IDs for tensor parallelism + GPU_IDS="" + for ((j=0; j<$PREFILLER_TP_SIZE; j++)); do + if [ $j -gt 0 ]; then + GPU_IDS+="," + fi + GPU_IDS+="$((BASE_GPU_ID + j))" + done + # Calculate port number (base port + instance number) PORT=$((8100 + i)) - # Calculate side channel port - SIDE_CHANNEL_PORT=$((5559 + i)) + # Calculate side channel port. Avoid clash with with TP workers. + SIDE_CHANNEL_PORT=$((5559 + i * $PREFILLER_TP_SIZE)) - echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT" + echo "Starting prefill instance $i on GPUs $GPU_IDS, port $PORT" # Build the command with or without model-specific args - BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + BASE_CMD="VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=$GPU_IDS VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ --port $PORT \ --enforce-eager \ --disable-log-requests \ --gpu-memory-utilization 0.2 \ + --tensor-parallel-size $PREFILLER_TP_SIZE \ --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" if [ -n "$model_args" ]; then @@ -96,20 +141,31 @@ run_tests_for_model() { # Start decode instances for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do # Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs - GPU_ID=$(((i + NUM_PREFILL_INSTANCES) % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) + # For tensor parallelism, we need to assign multiple consecutive GPUs + BASE_GPU_ID=$(((i * $DECODER_TP_SIZE + $NUM_PREFILL_INSTANCES * $PREFILLER_TP_SIZE) % $(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l))) + # Create a comma-separated list of GPU IDs for tensor parallelism + GPU_IDS="" + for ((j=0; j<$DECODER_TP_SIZE; j++)); do + if [ $j -gt 0 ]; then + GPU_IDS+="," + fi + GPU_IDS+="$((BASE_GPU_ID + j))" + done + # Calculate port number (base port + instance number) PORT=$((8200 + i)) # Calculate side channel port - SIDE_CHANNEL_PORT=$((5659 + i)) + SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE + $NUM_PREFILL_INSTANCES * $PREFILLER_TP_SIZE)) - echo "Starting decode instance $i on GPU $GPU_ID, port $PORT" + echo "Starting decode instance $i on GPUs $GPU_IDS, port $PORT" # Build the command with or without model-specific args - BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ + BASE_CMD="VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=$GPU_IDS VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT vllm serve $model_name \ --port $PORT \ --enforce-eager \ --disable-log-requests \ --gpu-memory-utilization 0.2 \ + --tensor-parallel-size $DECODER_TP_SIZE \ --kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\"}'" if [ -n "$model_args" ]; then @@ -165,7 +221,7 @@ run_tests_for_model() { # Run tests for each model for model in "${MODELS[@]}"; do - run_tests_for_model "$model" + run_tests_for_model "$model" "$@" done echo "All tests completed!" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d26184982270..e1bb6091388e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -103,6 +103,8 @@ class NixlAgentMetadata( agent_metadata: bytes kv_caches_base_addr: list[int] num_blocks: int + tp_size: int + block_len: int @dataclass @@ -153,7 +155,8 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): self.connector_worker: Optional[NixlConnectorWorker] = None elif role == KVConnectorRole.WORKER: self.connector_scheduler = None - self.connector_worker = NixlConnectorWorker(str(self.engine_id)) + self.connector_worker = NixlConnectorWorker( + vllm_config, str(self.engine_id)) ############################################################ # Scheduler Side Methods @@ -347,17 +350,21 @@ def request_finished( class NixlConnectorWorker: """Implementation of Worker side methods""" - def __init__(self, engine_id: str): + def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: logger.error("NIXL is not available") raise RuntimeError("NIXL is not available") logger.info("Initializing NIXL wrapper") logger.info("Initializing NIXL worker %s", engine_id) + # Config. + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + # Agent. self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) - # Map of engine_id -> agent_name. - self._remote_agents: dict[str, str] = {} + # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. + self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict) # Metadata. self.engine_id = engine_id @@ -368,20 +375,22 @@ def __init__(self, engine_id: str): # KV Caches and nixl tracking data. self.kv_caches: dict[str, torch.Tensor] = {} - # Map of engine_id -> kv_caches_base_addr - self.kv_caches_base_addr: dict[str, list[int]] = {} + # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. + self.kv_caches_base_addr: dict[str, list[int]] = dict() # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) self.num_regions = 0 - # nixl_prepped_dlist_handle (int). + # nixl_prepped_dlist_handle. self.src_xfer_side_handle: int = 0 # Map of engine_id -> nixl_prepped_dlist_handle (int)]. - self.dst_xfer_side_handles: dict[str, int] = {} + self.dst_xfer_side_handles: dict[str, int] = dict() - # Map of engine_id -> num_blocks. - self.dst_num_blocks: dict[str, int] = {} + # Map of engine_id -> num_blocks. All ranks in the same deployment will + # have the same number of blocks. + self.dst_num_blocks: dict[str, int] = dict() self._registered_descs: list[Any] = [] # In progress transfers. @@ -400,6 +409,13 @@ def __init__(self, engine_id: str): # Background thread for establishing new connections. self._nixl_handshake_listener_t: Optional[threading.Thread] = None + self._tp_size: dict[str, int] = {self.engine_id: self.world_size} + + # With heterogeneous TP, P must wait for all assigned D TP workers to + # finish reading before safely freeing the blocks. + self.consumer_notification_counts_by_req: dict[str, + int] = defaultdict(int) + @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, ready_event: threading.Event, rank: int): @@ -439,27 +455,44 @@ def _nixl_handshake(self, host: str, port: int): """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() + # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - path = f"tcp://{host}:{port + self.rank}" - logger.debug("Querying metadata on path: %s", path) - with zmq_ctx(zmq.REQ, path) as sock: - # Send query for the request. - sock.send(GET_META_MSG) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) - got_metadata_time = time.perf_counter() - - # Register Remote agent. - self.add_remote_agent(metadata) - setup_agent_time = time.perf_counter() - logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) + def handshake(path: str, rank: int) -> NixlAgentMetadata: + # Send query for the request. + with zmq_ctx(zmq.REQ, path) as sock: + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + + # Register Remote agent. + self.add_remote_agent(metadata, rank) + setup_agent_time = time.perf_counter() + + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + return metadata + + # Handshake with remote agent-rank0 first to get the tp_size of remote + path = f"tcp://{host}:{port}" + logger.debug("Querying master rank metadata on path: %s", path) + metadata = handshake(path, 0) + + # Handshake only with the other TP remote the current local rank will + # pull from. With homogeneous TP it happens to be the same rank_i. + tp_rate = self._tp_size[self.engine_id] // metadata.tp_size + p_remote_rank = self.rank // tp_rate + if p_remote_rank > 0: + path = f"tcp://{host}:{port + p_remote_rank}" + logger.debug("Querying metadata on path: %s at remote rank %s", + path, p_remote_rank) + _ = handshake(path, p_remote_rank) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data in nixl.""" @@ -474,14 +507,20 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.num_blocks = first_kv_cache.shape[0] block_rank = 2 # [block_size, latent_dim] block_shape = first_kv_cache.shape[-block_rank:] + block_size, kv_latent_dim = block_shape + self.slot_size_bytes = kv_elem_size * kv_latent_dim else: - # [2 (k and v), num_blocks, ...] + # [2 (k and v), num_blocks, block_size, kv_heads, head_dim] self.num_blocks = first_kv_cache.shape[1] block_rank = 3 # [block_size, kv_heads, head_dim] block_shape = first_kv_cache.shape[-block_rank:] - + block_size, n_kv_heads, head_dim = block_shape + # head size in bytes. + self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim + assert block_size == self.block_size # TODO(tms): self.block_len needs to be per-layer for sliding window, # hybrid attn, etc + # block size in bytes self.block_len = kv_elem_size * math.prod(block_shape) logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla, @@ -515,16 +554,39 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): logger.debug("Registering descs: %s", caches_data) self.nixl_wrapper.register_memory(descs) logger.debug("Done registering descs") - self._registered_descs.append(descs) + # Register local/src descr for NIXL xfer. + blocks_data = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + # NOTE With heter-TP, more blocks are prepared than what are + # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We + # could create fewer, but then _get_block_descs_ids needs to + # select agent_meta.num_blocks instead of self.num_blocks for + # local descr, and that makes handling regular flow less clean. + for block_id in range(self.num_blocks): + block_offset = block_id * self.block_len + for slot_idx in range(self.block_size): + slot_offset = slot_idx * self.slot_size_bytes + addr = base_addr + block_offset + slot_offset + # (addr, len, device id) + blocks_data.append((addr, self.slot_size_bytes, self.rank)) + logger.debug("Created %s blocks for src engine %s and rank %s", + len(blocks_data), self.engine_id, self.rank) + + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + # NIXL_INIT_AGENT to be used for preparations of local descs. + self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + # After KV Caches registered, listen for new connections. metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, - ) + tp_size=self.world_size, + block_len=self.block_len) ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, @@ -534,49 +596,108 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self._nixl_handshake_listener_t.start() ready_event.wait() - def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata): + def add_remote_agent(self, + nixl_agent_meta: NixlAgentMetadata, + remote_rank: int = 0): + """ + Add the remote NIXL agent and prepare the descriptors for reading cache + blocks from remote. + + In particular, handle both homogeneous and heterogeneous TP. The latter + requires local rank_i to read from remote rank_i. + The former, assuming D.world_size > P.world_size, requires that two or + more local TP worker share the xfer from a single TP worker. + + Here's an example: + + rank_offset p_remote_rank + (kv split no) + -------------------------------- + 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ] + / + 1 0 Worker1 ---- 2nd half of KV -----/ + + 0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ] + / + 1 1 Worker3 ---- 2nd half of KV -----/ + + + Decoder TP workers Prefix TP workers + (world_size=4) (world_size=2) + tp_ratio = 4 // 2 = 2 + + Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, block_size, kv_heads, head_dim] + then D-Worker_j has [2, num_blocksD, block_size, kv_heads//tp_ratio, head_dim]. + Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio + first heads from all the slots of all the blocks in the case. + D-Worker1 will do the same, but reading the second split along the kv_heads dimension. + + Note that the above will also hold true for the homogeneous TP case. + """ # noqa: E501 + engine_id = nixl_agent_meta.engine_id - if engine_id in self._remote_agents: + # TODO re-evaluate refreshing for scaling/recovery + if (engine_id in self._remote_agents and \ + remote_rank in self._remote_agents[engine_id]): return - self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent( - nixl_agent_meta.agent_metadata) - self.kv_caches_base_addr[ - engine_id] = nixl_agent_meta.kv_caches_base_addr - - # Create src descs and xfer side handles. - blocks_data = [] - for base_addr in self.kv_caches_base_addr[self.engine_id]: - for block_id in range(self.num_blocks): - block_offset = block_id * self.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for src engine %s and rank %s", - len(blocks_data), self.engine_id, self.rank) - - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist( - "NIXL_INIT_AGENT", descs) + if engine_id in self._tp_size: + assert self._tp_size[engine_id] == nixl_agent_meta.tp_size + self._tp_size[engine_id] = nixl_agent_meta.tp_size + self._remote_agents[engine_id][ + remote_rank] = self.nixl_wrapper.add_remote_agent( + nixl_agent_meta.agent_metadata) + + # Number of D TP workers reading from a single P TP worker. This is + # 1 when P and D `--tensor-parallel-size` match. + tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id] + assert tp_ratio > 0, "Decode TP cannot be smaller than" + " prefill TP" + + # TODO we should also check hidden_dim and kv precision, they must match + remote_block_size = nixl_agent_meta.block_len / (self.slot_size_bytes * + tp_ratio) + assert self.block_size == remote_block_size, "Remote P worker with " + "different block size is not supported" + + # Create dst descs and xfer side handles. TP workers have same #blocks. + if engine_id in self.dst_num_blocks: + assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks - # Create dst descs and xfer side handles. self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks + blocks_data = [] - for base_addr in self.kv_caches_base_addr[engine_id]: - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * self.block_len - # (addr, len, device id) - blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for dst engine %s and rank %s", - len(blocks_data), engine_id, self.rank) + # With homogeneous TP, D pulls the whole kv cache from corresponding + # rank. With heterogeneous TP, prepare the descriptors by splitting the + # P KV cache along kv_head dim, of D worker's kv_head size (D>P). + # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. + p_remote_rank = self.rank // tp_ratio + # Only register the remote's descriptors if current rank pulls from it. + if p_remote_rank == remote_rank: + self.kv_caches_base_addr[ + engine_id] = nixl_agent_meta.kv_caches_base_addr + rank_offset = self.rank % tp_ratio * self.slot_size_bytes + # Register all remote blocks, but only the corresponding kv heads. + for base_addr in nixl_agent_meta.kv_caches_base_addr: + for block_id in range(nixl_agent_meta.num_blocks): + block_offset = block_id * nixl_agent_meta.block_len + for slot_idx in range(self.block_size): + # Remote has `tp_ratio` times the kv_heads of local. + slot_offset = slot_idx * self.slot_size_bytes * tp_ratio + addr = base_addr + block_offset + slot_offset + # (addr, len, device id) + blocks_data.append((addr + rank_offset, + self.slot_size_bytes, remote_rank)) + logger.debug( + "Created %s blocks for dst engine %s with remote rank %s and " \ + "local rank %s", + len(blocks_data), engine_id, remote_rank, self.rank) - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") - self.dst_xfer_side_handles[ - engine_id] = self.nixl_wrapper.prep_xfer_dlist( - self._remote_agents[engine_id], descs) + # Register with NIXL. + descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") + self.dst_xfer_side_handles[ + engine_id] = self.nixl_wrapper.prep_xfer_dlist( + self._remote_agents[engine_id][remote_rank], descs) def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -647,10 +768,15 @@ def _get_new_notifs(self) -> set[str]: """Get req_ids which got a remote xfer message.""" notified_req_ids: set[str] = set() - for req_ids in self.nixl_wrapper.get_new_notifs().values(): - for req_id in req_ids: - assert req_id not in notified_req_ids - notified_req_ids.add(req_id.decode("utf-8")) + for notifs in self.nixl_wrapper.get_new_notifs().values(): + for notif in notifs: + req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) + self.consumer_notification_counts_by_req[req_id] += 1 + # Wait all consumers (D) to be done reading before freeing. + if self.consumer_notification_counts_by_req[req_id] == int( + tp_ratio): + notified_req_ids.add(req_id) + del self.consumer_notification_counts_by_req[req_id] return notified_req_ids def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: @@ -725,12 +851,17 @@ def _read_blocks( # saturate IB with heterogeneous TP sizes. We should remove the staging # blocks until we are ready. + # Number of D TP workers that will read from dst P. Propagate tp_ratio + # on notification so that dst worker can wait before freeing blocks. + tp_ratio = self._tp_size[ + self.engine_id] // self._tp_size[dst_engine_id] + notif_id = f"{request_id}:{tp_ratio}".encode() + # Full prefix cache hit: do not need to read remote blocks, # just notify P worker that we have the blocks we need. num_local_blocks = len(local_block_ids) if num_local_blocks == 0: - self.nixl_wrapper.send_notif(dst_engine_id, - notif_msg=request_id.encode("utf-8")) + self.nixl_wrapper.send_notif(dst_engine_id, notif_msg=notif_id) return # Partial prefix cache hit: just read uncomputed blocks. @@ -743,6 +874,10 @@ def _read_blocks( local_xfer_side_handle = self.src_xfer_side_handle remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id] + # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from + # corresponding rank. With heterogeneous TP, fixing D>P, the D tp + # workers will issue xfers to parts of the P worker remote kv caches. + # Get descs ids. remote_block_descs_ids = self._get_block_descs_ids( dst_engine_id, remote_block_ids) @@ -757,7 +892,7 @@ def _read_blocks( local_block_descs_ids, remote_xfer_side_handle, remote_block_descs_ids, - notif_msg=request_id.encode("utf-8"), + notif_msg=notif_id, ) # Begin async xfer. @@ -769,16 +904,20 @@ def _read_blocks( def _get_block_descs_ids(self, engine_id: str, block_ids: list[int]) -> list[int]: """Get the descs ids for a set of block ids.""" + # TODO docs # range(1) for MLA, range(2) otherwise. region_ids = range(self.num_regions) + # TODO using a diff num of blocks here in dst and src num_blocks = self.dst_num_blocks[engine_id] # Compute the desc ids for each block. descs_ids: list[int] = [] for reg_id in region_ids: for block_id in block_ids: - descs_ids.append(reg_id * num_blocks + block_id) + for slot_id in range(self.block_size): + descs_ids.append(reg_id * num_blocks * self.block_size + + block_id * self.block_size + slot_id) return descs_ids