Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
"""
s1 = time.perf_counter()
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",

Check failure on line 189 in tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

tests/v1/kv_connector/nixl_integration/toy_proxy_server.py:189:5: F841 Local variable `s1` is assigned to but never used
"X-Request-Id": request_id
}

Expand All @@ -202,7 +202,7 @@
async def _handle_completions(api: str, request: Request):
s1 = time.perf_counter()
try:
req_data = await request.json()

Check failure on line 205 in tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

tests/v1/kv_connector/nixl_integration/toy_proxy_server.py:205:5: F841 Local variable `s1` is assigned to but never used
request_id = str(uuid.uuid4())

# Get the next prefill client in round-robin fashion
Expand All @@ -212,16 +212,22 @@
p = send_request_to_service(prefill_client_info, api,
req_data, request_id)
s2 = time.perf_counter()
print(f'libin proxy send to prefill {s2-s1}')
sys.stdout.flush()
response = await p

Check failure on line 216 in tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

tests/v1/kv_connector/nixl_integration/toy_proxy_server.py:216:9: F841 Local variable `s2` is assigned to but never used
s3 = time.perf_counter()
# Extract the needed fields
response_json = response.json()

Check failure on line 219 in tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

tests/v1/kv_connector/nixl_integration/toy_proxy_server.py:219:9: F841 Local variable `s3` is assigned to but never used
kv_transfer_params = response_json.get('kv_transfer_params', {})
if kv_transfer_params:
remote_block_len = len(kv_transfer_params['remote_block_ids'])
logger.debug('buke: cut:', type(kv_transfer_params), kv_transfer_params['remote_block_ids'],kv_transfer_params['remote_block_ids'][:(remote_block_len//8)*8])

kv_transfer_params['remote_block_ids'] = kv_transfer_params['remote_block_ids'][:(remote_block_len//8)*8]
if remote_block_len % 8 == 0:
kv_transfer_params['remote_block_ids'] = kv_transfer_params['remote_block_ids'][:(remote_block_len//8)*8-1]
logger.info('buke hit corner case multiples of 8:', remote_block_len)
req_data["kv_transfer_params"] = kv_transfer_params

#print(req_data)
# Get the next decode client in round-robin fashion
decode_client_info = get_next_client(request.app, 'decode')

Expand All @@ -238,24 +244,22 @@

if is_first is False:
s4 = time.perf_counter()
print(f'libin debug proxy receive decode 1 total:{s4-s1}| prefill:{s3-s1}| in-between:{s6-s3}|decode:{s4-s6}| {s6=} {s4=}')
sys.stdout.flush()
is_first = True

Check failure on line 248 in tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

tests/v1/kv_connector/nixl_integration/toy_proxy_server.py:248:13: F841 Local variable `s6` is assigned to but never used
yield chunk

re = StreamingResponse(generate_stream(),
media_type="application/json")
s5 = time.perf_counter()
#print(f'libin debug proxy receive decode2 {s5-s1} {s5-s6}')

#sys.stdout.flush()

Check failure on line 255 in tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

tests/v1/kv_connector/nixl_integration/toy_proxy_server.py:255:21: F841 Local variable `s4` is assigned to but never used
return re

except Exception as e:
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
f" - {api} endpoint")

Check failure on line 262 in tests/v1/kv_connector/nixl_integration/toy_proxy_server.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

tests/v1/kv_connector/nixl_integration/toy_proxy_server.py:262:9: F841 Local variable `s5` is assigned to but never used
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
Expand Down
76 changes: 52 additions & 24 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,18 @@
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
s1 = time.perf_counter()
assert self.connector_worker is not None

Check failure on line 208 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:208:9: F841 Local variable `s1` is assigned to but never used
re= self.connector_worker.get_finished()
#logger.info(f'libin debug get_finished {os.getenv('RANK')}, takes {time.perf_counter() - s1}')
return re

Check failure on line 211 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:211:81: E501 Line too long (103 > 80)

def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
s1 = time.perf_counter()
assert self.connector_worker is not None

Check failure on line 216 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:216:9: F841 Local variable `s1` is assigned to but never used
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
logger.info(f"libin debug start_load_kv return {os.getenv('RANK')}, takes {time.perf_counter() - s1}")
#logger.info(f"libin debug start_load_kv return {os.getenv('RANK')}, takes {time.perf_counter() - s1}")

def wait_for_layer_load(self, layer_name: str) -> None:
"""NixlConnector does not do layerwise saving."""
Expand Down Expand Up @@ -449,7 +449,8 @@
# Config.
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size

self.block_factor = 8 # A100.block_size/G2.block_size
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it going to be hardcoded value ?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's hardcode for now, maybe it's ok since this number won't change

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to check block size on both

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can check if the remote.block_size is expected. we don't know the remote.block_size here because the handshake occures afterwards.

self.block_shape = None
# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
Expand Down Expand Up @@ -622,7 +623,6 @@
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s", path,
p_remote_rank)

# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
Expand Down Expand Up @@ -786,6 +786,7 @@
block_shape[0] = block_shape[0] // self.num_blocks
block_shape = torch.Size(block_shape)
block_size, n_kv_heads, head_dim = block_shape[-3:]
self.block_shape = [block_size, n_kv_heads, head_dim]
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
else:
Expand All @@ -802,7 +803,7 @@
"per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device,
self.use_host_buffer, self.num_blocks, block_shape,
first_kv_cache[0].shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.dst_num_blocks[self.engine_id] = self.num_blocks * self.block_factor
self.device_kv_caches = kv_caches
kv_caches_base_addr = []
caches_data = []
Expand All @@ -815,6 +816,7 @@
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).

for cache_or_caches in xfer_buffers.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla \
Expand Down Expand Up @@ -870,15 +872,15 @@
# could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
for block_id in range(self.num_blocks * self.block_factor):
block_offset = block_id * self.block_len // (self.block_factor)
addr = base_addr + block_offset
# (addr, len, device id)
# TODO: does device_id matter to DRAM?
blocks_data.append((addr, self.block_len, self.tp_rank))
blocks_data.append((addr, self.block_len//(self.block_factor), self.tp_rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.tp_rank)

#print(f'buke: {blocks_data[0:10]=}')
descs = self.nixl_wrapper.get_xfer_descs(blocks_data,
self.nixl_memory_type)
# NIXL_INIT_AGENT to be used for preparations of local descs.
Expand Down Expand Up @@ -955,7 +957,7 @@
assert self._tp_size[engine_id] == remote_tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
assert nixl_agent_meta.attn_backend_name == self.backend_name
assert nixl_agent_meta.attn_backend_name == "FLASH_ATTN_VLLM_V1" or nixl_agent_meta.attn_backend_name == "HPU_ATTN_V1"

remote_agent_name = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
Expand Down Expand Up @@ -984,14 +986,14 @@
# Account for joint KV in FlashInfer.
remote_block_size //= 2

assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
#assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
# "Remote P worker KV layer cache must be of shape [2, N, "
# "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
#)

assert self.block_size == remote_block_size, (
"Remote P worker with different block size is not supported "
f"{self.block_size=} {remote_block_size=}")
#assert self.block_size == remote_block_size, (
# "Remote P worker with different block size is not supported "
# f"{self.block_size=} {remote_block_size=}")

# Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks:
Expand All @@ -1007,7 +1009,7 @@
# Only register the remote's descriptors if current rank pulls from it.
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.tp_rank % tp_ratio * self.block_len \
rank_offset = self.tp_rank % tp_ratio * nixl_agent_meta.block_len // tp_ratio \
if not (self.use_mla or is_kv_replicated) else 0
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
Expand All @@ -1018,15 +1020,16 @@
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, remote_tp_rank))
blocks_data.append((addr, nixl_agent_meta.block_len//tp_ratio, remote_tp_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
self.tp_rank)

logger.debug(f'buke {self.slot_size_bytes=}|{tp_ratio=}|{self.block_len=}|{nixl_agent_meta.block_len=}|{self.tp_rank=}|{self._use_flashinfer=}')
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data,
self.nixl_memory_type)
#print('buke register remote:', len(blocks_data), blocks_data[:10],blocks_data[-1],self.nixl_memory_type)
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
remote_agent_name, descs)
Expand Down Expand Up @@ -1081,6 +1084,24 @@
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving", self.tp_rank,
len(done_sending), len(done_recving))
#import remote_pdb; remote_pdb.set_trace()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add check remote is gpu attention?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i can add this

t1 = time.perf_counter()
remote_block_size = self.block_size // self.block_factor
block_size, n_kv_heads, head_dim = self.block_shape
for req_id in done_recving:
#print(req_id, self._recving_metadata)
meta = self._recving_metadata.pop(req_id)
for k, v in self.device_kv_caches.values():
local_block_ids = meta.local_block_ids
#print(f'buke {local_block_ids=}|{k.shape=}')
for block_idx in local_block_ids:
#import remote_pdb; remote_pdb.set_trace()
k[block_idx*self.block_size: (1+block_idx)*self.block_size] = k[block_idx*self.block_size: (1+block_idx)*self.block_size].reshape(self.block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(self.block_size,n_kv_heads,head_dim)
v[block_idx*self.block_size: (1+block_idx)*self.block_size] = v[block_idx*self.block_size: (1+block_idx)*self.block_size].reshape(self.block_factor, n_kv_heads, remote_block_size, head_dim).permute(0,2,1,3).contiguous().reshape(self.block_size,n_kv_heads,head_dim)
#import remote_pdb; remote_pdb.set_trace()
t2 = time.perf_counter()
tt = t2-t1
logger.debug(f'buke permute time:{tt}, {req_id=}|{self._recving_metadata=}')
if self.use_host_buffer:
for req_id in done_recving:
s2 = time.perf_counter()
Expand Down Expand Up @@ -1178,7 +1199,8 @@
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
if self.use_host_buffer:
is_hetero = True
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whats the plan to set this variable

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need the flag to enable the _recving_metadata

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are u going to add it as env flag ?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse the DECODE_TP_RATIO env

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems good

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if this can be moved to hpu_model_runner

if is_hetero or self.use_host_buffer:
self._recving_metadata[req_id] = meta
if remote_engine_id not in self._remote_agents:
# Initiate handshake with remote engine to exchange metadata.
Expand Down Expand Up @@ -1245,8 +1267,8 @@
# Partial prefix cache hit: just read uncomputed blocks.
num_remote_blocks = len(remote_block_ids)
assert num_local_blocks <= num_remote_blocks
if num_local_blocks < num_remote_blocks:
remote_block_ids = remote_block_ids[-num_local_blocks:]
#if num_local_blocks < num_remote_blocks:
# remote_block_ids = remote_block_ids[-num_local_blocks:]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add check for heter


# Get side handles.
local_xfer_side_handle = self.src_xfer_side_handle
Expand All @@ -1259,12 +1281,19 @@
# Get descs ids.
local_block_descs_ids: list[int] = []
remote_block_descs_ids: list[int] = []
local_sub_block_ids: list[int] = []

#print('buke: ', remote_block_ids)
remote_block_ids = remote_block_ids[:(len(remote_block_ids)//self.block_factor)*self.block_factor]
for index,remote_block_id in enumerate(remote_block_ids):
local_sub_block_ids.append(local_block_ids[index//self.block_factor]*self.block_factor + index%self.block_factor)
logger.debug(f'buke {local_block_ids=} |{remote_block_ids=} |{local_sub_block_ids=}')
if not self.block_window_per_layer:
# Default case: assume global attention
remote_block_descs_ids = self._get_block_descs_ids(
dst_engine_id, remote_block_ids)
local_block_descs_ids = self._get_block_descs_ids(
self.engine_id, local_block_ids)
self.engine_id, local_sub_block_ids)
else:
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
Expand Down Expand Up @@ -1350,7 +1379,6 @@

if socket_type not in (zmq.ROUTER, zmq.REQ):
raise ValueError(f"Unexpected socket type: {socket_type}")

ctx: Optional[zmq.Context] = None
try:
ctx = zmq.Context() # type: ignore[attr-defined]
Expand Down