Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def generate_batch_prefill(bs: int):
}

bs_min = bs
prefill_name = f"prefill_bs{bs}"

if export_config.has_prefill_position:
seq_len_blocks_dim_chunked = torch.export.Dim(
Expand All @@ -95,6 +96,9 @@ def generate_batch_prefill(bs: int):
if "start_pos" in dynamic_shapes:
dynamic_shapes["start_pos"][0] = extend_bs

prefill_name = "prefill_bs_extend"
export_config.bs_prefill = None

seq_block_ids = torch.empty(bs_min, block_dim_min, dtype=torch.int64)
tokens = torch.empty(
bs_min,
Expand All @@ -103,13 +107,13 @@ def generate_batch_prefill(bs: int):
)
seq_lens = torch.empty(bs_min, dtype=torch.int64)

print(f"Exporting prefill_bs{bs}")
print(f"Exporting {prefill_name}")

if export_config.has_prefill_position:
arg_devices = model.setup_arg_devices(cache_affinities, len(dynamic_shapes))

@fxb.export_program(
name=f"prefill_bs{bs}",
name=prefill_name,
args=(tokens, start_pos, seq_lens, seq_block_ids, cache),
dynamic_shapes=dynamic_shapes,
arg_device=arg_devices,
Expand All @@ -132,7 +136,7 @@ def _(
arg_devices = model.setup_arg_devices(cache_affinities, len(dynamic_shapes))

@fxb.export_program(
name=f"prefill_bs{bs}",
name=prefill_name,
args=(tokens, seq_lens, seq_block_ids, cache),
dynamic_shapes=dynamic_shapes,
arg_device=arg_devices,
Expand Down
3 changes: 2 additions & 1 deletion sharktank/sharktank/models/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ class ServiceConfig:
max_seq_len: int
attn_head_dim: int
prefill_batch_sizes: list[int]
has_prefill_position: bool
decode_batch_sizes: list[int]
transformer_block_count: int
logits_normalization: Optional[str]
top_k: Optional[int]
paged_kv_cache: KVCacheConfig
has_prefill_position: bool = False
use_extend_attention: bool = False

@staticmethod
def load(fp: Path):
Expand Down
3 changes: 2 additions & 1 deletion sharktank/sharktank/models/llm/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,11 @@ def build_service_config(
max_seq_len=hp.context_length,
attn_head_dim=hp.attn_head_dim,
prefill_batch_sizes=export_config.bs_prefill,
has_prefill_position=export_config.has_prefill_position,
decode_batch_sizes=export_config.bs_decode,
transformer_block_count=hp.block_count,
logits_normalization=export_config.logits_normalization,
top_k=export_config.top_k,
paged_kv_cache=kv_config,
has_prefill_position=export_config.has_prefill_position,
use_extend_attention=export_config.use_extend_attention,
)
5 changes: 4 additions & 1 deletion sharktank/sharktank/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,10 @@ def __init__(
setattr(self, funcname, func)
if "prefill_bs" in funcname:
self._prefill = func
self.prefill_bs = int(funcname[10:])
if funcname[10:] == "_extend":
self.prefill_bs = 4
else:
self.prefill_bs = int(funcname[10:])
if "decode_bs" in funcname:
self._decode = func
self.decode_bs = int(funcname[9:])
Expand Down
Loading