diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 4c23b8316fc..adc639be12e 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -75,6 +75,7 @@ def generate_batch_prefill(bs: int): } bs_min = bs + prefill_name = f"prefill_bs{bs}" if export_config.has_prefill_position: seq_len_blocks_dim_chunked = torch.export.Dim( @@ -95,6 +96,9 @@ def generate_batch_prefill(bs: int): if "start_pos" in dynamic_shapes: dynamic_shapes["start_pos"][0] = extend_bs + prefill_name = "prefill_bs_extend" + export_config.bs_prefill = None + seq_block_ids = torch.empty(bs_min, block_dim_min, dtype=torch.int64) tokens = torch.empty( bs_min, @@ -103,13 +107,13 @@ def generate_batch_prefill(bs: int): ) seq_lens = torch.empty(bs_min, dtype=torch.int64) - print(f"Exporting prefill_bs{bs}") + print(f"Exporting {prefill_name}") if export_config.has_prefill_position: arg_devices = model.setup_arg_devices(cache_affinities, len(dynamic_shapes)) @fxb.export_program( - name=f"prefill_bs{bs}", + name=prefill_name, args=(tokens, start_pos, seq_lens, seq_block_ids, cache), dynamic_shapes=dynamic_shapes, arg_device=arg_devices, @@ -132,7 +136,7 @@ def _( arg_devices = model.setup_arg_devices(cache_affinities, len(dynamic_shapes)) @fxb.export_program( - name=f"prefill_bs{bs}", + name=prefill_name, args=(tokens, seq_lens, seq_block_ids, cache), dynamic_shapes=dynamic_shapes, arg_device=arg_devices, diff --git a/sharktank/sharktank/models/llm/config.py b/sharktank/sharktank/models/llm/config.py index 0f8a0cb6047..bcb6ab5d586 100644 --- a/sharktank/sharktank/models/llm/config.py +++ b/sharktank/sharktank/models/llm/config.py @@ -27,12 +27,13 @@ class ServiceConfig: max_seq_len: int attn_head_dim: int prefill_batch_sizes: list[int] - has_prefill_position: bool decode_batch_sizes: list[int] transformer_block_count: int logits_normalization: Optional[str] top_k: Optional[int] paged_kv_cache: KVCacheConfig + has_prefill_position: bool = False + use_extend_attention: bool = False @staticmethod def load(fp: Path): diff --git a/sharktank/sharktank/models/llm/export.py b/sharktank/sharktank/models/llm/export.py index 8a06b815663..4a9e62b5a18 100644 --- a/sharktank/sharktank/models/llm/export.py +++ b/sharktank/sharktank/models/llm/export.py @@ -202,10 +202,11 @@ def build_service_config( max_seq_len=hp.context_length, attn_head_dim=hp.attn_head_dim, prefill_batch_sizes=export_config.bs_prefill, - has_prefill_position=export_config.has_prefill_position, decode_batch_sizes=export_config.bs_decode, transformer_block_count=hp.block_count, logits_normalization=export_config.logits_normalization, top_k=export_config.top_k, paged_kv_cache=kv_config, + has_prefill_position=export_config.has_prefill_position, + use_extend_attention=export_config.use_extend_attention, ) diff --git a/sharktank/sharktank/utils/llm_utils.py b/sharktank/sharktank/utils/llm_utils.py index 98522b1db0b..e9eb63fd440 100644 --- a/sharktank/sharktank/utils/llm_utils.py +++ b/sharktank/sharktank/utils/llm_utils.py @@ -177,7 +177,10 @@ def __init__( setattr(self, funcname, func) if "prefill_bs" in funcname: self._prefill = func - self.prefill_bs = int(funcname[10:]) + if funcname[10:] == "_extend": + self.prefill_bs = 4 + else: + self.prefill_bs = int(funcname[10:]) if "decode_bs" in funcname: self._decode = func self.decode_bs = int(funcname[9:]) diff --git a/shortfin/python/shortfin_apps/llm/cli.py b/shortfin/python/shortfin_apps/llm/cli.py index 7f99d2951be..f319b5c61ef 100644 --- a/shortfin/python/shortfin_apps/llm/cli.py +++ b/shortfin/python/shortfin_apps/llm/cli.py @@ -173,9 +173,12 @@ def send_error( self.ensure_response() def _print_response(self, response): - response_data = json.loads(response.decode("utf-8")) - responses_array = response_data["responses"][0].get("responses", []) - print(json.dumps(responses_array, indent=1)) + if isinstance(response, bytes): + response_data = json.loads(response.decode("utf-8")) + responses_array = response_data["responses"][0].get("responses", []) + print(json.dumps(responses_array, indent=1)) + else: + print(response) def send_response(self, response): logger.info(f"{self.name} Sending response") diff --git a/shortfin/python/shortfin_apps/llm/components/batching/config.py b/shortfin/python/shortfin_apps/llm/components/batching/config.py index a732116f80c..205043dfe8e 100644 --- a/shortfin/python/shortfin_apps/llm/components/batching/config.py +++ b/shortfin/python/shortfin_apps/llm/components/batching/config.py @@ -25,6 +25,7 @@ class BatchMode(Enum): DEFAULT = "Default" + EXTEND_ATTENTION = "ExtendAttention" @dataclass(slots=True) @@ -35,3 +36,4 @@ class BatchConfig: decode_functions: dict[int, sf.ProgramFunction] # type: ignore prog_isolation: sf.ProgramIsolation # type: ignore chunk_block_size: Optional[int] = None + token_budget: Optional[int] = None diff --git a/shortfin/python/shortfin_apps/llm/components/batching/factory.py b/shortfin/python/shortfin_apps/llm/components/batching/factory.py index bd08fd6590e..9fe89cc0b13 100644 --- a/shortfin/python/shortfin_apps/llm/components/batching/factory.py +++ b/shortfin/python/shortfin_apps/llm/components/batching/factory.py @@ -16,6 +16,7 @@ from ..kvcache.base_attention_cache import BasePagedAttentionCache from .batching_trait import BatchingTrait from .modes.default import DefaultBatchingEngine +from .modes.extend_attention import ExtendAttentionBatchingEngine from ..messages import LlmInferenceExecRequest @@ -61,5 +62,15 @@ def _create_impl(batch_cfg: BatchConfig, page_cache: BasePagedAttentionCache, pr ), page_cache=page_cache, ) + elif batch_cfg.mode == BatchMode.EXTEND_ATTENTION: + return _BatchingEngineImpl( + ExtendAttentionBatchingEngine.create( + batch_cfg=batch_cfg, + page_cache=page_cache, + prefill_fiber=prefill_fiber, + decode_fiber=decode_fiber, + ), + page_cache=page_cache, + ) raise ValueError(f"Unsupported Batching Mode: {batch_cfg.mode}") diff --git a/shortfin/python/shortfin_apps/llm/components/batching/modes/default.py b/shortfin/python/shortfin_apps/llm/components/batching/modes/default.py index 44f6c735d31..6a5c2945ce4 100644 --- a/shortfin/python/shortfin_apps/llm/components/batching/modes/default.py +++ b/shortfin/python/shortfin_apps/llm/components/batching/modes/default.py @@ -457,12 +457,17 @@ def __init__( def make_task_inputs( self, exec_request: LlmInferenceExecRequest ) -> List[LlmTaskInput]: + logger.info( + f"DEBUG DECODE: rid={exec_request.orig_instance_id}, start_position={exec_request.start_position}," + f"block_count={exec_request.block_count}, input_tokens={exec_request.input_token_ids}," + f"page_ids={exec_request.page_ids}" + ) return [ LlmTaskInput( rid=exec_request.orig_instance_id, instance_id=exec_request.instance_id, block_count=exec_request.block_count, - seq_len=exec_request.start_position + 1, + seq_len=exec_request.start_position, input_tokens=tuple(exec_request.input_token_ids), page_ids=tuple(exec_request.page_ids), start_position=exec_request.start_position, diff --git a/shortfin/python/shortfin_apps/llm/components/batching/modes/extend_attention.py b/shortfin/python/shortfin_apps/llm/components/batching/modes/extend_attention.py new file mode 100644 index 00000000000..e1934257068 --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/batching/modes/extend_attention.py @@ -0,0 +1,210 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +from typing import List + +import shortfin as sf + +from shortfin import Fiber + +from ..batching_trait import BatchingTrait +from ..config import BatchConfig +from ...config_struct import ModelParams +from ...invocation import ( + ExtendAttentionPrefillTask, + LlmInvocationProcess, + LlmTask, + LlmTaskInput, +) +from ...kvcache.base_attention_cache import BasePagedAttentionCache +from ...messages import InferencePhase, LlmInferenceExecRequest +from ...scheduler import ExtendAttentionScheduler + +from .default import ( + LlmBatcherProcess, + PrefillTaskResponder, + DecodeBatcherProcess, +) + +logger = logging.getLogger(__name__) + + +class ExtendAttentionPrefillBatcherProcess(LlmBatcherProcess): + """Batcher process optimized for extend-attention prefill.""" + + STROBE_SHORT_DELAY = 0.065 + STROBE_LONG_DELAY = 0.065 + + def __init__( + self, + fiber: Fiber, + page_cache: BasePagedAttentionCache, + model_params: ModelParams, + prefill_functions: dict[int, sf.ProgramFunction], + program_isolation: str, + token_budget: int, + ): + # Use the extend-attention aware scheduler + block_seq_stride = model_params.paged_kv_cache.block_seq_stride + + scheduler = ExtendAttentionScheduler( + token_budget=token_budget, block_seq_stride=block_seq_stride + ) + + llm_task_responder = PrefillTaskResponder(scheduler=scheduler) + + # ideal_batch_size - not really important. we can set it to + # maximum number of requests that can be batched together. + ideal_batch_size = token_budget // block_seq_stride + + super().__init__( + name="extend_attention_prefill", + fiber=fiber, + page_cache=page_cache, + model_params=model_params, + functions=prefill_functions, + ideal_batch_size=ideal_batch_size, + program_isolation=program_isolation, + scheduler=scheduler, + llm_task_responder=llm_task_responder, + ) + + def make_task_inputs( + self, exec_request: LlmInferenceExecRequest + ) -> List[LlmTaskInput]: + """Create a single task input containing all tokens. + + The scheduler will dynamically chunk this request at scheduling time based + on the number of active requests and the token budget. + """ + total_tokens = len(exec_request.input_token_ids) + + logger.info( + f"ExtendAttention make_task_inputs: input_token_ids={exec_request.input_token_ids}" + ) + + # Return a single task with ALL tokens + # The scheduler will chunk it dynamically + return [ + LlmTaskInput( + rid=exec_request.orig_instance_id, + instance_id=exec_request.instance_id, + block_count=exec_request.block_count, + seq_len=total_tokens, + input_tokens=tuple(exec_request.input_token_ids), + page_ids=tuple(exec_request.page_ids), + start_position=0 + if exec_request.start_position is None + else exec_request.start_position, + ) + ] + + def make_task( + self, + task_inputs: List[LlmTaskInput], + page_cache: BasePagedAttentionCache, + ) -> LlmTask: + """Create an extend-attention aware prefill task.""" + return ExtendAttentionPrefillTask( + task_inputs=task_inputs, + array_cache=self.array_cache, + page_tables=page_cache.page_pool.page_tables, + has_prefill_position=self.model_params.has_prefill_position, + block_seq_stride=self.page_seq_stride, + ) + + def make_invoker( + self, + page_cache: BasePagedAttentionCache, + fiber: Fiber, + task_inputs: list[LlmTaskInput], + ) -> LlmInvocationProcess: + """Create invoker for extend-attention prefill.""" + return LlmInvocationProcess( + name="extend_attention_prefill_invocation", + fiber=fiber, + llm_task=self.make_task(task_inputs, page_cache), + functions=self.functions, + program_isolation=self.program_isolation, + responder=self._llm_task_responder, + ) + + +class ExtendAttentionBatchingEngine(BatchingTrait): + """Batching engine that uses extend-attention for improved prefill batching.""" + + def __init__( + self, + prefill_lane: ExtendAttentionPrefillBatcherProcess, + decode_lane: DecodeBatcherProcess, + ): + self.prefill_lane = prefill_lane + self.decode_lane = decode_lane + + def submit(self, request: LlmInferenceExecRequest): + if request.phase == InferencePhase.PREFILL: + self.prefill_lane.submit(request) + elif request.phase == InferencePhase.DECODE: + self.decode_lane.submit(request) + else: + raise ValueError( + "Requested unsupported batching lane: Supported only either prefill or decode." + ) + + def launch(self): + self.prefill_lane.launch() + self.decode_lane.launch() + + def shutdown(self): + self.prefill_lane.shutdown() + self.decode_lane.shutdown() + + def reserve_workload(self, rid: str, count: int): + self.decode_lane.reserve_workload(rid=rid, count=count) + + def get_model_params(self) -> ModelParams: + return self.prefill_lane.model_params + + @staticmethod + def create( + batch_cfg: BatchConfig, + page_cache: BasePagedAttentionCache, + prefill_fiber: sf.Fiber, + decode_fiber: sf.Fiber, + ): + """Create an extend-attention batching engine.""" + + # Check if the model was exported with extend-attention support + if not batch_cfg.model_params.use_extend_attention: + raise ValueError( + "Model was not exported with extend-attention support. " + "Please export the model with --use-extend-attention flag." + ) + assert batch_cfg.token_budget is not None + token_budget = batch_cfg.token_budget + + prefill_batcher = ExtendAttentionPrefillBatcherProcess( + fiber=prefill_fiber, + page_cache=page_cache, + model_params=batch_cfg.model_params, + prefill_functions=batch_cfg.prefill_functions, + program_isolation=batch_cfg.prog_isolation, + token_budget=token_budget, + ) + + decode_batcher = DecodeBatcherProcess( + fiber=decode_fiber, + page_cache=page_cache, + model_params=batch_cfg.model_params, + decode_functions=batch_cfg.decode_functions, + program_isolation=batch_cfg.prog_isolation, + ) + + return ExtendAttentionBatchingEngine( + prefill_lane=prefill_batcher, + decode_lane=decode_batcher, + ) diff --git a/shortfin/python/shortfin_apps/llm/components/buffers.py b/shortfin/python/shortfin_apps/llm/components/buffers.py index 1226c547531..d206bdea6ba 100644 --- a/shortfin/python/shortfin_apps/llm/components/buffers.py +++ b/shortfin/python/shortfin_apps/llm/components/buffers.py @@ -42,6 +42,10 @@ def create_argument_buffers( if default is not None: host_buffer.fill(default) host_buffer.items = buffer_data + # Log the final buffer contents for debugging + logger.info( + f"Buffer {index}: shape={buffer.shape}, data={list(host_buffer.items)[:min(20, len(host_buffer.items))]}" + ) args.append(buffer) diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 1126e985929..f81f516530e 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -142,6 +142,9 @@ class ModelParams: # Whether the model was exported with `start_positions` for prefill. has_prefill_position: bool = False + # Whether the model was exported with extend attention support. + use_extend_attention: bool = False + # Name of the IREE module implementing the model. module_name: str = "module" @@ -239,6 +242,10 @@ class ServerParams: # KV cache configuration prefix_sharing_algorithm: str = "none" # none or trie + # Batching configuration + batch_mode: str = "default" # default or extend_attention + token_budget: Optional[int] = None # Token budget for extend_attention mode + # Program isolation configuration program_isolation: str = "per_call" diff --git a/shortfin/python/shortfin_apps/llm/components/invocation.py b/shortfin/python/shortfin_apps/llm/components/invocation.py index 8c1512031cc..e7d7afdec86 100644 --- a/shortfin/python/shortfin_apps/llm/components/invocation.py +++ b/shortfin/python/shortfin_apps/llm/components/invocation.py @@ -270,7 +270,7 @@ async def prepare_args( buffers.extend([seq_lens_allocation, seq_block_ids_allocation]) data.extend([seq_lens_data, seq_block_ids_data]) - defaults.extend([1, 0]) + defaults.extend([0, 0]) args = create_argument_buffers( buffers=buffers, @@ -372,7 +372,118 @@ async def prepare_args( start_positions, seq_block_ids_data, ], - defaults=[0, 1, 0, 0], + defaults=[0, 0, 0, 0], + ) + + for page_table in self._page_tables: + args.append(WrappedAllocation(sfnp.disable_barrier(page_table))) + + return args + + +class ExtendAttentionPrefillTask(PrefillTask): + """Prefill task that supports extend-attention with varying sequence lengths. + + This task is designed for extend-attention batching where requests can have + different sequence lengths within a batch. It properly handles offset prefill + by using the parent class's methods for calculating batch sequence length and + block counts. + """ + + def __init__( + self, + task_inputs: List[LlmTaskInput], + array_cache: DeviceArrayCache, + page_tables: List[sfnp.device_array], + has_prefill_position: bool, + block_seq_stride: int, + ): + super().__init__( + task_inputs=task_inputs, + array_cache=array_cache, + page_tables=page_tables, + seq_stride=block_seq_stride, + has_prefill_position=has_prefill_position, + chunk_block_size=None, + ) + + async def prepare_args(self, batch_size: int) -> List[sfnp.device_array]: + """Prepare arguments for extend-attention prefill. + + Each request's tokens are divided into pages of + block_seq_stride size. Each page can track its own history, allowing + efficient batching of variable-length sequences. + """ + task_inputs = self._task_inputs + + # Prepare the batch with page-aligned tokens + tokens = [] + seq_lens = [] + page_ids = [] + start_positions = [] + + for task_input in task_inputs: + # Each task's tokens are organized by pages + task_tokens = list(task_input.input_tokens) + + # Pad each sequence to page boundaries + tokens.append(task_tokens) + seq_lens.append(task_input.seq_len) + page_ids.append(list(task_input.page_ids)) + if self._has_prefill_position: + start_positions.append(task_input.start_position) + + batch_seq_len = self._get_batch_seq_len(task_inputs) + max_blocks = self._get_block_count(batch_seq_len, task_inputs) + + logger.debug( + f"ExtendAttention Prefill bs={batch_size}, " + f"batch_seq_len={batch_seq_len}, " + f"max_blocks={max_blocks}, seq_stride={self._seq_stride}" + ) + + array_cache = self._array_cache + int_dtype = sfnp.int64 + + # Allocate buffers + tokens_allocation = array_cache.allocate([batch_size, batch_seq_len], int_dtype) + seq_lens_allocation = array_cache.allocate([batch_size], int_dtype) + seq_block_ids_allocation = array_cache.allocate( + [batch_size, max_blocks], int_dtype + ) + + tokens_data = list( + chain.from_iterable(_pad_list(t, batch_seq_len) for t in tokens) + ) + + seq_block_ids_data = list( + chain.from_iterable( + _pad_list(pages, target_length=max_blocks) for pages in page_ids + ) + ) + + buffers = [tokens_allocation] + data = [tokens_data] + defaults = [0] + + if self._has_prefill_position: + start_positions_allocation = array_cache.allocate([batch_size], int_dtype) + buffers.append(start_positions_allocation) + data.append(start_positions) + defaults.append(0) + + buffers.extend([seq_lens_allocation, seq_block_ids_allocation]) + data.extend([seq_lens, seq_block_ids_data]) + defaults.extend([0, 0]) + + logger.info( + f"ExtendAttention Prefill: batch_size={batch_size}, actual_requests={len(task_inputs)}, seq_lens={seq_lens}, defaults={defaults}" + ) + + args = create_argument_buffers( + buffers=buffers, + data=data, + defaults=defaults, ) for page_table in self._page_tables: diff --git a/shortfin/python/shortfin_apps/llm/components/lifecycle.py b/shortfin/python/shortfin_apps/llm/components/lifecycle.py index d90abee143f..274e149873a 100644 --- a/shortfin/python/shortfin_apps/llm/components/lifecycle.py +++ b/shortfin/python/shortfin_apps/llm/components/lifecycle.py @@ -151,6 +151,19 @@ def _validate_initialization_args( "Export from `sharktank` with `--has-prefill-position` for full trie prefix sharing benefits." ) + batch_mode = server_params.batch_mode + use_extend_attention = model_params.use_extend_attention + if batch_mode == "extend_attention" and not use_extend_attention: + logger.error( + "INCOMPATIBLE SERVER CONFIGURATION: batch_mode is set to 'extend_attention', " + "but the model was not exported with extend-attention support.\n" + "Export from `sharktank` with `--use-extend-attention` to use extend-attention batching." + ) + raise ValueError( + "Incompatible server configuration. " + "Extend-attention batch mode requested, but model not exported with `--use-extend-attention`." + ) + @asynccontextmanager async def fastapi_lifespan(self, app: FastAPI): """ diff --git a/shortfin/python/shortfin_apps/llm/components/scheduler.py b/shortfin/python/shortfin_apps/llm/components/scheduler.py index 1c0a0a9d4fc..3a24c9001a4 100644 --- a/shortfin/python/shortfin_apps/llm/components/scheduler.py +++ b/shortfin/python/shortfin_apps/llm/components/scheduler.py @@ -7,7 +7,8 @@ from abc import ABC, abstractmethod import itertools import logging -from typing import Dict, List +import math +from typing import Dict, List, Set import shortfin as sf from .invocation import LlmTaskInput @@ -378,3 +379,167 @@ def handle_completed(self, rid: str) -> bool: next_chunk = self._pending[rid].pop(0) self._ready.append(next_chunk) return False + + +class ExtendAttentionScheduler(AbstractScheduler): + """Scheduler for extend-attention batching with dynamic chunking. + + This scheduler manages requests that are dynamically chunked based on the + number of active requests and a token budget. Each request is processed + in chunks, with chunk sizes calculated to maximize GPU utilization while + respecting page alignment constraints. + """ + + def __init__(self, *, token_budget: int, block_seq_stride: int): + # Pass dummy ideal_batch_size to parent - not used in extend attention + super().__init__(ideal_batch_size=1) + self._block_seq_stride = block_seq_stride + self._token_budget = token_budget + # Track active requests (full task inputs with all tokens) + self._active_requests: Dict[str, LlmTaskInput] = {} + # Track current position (token offset) for each request + self._request_positions: Dict[str, int] = {} + # Track the chunk size used for each request in the last execution + self._last_chunk_sizes: Dict[str, int] = {} + # Track requests that are currently executing to prevent double-scheduling + self._in_flight: Set[str] = set() + + def schedule_job(self, task: LlmTaskInput): + """Add a request to the scheduler. + + The task contains all tokens for the request. We'll dynamically chunk it + at scheduling time based on the number of active requests. + """ + rid = task.rid + if rid in self._active_requests: + logger.warning( + f"Request {rid} is already scheduled. Ignoring duplicate schedule_job call." + ) + return + + # New request - store it and initialize position from task + # (may not be 0 with trie_prefix_matching) + self._active_requests[rid] = task + self._request_positions[rid] = task.start_position + + def should_execute(self, strobe) -> List[List[LlmTaskInput]]: + """Determine which tasks should be executed now. + + Dynamically chunks active requests based on the number of requests and token budget. + Each request gets a page-aligned chunk that fits within the budget. + """ + if not self._active_requests: + return [] + + # Calculate dynamic chunk size based on active requests that are NOT in-flight + available_requests = { + rid: task + for rid, task in self._active_requests.items() + if rid not in self._in_flight + } + + if not available_requests: + # All requests are currently executing + return [] + + num_active = len(available_requests) + tokens_per_request = self._token_budget // num_active + # Align to page boundaries + chunk_size = ( + tokens_per_request // self._block_seq_stride + ) * self._block_seq_stride + + if chunk_size == 0: + # Too many requests for the budget - shouldn't happen but handle gracefully + chunk_size = self._block_seq_stride + + # Create chunks for this batch + batch = [] + for rid, full_task in available_requests.items(): + position = self._request_positions[rid] + all_tokens = full_task.input_tokens + + # Determine how many tokens to take + remaining_tokens = len(all_tokens) - position + tokens_to_take = min(chunk_size, remaining_tokens) + + if tokens_to_take <= 0: + continue + + # Create a chunk from current position + chunk_tokens = all_tokens[position : position + tokens_to_take] + + logger.info( + f"ExtendAttentionScheduler: rid={rid}, position={position}, tokens_to_take={tokens_to_take}, chunk_tokens={chunk_tokens}, all_tokens={all_tokens}" + ) + + # Calculate cumulative seq_len and block_count + cumulative_seq_len = position + len(chunk_tokens) + chunk_block_count = math.ceil(cumulative_seq_len / self._block_seq_stride) + + # Get page_ids up to the current block count + chunk_page_ids = full_task.page_ids[:chunk_block_count] + + # Store the actual chunk size used for this request + self._last_chunk_sizes[rid] = tokens_to_take + + # Mark this request as in-flight + self._in_flight.add(rid) + + # Create the chunk task input + chunk_task = LlmTaskInput( + rid=rid, + instance_id=full_task.instance_id, + block_count=chunk_block_count, + seq_len=cumulative_seq_len, + input_tokens=chunk_tokens, + page_ids=chunk_page_ids, + start_position=position, + ) + batch.append(chunk_task) + + return [batch] if batch else [] + + def handle_scheduler(self, msg) -> bool: + # Handle scheduler messages + return False + + def reserve_workload(self, *, batcher, count, rid): + # Handle workload reservation + pass + + def handle_completed(self, rid: str) -> bool: + """Handle completion of a chunk. + + Updates the position for this request and determines if more tokens remain. + + Returns True if the request is fully complete (no more tokens). + Returns False if there are more tokens to process. + """ + assert ( + rid in self._active_requests + ), f"Request {rid} not found in active requests" + + # Remove from in-flight set to allow next chunk to be scheduled + self._in_flight.discard(rid) + + full_task = self._active_requests[rid] + current_position = self._request_positions[rid] + + # Get the chunk size that was actually used for this request in the last execution + tokens_processed = self._last_chunk_sizes.get(rid, 0) + new_position = current_position + tokens_processed + + # Update position + self._request_positions[rid] = new_position + + # Check if we've processed all tokens + if new_position >= len(full_task.input_tokens): + # Request complete - remove from active requests + del self._active_requests[rid] + del self._request_positions[rid] + del self._last_chunk_sizes[rid] + return True # Request fully complete + + # More tokens to process + return False diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 523dfd8b6fd..2f11656f9da 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -117,13 +117,22 @@ def start(self): modules=component_modules, devices=self.sysman.ls.devices ) self.initialize_function_references() + + # Determine batch mode from server config + # Validation is done in lifecycle.py _validate_initialization_args + if self.server_params.batch_mode == "extend_attention": + batch_mode = BatchMode.EXTEND_ATTENTION + else: + batch_mode = BatchMode.DEFAULT + batch_cfg = BatchConfig( - BatchMode.DEFAULT, + batch_mode, self.model_params, self.prefill_functions, self.decode_functions, self.prog_isolation, self.server_params.chunk_block_size, + self.server_params.token_budget, ) self.unified_batcher = BatchingFacade.build_batcher( batch_cfg, self.page_cache, self.prefill_fiber, self.decode_fiber @@ -138,9 +147,12 @@ def shutdown(self): def initialize_function_references(self): self.prefill_functions = {} for bs in self.model_params.prefill_batch_sizes: - self.prefill_functions[bs] = self.inference_program[ - f"{self.model_params.module_name}.prefill_bs{bs}" - ] + # Use different function names for extend-attention mode + if self.model_params.use_extend_attention: + function_name = f"{self.model_params.module_name}.prefill_extend_bs{bs}" + else: + function_name = f"{self.model_params.module_name}.prefill_bs{bs}" + self.prefill_functions[bs] = self.inference_program[function_name] # Resolve decode entrypoints. self.decode_functions = {} for bs in self.model_params.decode_batch_sizes: diff --git a/shortfin/python/shortfin_apps/llm/server.py b/shortfin/python/shortfin_apps/llm/server.py index 800f7fc7b93..5f1ac3b4894 100644 --- a/shortfin/python/shortfin_apps/llm/server.py +++ b/shortfin/python/shortfin_apps/llm/server.py @@ -17,6 +17,7 @@ from .application import get_app from .components.lifecycle import ShortfinLlmLifecycleManager +from .components.batching.config import BatchMode from ..utils import get_system_args logger = logging.getLogger(__name__) @@ -101,6 +102,18 @@ def add_service_args(parser: argparse.ArgumentParser): choices=["none", "trie"], help="Algorithm to use for prefix sharing in KV cache", ) + parser.add_argument( + "--batch_mode", + type=str, + choices=[mode.name.lower() for mode in BatchMode], + help="Batching mode to use. 'extend_attention' requires model exported with --use-extend-attention", + ) + parser.add_argument( + "--token_budget", + type=int, + default=1024, + help="Token budget to use for extend_attention mode.", + ) parser.add_argument( "--num_beams", type=int, diff --git a/shortfin/tests/apps/llm/components/extend_attention_test.py b/shortfin/tests/apps/llm/components/extend_attention_test.py new file mode 100644 index 00000000000..b63efa0286e --- /dev/null +++ b/shortfin/tests/apps/llm/components/extend_attention_test.py @@ -0,0 +1,750 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import asyncio +import logging +import math +import pytest + +import shortfin.array as sfnp + +from typing import List +from unittest.mock import patch +from uuid import uuid4 + +from shortfin_apps.llm.components.scheduler import ExtendAttentionScheduler +from shortfin_apps.llm.components.invocation import ExtendAttentionPrefillTask +from shortfin_apps.llm.components.config_struct import ModelParams, PagedKVCacheParams +from shortfin_apps.llm.components.device_array_cache import ( + Allocation, + WrappedAllocation, +) +from shortfin_apps.llm.components.invocation import LlmTaskInput +from shortfin_apps.llm.components.kvcache.attention_cache_abstract import CacheInfo +from shortfin_apps.llm.components.kvcache.page_pool import PageInfo +from shortfin_apps.llm.components.messages import ( + LlmInferenceExecRequest, + InferencePhase, +) + + +logger = logging.getLogger(__name__) + + +class MockVoidFuture: + def __init__(self): + self._event = asyncio.Event() + + def set_success(self): + self._event.set() + + def __await__(self): + return self._event.wait().__await__() + + +@pytest.fixture +def model_params(): + return ModelParams( + max_seq_len=2048, + transformer_block_count=32, + attn_head_dim=128, + prefill_batch_sizes=[1, 2, 4], + decode_batch_sizes=[1, 2, 4], + paged_kv_cache=PagedKVCacheParams( + block_seq_stride=16, + attention_head_count_kv=8, + device_block_count=256, + kv_cache_dtype=sfnp.float16, + ), + ) + + +@pytest.fixture(scope="function") +def extend_attention_exec_requests(cache_ref_count, page_pool): + """Create exec requests with varying token lengths for extend attention testing.""" + with patch( + "shortfin_apps.llm.components.messages.sf.VoidFuture", new=MockVoidFuture + ): + exec_reqs = [] + token_lengths = [64, 128, 256, 512] # Different lengths to test batching + + page_offset = 0 + for idx, token_len in enumerate(token_lengths): + input_tokens = [i + idx * 1000 for i in range(token_len)] + exec_req = LlmInferenceExecRequest( + phase=InferencePhase.PREFILL, + input_token_ids=input_tokens, + rid=str(uuid4()), + ) + exec_reqs.append(exec_req) + + # Allocate pages for this request + exec_req._cache = cache_ref_count + pages = [ + PageInfo(index=page_offset + i, pool=page_pool) + for i in range(math.ceil(len(input_tokens) / 16)) + ] + exec_req.allocated_cache_info = CacheInfo( + num_tokens=len(exec_req.input_token_ids), + tokens=exec_req.input_token_ids, + pages=pages, + pool=page_pool, + last_cached_node=None, + ) + exec_req.page_ids = [page.index for page in pages] + page_offset += len(pages) + + yield exec_reqs + + +def _get_extend_attention_task_inputs( + exec_requests: List[LlmInferenceExecRequest], +) -> List[LlmTaskInput]: + """Convert exec requests to task inputs for extend attention.""" + task_inputs = [] + for req in exec_requests: + task_inputs.append( + LlmTaskInput( + rid=req.orig_instance_id, + instance_id=req.instance_id, + block_count=req.block_count, + input_tokens=tuple(req.input_token_ids), + seq_len=len(req.input_token_ids), + page_ids=tuple(req.page_ids), + start_position=0, + ) + ) + return task_inputs + + +@pytest.fixture(scope="function") +def extend_attention_prefill_task( + extend_attention_exec_requests, device_array_cache, page_pool +) -> ExtendAttentionPrefillTask: + """Fixture to create an ExtendAttentionPrefillTask.""" + page_tables = page_pool.acquire_free_pages(len(extend_attention_exec_requests)) + task_inputs = _get_extend_attention_task_inputs(extend_attention_exec_requests) + return ExtendAttentionPrefillTask( + task_inputs=task_inputs, + array_cache=device_array_cache, + page_tables=page_tables, + has_prefill_position=False, + block_seq_stride=16, + ) + + +@pytest.fixture(scope="function") +def extend_attention_prefill_task_w_start_pos( + extend_attention_exec_requests, device_array_cache, page_pool +) -> ExtendAttentionPrefillTask: + """Fixture to create an ExtendAttentionPrefillTask with start positions.""" + page_tables = page_pool.acquire_free_pages(len(extend_attention_exec_requests)) + task_inputs = _get_extend_attention_task_inputs(extend_attention_exec_requests) + + # Set different start positions to test offset prefill + for i, task_input in enumerate(task_inputs): + task_inputs[i] = LlmTaskInput( + rid=task_input.rid, + instance_id=task_input.instance_id, + block_count=task_input.block_count, + input_tokens=task_input.input_tokens, + seq_len=task_input.seq_len, + page_ids=task_input.page_ids, + start_position=i * 16, # Different start positions + ) + + return ExtendAttentionPrefillTask( + task_inputs=task_inputs, + array_cache=device_array_cache, + page_tables=page_tables, + has_prefill_position=True, + block_seq_stride=16, + ) + + +class TestExtendAttentionScheduler: + """Tests for ExtendAttentionScheduler - dynamic chunking logic.""" + + def test_initialization(self): + """Test scheduler initializes with correct parameters.""" + scheduler = ExtendAttentionScheduler(token_budget=1024, block_seq_stride=16) + assert scheduler._token_budget == 1024 + assert scheduler._block_seq_stride == 16 + assert len(scheduler._active_requests) == 0 + assert len(scheduler._request_positions) == 0 + assert len(scheduler._last_chunk_sizes) == 0 + + def test_schedule_job_new_request(self): + """Test scheduling a new job.""" + scheduler = ExtendAttentionScheduler(token_budget=1024, block_seq_stride=16) + task = LlmTaskInput( + rid="req1", + instance_id=0, + block_count=10, + seq_len=160, + input_tokens=tuple(range(160)), + page_ids=tuple(range(10)), + start_position=0, + ) + + scheduler.schedule_job(task) + assert "req1" in scheduler._active_requests + assert scheduler._request_positions["req1"] == 0 + + def test_schedule_job_with_start_position(self): + """Test scheduling a job with non-zero start position (trie prefix matching).""" + scheduler = ExtendAttentionScheduler(token_budget=1024, block_seq_stride=16) + task = LlmTaskInput( + rid="req1", + instance_id=0, + block_count=10, + seq_len=160, + input_tokens=tuple(range(160)), + page_ids=tuple(range(10)), + start_position=32, + ) + + scheduler.schedule_job(task) + assert "req1" in scheduler._active_requests + assert scheduler._request_positions["req1"] == 32 + + def test_schedule_duplicate_job(self, caplog): + """Test scheduling a duplicate job logs warning.""" + scheduler = ExtendAttentionScheduler(token_budget=1024, block_seq_stride=16) + task = LlmTaskInput( + rid="req1", + instance_id=0, + block_count=10, + seq_len=160, + input_tokens=tuple(range(160)), + page_ids=tuple(range(10)), + start_position=0, + ) + + scheduler.schedule_job(task) + with caplog.at_level(logging.WARNING): + scheduler.schedule_job(task) + + assert "already scheduled" in caplog.text + + def test_dynamic_chunk_calculation_single_request(self): + """Test dynamic chunk size with single request gets full budget.""" + scheduler = ExtendAttentionScheduler(token_budget=1024, block_seq_stride=16) + task = LlmTaskInput( + rid="req1", + instance_id=0, + block_count=128, + seq_len=2048, + input_tokens=tuple(range(2048)), + page_ids=tuple(range(128)), + start_position=0, + ) + + scheduler.schedule_job(task) + batches = scheduler.should_execute(strobe=0) + + assert len(batches) == 1 + assert len(batches[0]) == 1 + assert len(batches[0][0].input_tokens) == 1024 # Full budget + + def test_dynamic_chunk_calculation_two_requests(self): + """Test dynamic chunk size with two requests splits budget.""" + scheduler = ExtendAttentionScheduler(token_budget=1024, block_seq_stride=16) + + for i in range(2): + task = LlmTaskInput( + rid=f"req{i}", + instance_id=i, + block_count=128, + seq_len=2048, + input_tokens=tuple(range(2048)), + page_ids=tuple(range(128)), + start_position=0, + ) + scheduler.schedule_job(task) + + batches = scheduler.should_execute(strobe=0) + + assert len(batches) == 1 + assert len(batches[0]) == 2 + # Each request gets 512 tokens (1024 / 2, page-aligned) + assert all(len(chunk.input_tokens) == 512 for chunk in batches[0]) + + def test_page_alignment(self): + """Test that chunk sizes are page-aligned.""" + scheduler = ExtendAttentionScheduler(token_budget=1000, block_seq_stride=16) + + for i in range(3): + task = LlmTaskInput( + rid=f"req{i}", + instance_id=i, + block_count=128, + seq_len=2048, + input_tokens=tuple(range(2048)), + page_ids=tuple(range(128)), + start_position=0, + ) + scheduler.schedule_job(task) + + batches = scheduler.should_execute(strobe=0) + + # 1000 / 3 = 333.33 -> 320 (20 pages * 16) + for chunk in batches[0]: + assert len(chunk.input_tokens) % 16 == 0 + + def test_handle_completed_advances_position(self): + """Test that handle_completed advances position correctly.""" + scheduler = ExtendAttentionScheduler(token_budget=512, block_seq_stride=16) + task = LlmTaskInput( + rid="req1", + instance_id=0, + block_count=64, + seq_len=1024, + input_tokens=tuple(range(1024)), + page_ids=tuple(range(64)), + start_position=0, + ) + + scheduler.schedule_job(task) + + # First execution + batches = scheduler.should_execute(strobe=0) + assert batches[0][0].start_position == 0 + assert len(batches[0][0].input_tokens) == 512 + + # Complete first chunk + is_complete = scheduler.handle_completed("req1") + assert is_complete is False + assert scheduler._request_positions["req1"] == 512 + + # Second execution + batches = scheduler.should_execute(strobe=0) + assert batches[0][0].start_position == 512 + assert len(batches[0][0].input_tokens) == 512 + + # Complete second chunk + is_complete = scheduler.handle_completed("req1") + assert is_complete is True + assert "req1" not in scheduler._active_requests + + def test_handle_completed_nonexistent_request(self): + """Test handle_completed asserts on non-existent request.""" + scheduler = ExtendAttentionScheduler(token_budget=1024, block_seq_stride=16) + + with pytest.raises(AssertionError): + scheduler.handle_completed("nonexistent") + + def test_cumulative_metadata(self): + """Test that seq_len and block_count are cumulative across chunks.""" + scheduler = ExtendAttentionScheduler(token_budget=512, block_seq_stride=16) + task = LlmTaskInput( + rid="req1", + instance_id=0, + block_count=64, + seq_len=1024, + input_tokens=tuple(range(1024)), + page_ids=tuple(range(64)), + start_position=0, + ) + + scheduler.schedule_job(task) + + # First chunk + batches = scheduler.should_execute(strobe=0) + chunk1 = batches[0][0] + assert chunk1.seq_len == 512 + assert chunk1.block_count == 32 # ceil(512 / 16) + + scheduler.handle_completed("req1") + + # Second chunk + batches = scheduler.should_execute(strobe=0) + chunk2 = batches[0][0] + assert chunk2.seq_len == 1024 # Cumulative + assert chunk2.block_count == 64 # ceil(1024 / 16) + + def test_request_shorter_than_chunk(self): + """Test request with fewer tokens than allocated chunk size.""" + scheduler = ExtendAttentionScheduler(token_budget=1024, block_seq_stride=16) + task = LlmTaskInput( + rid="req1", + instance_id=0, + block_count=19, + seq_len=300, + input_tokens=tuple(range(300)), + page_ids=tuple(range(19)), + start_position=0, + ) + scheduler.schedule_job(task) + + # Should get all 300 tokens (less than budget) + batches = scheduler.should_execute(strobe=0) + assert len(batches[0][0].input_tokens) == 300 + + # Complete - should be done + is_complete = scheduler.handle_completed("req1") + assert is_complete is True + + def test_dynamic_chunk_size_adjustment(self): + """Test that chunk size adjusts as requests complete.""" + scheduler = ExtendAttentionScheduler(token_budget=1024, block_seq_stride=16) + # req1 is long, req2 is short (will complete in first batch) + task1 = LlmTaskInput( + rid="req1", + instance_id=0, + block_count=125, + seq_len=2000, + input_tokens=tuple(range(2000)), + page_ids=tuple(range(125)), + start_position=0, + ) + task2 = LlmTaskInput( + rid="req2", + instance_id=1, + block_count=25, + seq_len=400, + input_tokens=tuple(range(400)), + page_ids=tuple(range(25)), + start_position=0, + ) + + scheduler.schedule_job(task1) + scheduler.schedule_job(task2) + + # First batch: 512 tokens each (budget / 2, page-aligned) + batches1 = scheduler.should_execute(strobe=0) + assert len(batches1[0]) == 2 + # req1 gets 512, req2 gets all 400 (less than allocated 512) + req1_chunk = [c for c in batches1[0] if c.rid == "req1"][0] + req2_chunk = [c for c in batches1[0] if c.rid == "req2"][0] + assert len(req1_chunk.input_tokens) == 512 + assert len(req2_chunk.input_tokens) == 400 + + # Complete both chunks - req2 is done, req1 has more + is_complete_req1 = scheduler.handle_completed("req1") + is_complete_req2 = scheduler.handle_completed("req2") + assert is_complete_req1 is False # More tokens remain + assert is_complete_req2 is True # Request complete + + # Second batch: only req1 active - should get full budget (1024 tokens) + batches2 = scheduler.should_execute(strobe=0) + assert len(batches2[0]) == 1 + assert batches2[0][0].rid == "req1" + assert len(batches2[0][0].input_tokens) == 1024 + assert batches2[0][0].start_position == 512 + + def test_empty_scheduler(self): + """Test getting batch from empty scheduler.""" + scheduler = ExtendAttentionScheduler(token_budget=1024, block_seq_stride=16) + batches = scheduler.should_execute(strobe=0) + assert len(batches) == 0 + + def test_page_ids_grow_with_chunks(self): + """Test that page_ids include all blocks up to current position.""" + scheduler = ExtendAttentionScheduler(token_budget=256, block_seq_stride=16) + task = LlmTaskInput( + rid="req1", + instance_id=0, + block_count=63, + seq_len=1000, + input_tokens=tuple(range(1000)), + page_ids=tuple(range(63)), + start_position=0, + ) + scheduler.schedule_job(task) + + # First chunk + batches1 = scheduler.should_execute(strobe=0) + chunk1 = batches1[0][0] + assert len(chunk1.page_ids) == chunk1.block_count + + # Complete and get second chunk + scheduler.handle_completed("req1") + batches2 = scheduler.should_execute(strobe=0) + chunk2 = batches2[0][0] + assert len(chunk2.page_ids) == chunk2.block_count + # Second chunk should have more pages than first + assert chunk2.block_count > chunk1.block_count + + +class TestExtendAttentionPrefillTask: + """Tests for ExtendAttentionPrefillTask - argument preparation.""" + + def test_initialization( + self, extend_attention_prefill_task: ExtendAttentionPrefillTask + ): + """Test task initializes correctly.""" + assert extend_attention_prefill_task._seq_stride == 16 + assert extend_attention_prefill_task.req_count == 4 + + def test_prepare_args_structure( + self, lsys, extend_attention_prefill_task: ExtendAttentionPrefillTask + ): + """Test that prepare_args returns correct structure.""" + + async def _test(): + args = await extend_attention_prefill_task.prepare_args(batch_size=4) + + # Should have tokens, seq_lens, seq_block_ids, plus page tables + assert len(args) >= 3 + assert all(isinstance(arg, Allocation) for arg in args[:3]) + assert all(isinstance(arg, WrappedAllocation) for arg in args[3:]) + + lsys.run(_test()) + + def test_prepare_args_shapes( + self, + lsys, + extend_attention_prefill_task: ExtendAttentionPrefillTask, + extend_attention_exec_requests, + ): + """Test that prepared arguments have correct shapes.""" + + async def _test(): + batch_size = len(extend_attention_exec_requests) + args = await extend_attention_prefill_task.prepare_args( + batch_size=batch_size + ) + + # Tokens should be [batch_size, max_seq_len] + assert args[0].shape[0] == batch_size + # Should be page-aligned + assert args[0].shape[1] % 16 == 0 + + # Seq lens should be [batch_size] + assert args[1].shape[0] == batch_size + + # Seq block ids should be [batch_size, max_blocks] + assert args[2].shape[0] == batch_size + + lsys.run(_test()) + + def test_prepare_args_with_start_position( + self, + lsys, + extend_attention_prefill_task_w_start_pos: ExtendAttentionPrefillTask, + ): + """Test prepare_args with start positions included.""" + + async def _test(): + args = await extend_attention_prefill_task_w_start_pos.prepare_args( + batch_size=4 + ) + + # Should have tokens, start_positions, seq_lens, seq_block_ids + assert len(args) >= 4 + + # Start positions should be [batch_size] + assert args[1].shape[0] == 4 + + # Verify start positions match what we set + start_positions = args[1].host.items.tolist() + assert start_positions == [0, 16, 32, 48] + + lsys.run(_test()) + + def test_padding_to_page_boundaries( + self, + lsys, + extend_attention_prefill_task: ExtendAttentionPrefillTask, + ): + """Test that sequences are padded to page boundaries.""" + + async def _test(): + args = await extend_attention_prefill_task.prepare_args(batch_size=4) + + tokens_alloc = args[0] + max_seq_len = tokens_alloc.shape[1] + + # Max seq len should be divisible by block_seq_stride + assert max_seq_len % 16 == 0 + + # Should accommodate the longest sequence (512 tokens) + # which needs 32 pages, so 32 * 16 = 512 + assert max_seq_len >= 512 + + lsys.run(_test()) + + def test_varying_sequence_lengths( + self, + lsys, + extend_attention_prefill_task: ExtendAttentionPrefillTask, + extend_attention_exec_requests, + ): + """Test handling of requests with varying sequence lengths.""" + + async def _test(): + args = await extend_attention_prefill_task.prepare_args(batch_size=4) + + seq_lens_alloc = args[1] + seq_lens = seq_lens_alloc.host.items.tolist() + + # Should match the input token lengths + expected_lens = [ + len(req.input_token_ids) for req in extend_attention_exec_requests + ] + assert seq_lens == expected_lens + + lsys.run(_test()) + + def test_max_blocks_calculation( + self, + lsys, + extend_attention_prefill_task: ExtendAttentionPrefillTask, + extend_attention_exec_requests, + ): + """Test that max_blocks is calculated correctly.""" + + async def _test(): + args = await extend_attention_prefill_task.prepare_args(batch_size=4) + + seq_block_ids_alloc = args[2] + max_blocks = seq_block_ids_alloc.shape[1] + + # Should match the maximum block_count across all requests + expected_max_blocks = max( + req.block_count for req in extend_attention_exec_requests + ) + assert max_blocks == expected_max_blocks + + lsys.run(_test()) + + def test_offset_prefill_padding( + self, + lsys, + device_array_cache, + page_pool, + ): + """Test that seq_block_ids are padded correctly for offset prefill. + + This tests the edge case mentioned in PR review: + Request A: 96 tokens, start_position=0, 6 pages + Request B: 64 tokens, start_position=96, 5 pages + + max_blocks should be 6 (max_block_start) + 4 (write_block_span) = 10 + """ + + async def _test(): + # Request A: 96 tokens from position 0 + task1 = LlmTaskInput( + rid="reqA", + instance_id="0", + block_count=6, # 96 / 16 = 6 + seq_len=96, + input_tokens=tuple(range(96)), + page_ids=tuple(range(6)), + start_position=0, + ) + + # Request B: 64 tokens from position 96 (offset prefill) + task2 = LlmTaskInput( + rid="reqB", + instance_id="1", + block_count=10, # (96 + 64) / 16 = 10 total + seq_len=160, # 96 + 64 + input_tokens=tuple(range(64)), + page_ids=tuple(range(10)), # All pages including history + start_position=96, + ) + + page_tables = page_pool.acquire_free_pages(2) + task = ExtendAttentionPrefillTask( + task_inputs=[task1, task2], + array_cache=device_array_cache, + page_tables=page_tables, + has_prefill_position=True, + block_seq_stride=16, + ) + + args = await task.prepare_args(batch_size=2) + + # seq_block_ids should have shape [2, max_blocks] + # max_blocks = max_block_start (6) + write_block_span (6) = 12 + # write_block_span = max_seq_len / 16 = 96 / 16 = 6 + seq_block_ids_alloc = args[3] + assert seq_block_ids_alloc.shape[0] == 2 # batch size + # max_blocks should accommodate both the max start position and the write span + assert seq_block_ids_alloc.shape[1] >= 10 # At least 10 blocks for reqB + + lsys.run(_test()) + + +class TestExtendAttentionIntegration: + """Integration tests combining scheduler and task.""" + + def test_full_prefill_workflow( + self, lsys, device_array_cache, page_pool, cache_ref_count + ): + """Test complete workflow: schedule -> execute -> complete.""" + + async def _test(): + scheduler = ExtendAttentionScheduler(token_budget=256, block_seq_stride=16) + + # Create two requests + request_ids = [] + for i in range(2): + input_tokens = [j + i * 1000 for j in range(512)] + req = LlmInferenceExecRequest( + phase=InferencePhase.PREFILL, + input_token_ids=input_tokens, + rid=f"req{i}", + ) + req._cache = cache_ref_count + pages = [PageInfo(index=j + i * 32, pool=page_pool) for j in range(32)] + req.allocated_cache_info = CacheInfo( + num_tokens=len(input_tokens), + tokens=input_tokens, + pages=pages, + pool=page_pool, + last_cached_node=None, + ) + req.page_ids = [page.index for page in pages] + + # Schedule job using the actual request ID + task_input = LlmTaskInput( + rid=req.orig_instance_id, + instance_id=req.instance_id, + block_count=req.block_count, + input_tokens=tuple(req.input_token_ids), + seq_len=len(req.input_token_ids), + page_ids=tuple(req.page_ids), + start_position=0, + ) + scheduler.schedule_job(task_input) + request_ids.append(req.orig_instance_id) + + # First batch - both requests get 128 tokens (256 / 2, page-aligned) + batches = scheduler.should_execute(strobe=0) + assert len(batches[0]) == 2 + assert all(len(chunk.input_tokens) == 128 for chunk in batches[0]) + + # Create task from batch + page_tables = page_pool.acquire_free_pages(2) + task = ExtendAttentionPrefillTask( + task_inputs=batches[0], + array_cache=device_array_cache, + page_tables=page_tables, + has_prefill_position=False, + block_seq_stride=16, + ) + + # Prepare arguments + args = await task.prepare_args(batch_size=2) + assert len(args) >= 3 + + # Verify tokens allocation + assert args[0].shape[0] == 2 + + # Complete both chunks using actual request IDs + scheduler.handle_completed(request_ids[0]) + scheduler.handle_completed(request_ids[1]) + + # Second batch - both should continue from position 128 + batches = scheduler.should_execute(strobe=0) + assert len(batches[0]) == 2 + assert all(chunk.start_position == 128 for chunk in batches[0]) + + lsys.run(_test())