Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d2bb0ae
Extend attn prefill
zeeshanhaque21 Oct 15, 2025
f8de260
Merge branch 'main' into extend-attn-shortfin
zeeshanhaque21 Oct 15, 2025
ad46fef
Fixt chunked mode
zeeshanhaque21 Oct 15, 2025
25af208
Add tests
zeeshanhaque21 Oct 15, 2025
638361e
Fix tests
zeeshanhaque21 Oct 15, 2025
03be4cb
precommit fix
zeeshanhaque21 Oct 15, 2025
31e3aed
Merge branch 'main' into extend-attn-shortfin
zeeshanhaque21 Oct 15, 2025
3c082fe
Merge branch 'main' into extend-attn-shortfin
zeeshanhaque21 Oct 15, 2025
730591c
Change chunking strategy to dnamically recompute based on number of c…
zeeshanhaque21 Oct 16, 2025
74ee726
Fix tests
zeeshanhaque21 Oct 16, 2025
3496380
precommit
zeeshanhaque21 Oct 16, 2025
66bce01
cleanup
zeeshanhaque21 Oct 16, 2025
0a0896e
Address PR comments
zeeshanhaque21 Oct 17, 2025
c61915f
Refactor scheduler and prefill task
zeeshanhaque21 Oct 17, 2025
10794f0
Add tests for PrefillTask
zeeshanhaque21 Oct 17, 2025
6c15862
Formatting
zeeshanhaque21 Oct 17, 2025
3ad3509
Merge branch 'main' into extend-attn-shortfin
zeeshanhaque21 Oct 17, 2025
759e204
Merge branch 'main' into extend-attn-shortfin
zeeshanhaque21 Oct 27, 2025
9eb0de4
Add parameter
zeeshanhaque21 Oct 27, 2025
249fdf9
Merge branch 'main' into extend-attn-shortfin
zeeshanhaque21 Oct 27, 2025
5446cf1
Modify sharktank to export flags
zeeshanhaque21 Oct 27, 2025
3fdfded
Change min prefill bs to 1 in export
zeeshanhaque21 Oct 27, 2025
b16a5e3
Add debug logs to investigate data corruption
zeeshanhaque21 Oct 27, 2025
709f975
revert back to bs_min of 2 for torch.export
zeeshanhaque21 Oct 27, 2025
eb621cb
Add debug logs
zeeshanhaque21 Oct 27, 2025
856a70a
add use_extend_attention to ServiceConfig & update prefill name
archana-ramalingam Oct 27, 2025
353d8ac
Enable extend attention in default path
archana-ramalingam Oct 28, 2025
13858a5
Merge branch 'main' into update-extend-attn
archana-ramalingam Oct 28, 2025
82bb572
Fix error
archana-ramalingam Oct 28, 2025
7957f6e
Merge branch 'update-extend-attn' of https://github.com/nod-ai/shark-…
archana-ramalingam Oct 28, 2025
c9fdadb
Add debug statements
zeeshanhaque21 Oct 29, 2025
5e245a7
Merge remote-tracking branch 'origin/update-extend-attn' into extend-…
zeeshanhaque21 Oct 29, 2025
1c7bb7b
Merge remote-tracking branch 'origin/main' into extend-attn-shortfin
zeeshanhaque21 Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def generate_batch_prefill(bs: int):

seq_len_dim = seq_len_blocks_dim * llama_config.block_seq_stride

start_pos = torch.empty(bs, dtype=torch.int64)
cache, cache_dynamic_shapes, cache_affinities = model.setup_cache()

dynamic_shapes = {
Expand Down Expand Up @@ -95,6 +94,7 @@ def generate_batch_prefill(bs: int):
if "start_pos" in dynamic_shapes:
dynamic_shapes["start_pos"][0] = extend_bs

start_pos = torch.empty(bs_min, dtype=torch.int64)
seq_block_ids = torch.empty(bs_min, block_dim_min, dtype=torch.int64)
tokens = torch.empty(
bs_min,
Expand All @@ -103,13 +103,17 @@ def generate_batch_prefill(bs: int):
)
seq_lens = torch.empty(bs_min, dtype=torch.int64)

print(f"Exporting prefill_bs{bs}")
# Use different naming for extend-attention mode to avoid confusion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning for adding this change?

function_name = (
f"prefill_extend_bs{bs}" if export_config.use_extend_attention else f"prefill_bs{bs}"
)
print(f"Exporting {function_name}")

if export_config.has_prefill_position:
arg_devices = model.setup_arg_devices(cache_affinities, len(dynamic_shapes))

@fxb.export_program(
name=f"prefill_bs{bs}",
name=function_name,
args=(tokens, start_pos, seq_lens, seq_block_ids, cache),
dynamic_shapes=dynamic_shapes,
arg_device=arg_devices,
Expand All @@ -132,7 +136,7 @@ def _(
arg_devices = model.setup_arg_devices(cache_affinities, len(dynamic_shapes))

@fxb.export_program(
name=f"prefill_bs{bs}",
name=function_name,
args=(tokens, seq_lens, seq_block_ids, cache),
dynamic_shapes=dynamic_shapes,
arg_device=arg_devices,
Expand Down
11 changes: 6 additions & 5 deletions sharktank/sharktank/models/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ class ServiceConfig:
attn_head_dim: int
prefill_batch_sizes: list[int]
has_prefill_position: bool
decode_batch_sizes: list[int]
transformer_block_count: int
logits_normalization: Optional[str]
top_k: Optional[int]
paged_kv_cache: KVCacheConfig
use_extend_attention: bool = False
decode_batch_sizes: list[int] = field(default_factory=list)
transformer_block_count: int = 0
logits_normalization: Optional[str] = None
top_k: Optional[int] = None
paged_kv_cache: Optional[KVCacheConfig] = None

@staticmethod
def load(fp: Path):
Expand Down
1 change: 1 addition & 0 deletions sharktank/sharktank/models/llm/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def build_service_config(
attn_head_dim=hp.attn_head_dim,
prefill_batch_sizes=export_config.bs_prefill,
has_prefill_position=export_config.has_prefill_position,
use_extend_attention=export_config.use_extend_attention,
decode_batch_sizes=export_config.bs_decode,
transformer_block_count=hp.block_count,
logits_normalization=export_config.logits_normalization,
Expand Down
9 changes: 6 additions & 3 deletions shortfin/python/shortfin_apps/llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,12 @@ def send_error(
self.ensure_response()

def _print_response(self, response):
response_data = json.loads(response.decode("utf-8"))
responses_array = response_data["responses"][0].get("responses", [])
print(json.dumps(responses_array, indent=1))
if isinstance(response, bytes):
response_data = json.loads(response.decode("utf-8"))
responses_array = response_data["responses"][0].get("responses", [])
print(json.dumps(responses_array, indent=1))
else:
print(response)

def send_response(self, response):
logger.info(f"{self.name} Sending response")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

class BatchMode(Enum):
DEFAULT = "Default"
EXTEND_ATTENTION = "ExtendAttention"


@dataclass(slots=True)
Expand All @@ -35,3 +36,4 @@ class BatchConfig:
decode_functions: dict[int, sf.ProgramFunction] # type: ignore
prog_isolation: sf.ProgramIsolation # type: ignore
chunk_block_size: Optional[int] = None
token_budget: Optional[int] = None
11 changes: 11 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/batching/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..kvcache.base_attention_cache import BasePagedAttentionCache
from .batching_trait import BatchingTrait
from .modes.default import DefaultBatchingEngine
from .modes.extend_attention import ExtendAttentionBatchingEngine
from ..messages import LlmInferenceExecRequest


Expand Down Expand Up @@ -61,5 +62,15 @@ def _create_impl(batch_cfg: BatchConfig, page_cache: BasePagedAttentionCache, pr
),
page_cache=page_cache,
)
elif batch_cfg.mode == BatchMode.EXTEND_ATTENTION:
return _BatchingEngineImpl(
ExtendAttentionBatchingEngine.create(
batch_cfg=batch_cfg,
page_cache=page_cache,
prefill_fiber=prefill_fiber,
decode_fiber=decode_fiber,
),
page_cache=page_cache,
)

raise ValueError(f"Unsupported Batching Mode: {batch_cfg.mode}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright 2025 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
from typing import List

import shortfin as sf

from shortfin import Fiber

from ..batching_trait import BatchingTrait
from ..config import BatchConfig
from ...config_struct import ModelParams
from ...invocation import (
ExtendAttentionPrefillTask,
LlmInvocationProcess,
LlmTask,
LlmTaskInput,
)
from ...kvcache.base_attention_cache import BasePagedAttentionCache
from ...messages import InferencePhase, LlmInferenceExecRequest
from ...scheduler import ExtendAttentionScheduler

from .default import (
LlmBatcherProcess,
PrefillTaskResponder,
DecodeBatcherProcess,
)

logger = logging.getLogger(__name__)


class ExtendAttentionPrefillBatcherProcess(LlmBatcherProcess):
"""Batcher process optimized for extend-attention prefill."""

STROBE_SHORT_DELAY = 0.065
STROBE_LONG_DELAY = 0.065

def __init__(
self,
fiber: Fiber,
page_cache: BasePagedAttentionCache,
model_params: ModelParams,
prefill_functions: dict[int, sf.ProgramFunction],
program_isolation: str,
token_budget: int,
):
# Use the extend-attention aware scheduler
block_seq_stride = model_params.paged_kv_cache.block_seq_stride

scheduler = ExtendAttentionScheduler(
token_budget=token_budget, block_seq_stride=block_seq_stride
)

llm_task_responder = PrefillTaskResponder(scheduler=scheduler)

# ideal_batch_size - not really important. we can set it to
# maximum number of requests that can be batched together.
ideal_batch_size = token_budget // block_seq_stride

super().__init__(
name="extend_attention_prefill",
fiber=fiber,
page_cache=page_cache,
model_params=model_params,
functions=prefill_functions,
ideal_batch_size=ideal_batch_size,
program_isolation=program_isolation,
scheduler=scheduler,
llm_task_responder=llm_task_responder,
)

def make_task_inputs(
self, exec_request: LlmInferenceExecRequest
) -> List[LlmTaskInput]:
"""Create a single task input containing all tokens.

The scheduler will dynamically chunk this request at scheduling time based
on the number of active requests and the token budget.
"""
total_tokens = len(exec_request.input_token_ids)

# Return a single task with ALL tokens
# The scheduler will chunk it dynamically
return [
LlmTaskInput(
rid=exec_request.orig_instance_id,
instance_id=exec_request.instance_id,
block_count=exec_request.block_count,
seq_len=total_tokens,
input_tokens=tuple(exec_request.input_token_ids),
page_ids=tuple(exec_request.page_ids),
start_position=0
if exec_request.start_position is None
else exec_request.start_position,
)
]

def make_task(
self,
task_inputs: List[LlmTaskInput],
page_cache: BasePagedAttentionCache,
) -> LlmTask:
"""Create an extend-attention aware prefill task."""
return ExtendAttentionPrefillTask(
task_inputs=task_inputs,
array_cache=self.array_cache,
page_tables=page_cache.page_pool.page_tables,
has_prefill_position=self.model_params.has_prefill_position,
block_seq_stride=self.page_seq_stride,
)

def make_invoker(
self,
page_cache: BasePagedAttentionCache,
fiber: Fiber,
task_inputs: list[LlmTaskInput],
) -> LlmInvocationProcess:
"""Create invoker for extend-attention prefill."""
return LlmInvocationProcess(
name="extend_attention_prefill_invocation",
fiber=fiber,
llm_task=self.make_task(task_inputs, page_cache),
functions=self.functions,
program_isolation=self.program_isolation,
responder=self._llm_task_responder,
)


class ExtendAttentionBatchingEngine(BatchingTrait):
"""Batching engine that uses extend-attention for improved prefill batching."""

def __init__(
self,
prefill_lane: ExtendAttentionPrefillBatcherProcess,
decode_lane: DecodeBatcherProcess,
):
self.prefill_lane = prefill_lane
self.decode_lane = decode_lane

def submit(self, request: LlmInferenceExecRequest):
if request.phase == InferencePhase.PREFILL:
self.prefill_lane.submit(request)
elif request.phase == InferencePhase.DECODE:
self.decode_lane.submit(request)
else:
raise ValueError(
"Requested unsupported batching lane: Supported only either prefill or decode."
)

def launch(self):
self.prefill_lane.launch()
self.decode_lane.launch()

def shutdown(self):
self.prefill_lane.shutdown()
self.decode_lane.shutdown()

def reserve_workload(self, rid: str, count: int):
self.decode_lane.reserve_workload(rid=rid, count=count)

def get_model_params(self) -> ModelParams:
return self.prefill_lane.model_params

@staticmethod
def create(
batch_cfg: BatchConfig,
page_cache: BasePagedAttentionCache,
prefill_fiber: sf.Fiber,
decode_fiber: sf.Fiber,
):
"""Create an extend-attention batching engine."""

# Check if the model was exported with extend-attention support
if not batch_cfg.model_params.use_extend_attention:
raise ValueError(
"Model was not exported with extend-attention support. "
"Please export the model with --use-extend-attention flag."
)
assert batch_cfg.token_budget is not None
token_budget = batch_cfg.token_budget

prefill_batcher = ExtendAttentionPrefillBatcherProcess(
fiber=prefill_fiber,
page_cache=page_cache,
model_params=batch_cfg.model_params,
prefill_functions=batch_cfg.prefill_functions,
program_isolation=batch_cfg.prog_isolation,
token_budget=token_budget,
)

decode_batcher = DecodeBatcherProcess(
fiber=decode_fiber,
page_cache=page_cache,
model_params=batch_cfg.model_params,
decode_functions=batch_cfg.decode_functions,
program_isolation=batch_cfg.prog_isolation,
)

return ExtendAttentionBatchingEngine(
prefill_lane=prefill_batcher,
decode_lane=decode_batcher,
)
7 changes: 7 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ class ModelParams:
# Whether the model was exported with `start_positions` for prefill.
has_prefill_position: bool = False

# Whether the model was exported with extend attention support.
use_extend_attention: bool = False

# Name of the IREE module implementing the model.
module_name: str = "module"

Expand Down Expand Up @@ -239,6 +242,10 @@ class ServerParams:
# KV cache configuration
prefix_sharing_algorithm: str = "none" # none or trie

# Batching configuration
batch_mode: str = "default" # default or extend_attention
token_budget: Optional[int] = None # Token budget for extend_attention mode

# Program isolation configuration
program_isolation: str = "per_call"

Expand Down
Loading
Loading