Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,22 @@ async def _handle_completions(api: str, request: Request):
p = send_request_to_service(prefill_client_info, api,
req_data, request_id)
s2 = time.perf_counter()
print(f'libin proxy send to prefill {s2-s1}')
sys.stdout.flush()
response = await p
s3 = time.perf_counter()
# Extract the needed fields
response_json = response.json()
kv_transfer_params = response_json.get('kv_transfer_params', {})
if kv_transfer_params:
remote_block_len = len(kv_transfer_params['remote_block_ids'])
logger.debug('buke: cut:', type(kv_transfer_params), kv_transfer_params['remote_block_ids'],kv_transfer_params['remote_block_ids'][:(remote_block_len//8)*8])

kv_transfer_params['remote_block_ids'] = kv_transfer_params['remote_block_ids'][:(remote_block_len//8)*8]
if remote_block_len % 8 == 0:
kv_transfer_params['remote_block_ids'] = kv_transfer_params['remote_block_ids'][:(remote_block_len//8)*8-1]
logger.info('buke hit corner case multiples of 8:', remote_block_len)
req_data["kv_transfer_params"] = kv_transfer_params

#print(req_data)
# Get the next decode client in round-robin fashion
decode_client_info = get_next_client(request.app, 'decode')

Expand All @@ -238,15 +244,13 @@ async def generate_stream():

if is_first is False:
s4 = time.perf_counter()
print(f'libin debug proxy receive decode 1 total:{s4-s1}| prefill:{s3-s1}| in-between:{s6-s3}|decode:{s4-s6}| {s6=} {s4=}')
sys.stdout.flush()
is_first = True
yield chunk

re = StreamingResponse(generate_stream(),
media_type="application/json")
s5 = time.perf_counter()
#print(f'libin debug proxy receive decode2 {s5-s1} {s5-s6}')

#sys.stdout.flush()
return re
Expand Down
76 changes: 52 additions & 24 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def start_load_kv(self, forward_context: "ForwardContext",
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
logger.info(f"libin debug start_load_kv return {os.getenv('RANK')}, takes {time.perf_counter() - s1}")
#logger.info(f"libin debug start_load_kv return {os.getenv('RANK')}, takes {time.perf_counter() - s1}")

def wait_for_layer_load(self, layer_name: str) -> None:
"""NixlConnector does not do layerwise saving."""
Expand Down Expand Up @@ -449,7 +449,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# Config.
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size

self.block_factor = 8 # A100.block_size/G2.block_size
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it going to be hardcoded value ?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's hardcode for now, maybe it's ok since this number won't change

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to check block size on both

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can check if the remote.block_size is expected. we don't know the remote.block_size here because the handshake occures afterwards.

self.block_shape = None
# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
Expand Down Expand Up @@ -622,7 +623,6 @@ def _nixl_handshake(
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s", path,
p_remote_rank)

# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
Expand Down Expand Up @@ -786,6 +786,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
block_shape[0] = block_shape[0] // self.num_blocks
block_shape = torch.Size(block_shape)
block_size, n_kv_heads, head_dim = block_shape[-3:]
self.block_shape = [block_size, n_kv_heads, head_dim]
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
else:
Expand All @@ -802,7 +803,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device,
self.use_host_buffer, self.num_blocks, block_shape,
first_kv_cache[0].shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.dst_num_blocks[self.engine_id] = self.num_blocks * self.block_factor
self.device_kv_caches = kv_caches
kv_caches_base_addr = []
caches_data = []
Expand All @@ -815,6 +816,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).

for cache_or_caches in xfer_buffers.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla \
Expand Down Expand Up @@ -870,15 +872,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
# could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
for block_id in range(self.num_blocks * self.block_factor):
block_offset = block_id * self.block_len // (self.block_factor)
addr = base_addr + block_offset
# (addr, len, device id)
# TODO: does device_id matter to DRAM?
blocks_data.append((addr, self.block_len, self.tp_rank))
blocks_data.append((addr, self.block_len//(self.block_factor), self.tp_rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.tp_rank)

#print(f'buke: {blocks_data[0:10]=}')
descs = self.nixl_wrapper.get_xfer_descs(blocks_data,
self.nixl_memory_type)
# NIXL_INIT_AGENT to be used for preparations of local descs.
Expand Down Expand Up @@ -955,7 +957,7 @@ def add_remote_agent(self,
assert self._tp_size[engine_id] == remote_tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
assert nixl_agent_meta.attn_backend_name == self.backend_name
#assert nixl_agent_meta.attn_backend_name == self.backend_name

remote_agent_name = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
Expand Down Expand Up @@ -984,14 +986,14 @@ def add_remote_agent(self,
# Account for joint KV in FlashInfer.
remote_block_size //= 2

assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
#assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
# "Remote P worker KV layer cache must be of shape [2, N, "
# "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
#)

assert self.block_size == remote_block_size, (
"Remote P worker with different block size is not supported "
f"{self.block_size=} {remote_block_size=}")
#assert self.block_size == remote_block_size, (
# "Remote P worker with different block size is not supported "
# f"{self.block_size=} {remote_block_size=}")

# Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks:
Expand All @@ -1007,7 +1009,7 @@ def add_remote_agent(self,
# Only register the remote's descriptors if current rank pulls from it.
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.tp_rank % tp_ratio * self.block_len \
rank_offset = self.tp_rank % tp_ratio * nixl_agent_meta.block_len // tp_ratio \
if not (self.use_mla or is_kv_replicated) else 0
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
Expand All @@ -1018,15 +1020,16 @@ def add_remote_agent(self,
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, remote_tp_rank))
blocks_data.append((addr, nixl_agent_meta.block_len//tp_ratio, remote_tp_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
self.tp_rank)

logger.debug(f'buke {self.slot_size_bytes=}|{tp_ratio=}|{self.block_len=}|{nixl_agent_meta.block_len=}|{self.tp_rank=}|{self._use_flashinfer=}')
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data,
self.nixl_memory_type)
#print('buke register remote:', len(blocks_data), blocks_data[:10],blocks_data[-1],self.nixl_memory_type)
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
remote_agent_name, descs)
Expand Down Expand Up @@ -1081,6 +1084,24 @@ def get_finished(self) -> tuple[set[str], set[str]]:
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving", self.tp_rank,
len(done_sending), len(done_recving))
#import remote_pdb; remote_pdb.set_trace()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add check remote is gpu attention?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i can add this

t1 = time.perf_counter()
remote_block_size = self.block_size // self.block_factor
block_size, n_kv_heads, head_dim = self.block_shape
for req_id in done_recving:
#print(req_id, self._recving_metadata)
meta = self._recving_metadata.pop(req_id)
for k, v in self.device_kv_caches.values():
local_block_ids = meta.local_block_ids
#print(f'buke {local_block_ids=}|{k.shape=}')
for block_idx in local_block_ids:
#import remote_pdb; remote_pdb.set_trace()
k[block_idx*self.block_size: (1+block_idx)*self.block_size] = k[block_idx*self.block_size: (1+block_idx)*self.block_size].reshape(self.block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(self.block_size,n_kv_heads,head_dim)
v[block_idx*self.block_size: (1+block_idx)*self.block_size] = v[block_idx*self.block_size: (1+block_idx)*self.block_size].reshape(self.block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(self.block_size,n_kv_heads,head_dim)
#import remote_pdb; remote_pdb.set_trace()
t2 = time.perf_counter()
tt = t2-t1
logger.debug(f'buke permute time:{tt}, {req_id=}|{self._recving_metadata=}')
if self.use_host_buffer:
for req_id in done_recving:
s2 = time.perf_counter()
Expand Down Expand Up @@ -1178,7 +1199,8 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
if self.use_host_buffer:
is_hetero = True
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whats the plan to set this variable

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need the flag to enable the _recving_metadata

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are u going to add it as env flag ?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse the DECODE_TP_RATIO env

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems good

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if this can be moved to hpu_model_runner

if is_hetero or self.use_host_buffer:
self._recving_metadata[req_id] = meta
if remote_engine_id not in self._remote_agents:
# Initiate handshake with remote engine to exchange metadata.
Expand Down Expand Up @@ -1245,8 +1267,8 @@ def _read_blocks(self, local_block_ids: list[int],
# Partial prefix cache hit: just read uncomputed blocks.
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
#if num_local_blocks < num_remote_blocks:
# remote_block_ids = remote_block_ids[-num_local_blocks:]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add check for heter


# Get side handles.
local_xfer_side_handle = self.src_xfer_side_handle
Expand All @@ -1259,12 +1281,19 @@ def _read_blocks(self, local_block_ids: list[int],
# Get descs ids.
local_block_descs_ids: list[int] = []
remote_block_descs_ids: list[int] = []
local_sub_block_ids: list[int] = []

#print('buke: ', remote_block_ids)
remote_block_ids = remote_block_ids[:(len(remote_block_ids)//self.block_factor)*self.block_factor]
for index,remote_block_id in enumerate(remote_block_ids):
local_sub_block_ids.append(local_block_ids[index//self.block_factor]*self.block_factor + index%self.block_factor)
logger.debug(f'buke {local_block_ids=} |{remote_block_ids=} |{local_sub_block_ids=}')
if not self.block_window_per_layer:
# Default case: assume global attention
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id, remote_block_ids)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id, local_block_ids)
self.engine_id, local_sub_block_ids)
else:
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
Expand Down Expand Up @@ -1350,7 +1379,6 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:

if socket_type not in (zmq.ROUTER, zmq.REQ):
raise ValueError(f"Unexpected socket type: {socket_type}")

ctx: Optional[zmq.Context] = None
try:
ctx = zmq.Context() # type: ignore[attr-defined]
Expand Down