Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d2bb0ae
Extend attn prefill
zeeshanhaque21 Oct 15, 2025
f8de260
Merge branch 'main' into extend-attn-shortfin
zeeshanhaque21 Oct 15, 2025
ad46fef
Fixt chunked mode
zeeshanhaque21 Oct 15, 2025
25af208
Add tests
zeeshanhaque21 Oct 15, 2025
638361e
Fix tests
zeeshanhaque21 Oct 15, 2025
03be4cb
precommit fix
zeeshanhaque21 Oct 15, 2025
31e3aed
Merge branch 'main' into extend-attn-shortfin
zeeshanhaque21 Oct 15, 2025
3c082fe
Merge branch 'main' into extend-attn-shortfin
zeeshanhaque21 Oct 15, 2025
730591c
Change chunking strategy to dnamically recompute based on number of c…
zeeshanhaque21 Oct 16, 2025
74ee726
Fix tests
zeeshanhaque21 Oct 16, 2025
3496380
precommit
zeeshanhaque21 Oct 16, 2025
66bce01
cleanup
zeeshanhaque21 Oct 16, 2025
0a0896e
Address PR comments
zeeshanhaque21 Oct 17, 2025
c61915f
Refactor scheduler and prefill task
zeeshanhaque21 Oct 17, 2025
10794f0
Add tests for PrefillTask
zeeshanhaque21 Oct 17, 2025
6c15862
Formatting
zeeshanhaque21 Oct 17, 2025
3ad3509
Merge branch 'main' into extend-attn-shortfin
zeeshanhaque21 Oct 17, 2025
759e204
Merge branch 'main' into extend-attn-shortfin
zeeshanhaque21 Oct 27, 2025
9eb0de4
Add parameter
zeeshanhaque21 Oct 27, 2025
249fdf9
Merge branch 'main' into extend-attn-shortfin
zeeshanhaque21 Oct 27, 2025
5446cf1
Modify sharktank to export flags
zeeshanhaque21 Oct 27, 2025
3fdfded
Change min prefill bs to 1 in export
zeeshanhaque21 Oct 27, 2025
b16a5e3
Add debug logs to investigate data corruption
zeeshanhaque21 Oct 27, 2025
709f975
revert back to bs_min of 2 for torch.export
zeeshanhaque21 Oct 27, 2025
eb621cb
Add debug logs
zeeshanhaque21 Oct 27, 2025
856a70a
add use_extend_attention to ServiceConfig & update prefill name
archana-ramalingam Oct 27, 2025
353d8ac
Enable extend attention in default path
archana-ramalingam Oct 28, 2025
13858a5
Merge branch 'main' into update-extend-attn
archana-ramalingam Oct 28, 2025
82bb572
Fix error
archana-ramalingam Oct 28, 2025
7957f6e
Merge branch 'update-extend-attn' of https://github.com/nod-ai/shark-…
archana-ramalingam Oct 28, 2025
c9fdadb
Add debug statements
zeeshanhaque21 Oct 29, 2025
5e245a7
Merge remote-tracking branch 'origin/update-extend-attn' into extend-…
zeeshanhaque21 Oct 29, 2025
1c7bb7b
Merge remote-tracking branch 'origin/main' into extend-attn-shortfin
zeeshanhaque21 Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def generate_batch_prefill(bs: int):

seq_len_dim = seq_len_blocks_dim * llama_config.block_seq_stride

start_pos = torch.empty(bs, dtype=torch.int64)
cache, cache_dynamic_shapes, cache_affinities = model.setup_cache()

dynamic_shapes = {
Expand Down Expand Up @@ -95,6 +94,7 @@ def generate_batch_prefill(bs: int):
if "start_pos" in dynamic_shapes:
dynamic_shapes["start_pos"][0] = extend_bs

start_pos = torch.empty(bs_min, dtype=torch.int64)
seq_block_ids = torch.empty(bs_min, block_dim_min, dtype=torch.int64)
tokens = torch.empty(
bs_min,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

class BatchMode(Enum):
DEFAULT = "Default"
EXTEND_ATTENTION = "ExtendAttention"


@dataclass(slots=True)
Expand All @@ -35,3 +36,4 @@ class BatchConfig:
decode_functions: dict[int, sf.ProgramFunction] # type: ignore
prog_isolation: sf.ProgramIsolation # type: ignore
chunk_block_size: Optional[int] = None
token_budget: Optional[int] = None
11 changes: 11 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/batching/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..kvcache.base_attention_cache import BasePagedAttentionCache
from .batching_trait import BatchingTrait
from .modes.default import DefaultBatchingEngine
from .modes.extend_attention import ExtendAttentionBatchingEngine
from ..messages import LlmInferenceExecRequest


Expand Down Expand Up @@ -61,5 +62,15 @@ def _create_impl(batch_cfg: BatchConfig, page_cache: BasePagedAttentionCache, pr
),
page_cache=page_cache,
)
elif batch_cfg.mode == BatchMode.EXTEND_ATTENTION:
return _BatchingEngineImpl(
ExtendAttentionBatchingEngine.create(
batch_cfg=batch_cfg,
page_cache=page_cache,
prefill_fiber=prefill_fiber,
decode_fiber=decode_fiber,
),
page_cache=page_cache,
)

raise ValueError(f"Unsupported Batching Mode: {batch_cfg.mode}")
Loading
Loading