diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 4c23b8316fc..adc639be12e 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -75,6 +75,7 @@ def generate_batch_prefill(bs: int): } bs_min = bs + prefill_name = f"prefill_bs{bs}" if export_config.has_prefill_position: seq_len_blocks_dim_chunked = torch.export.Dim( @@ -95,6 +96,9 @@ def generate_batch_prefill(bs: int): if "start_pos" in dynamic_shapes: dynamic_shapes["start_pos"][0] = extend_bs + prefill_name = "prefill_bs_extend" + export_config.bs_prefill = None + seq_block_ids = torch.empty(bs_min, block_dim_min, dtype=torch.int64) tokens = torch.empty( bs_min, @@ -103,13 +107,13 @@ def generate_batch_prefill(bs: int): ) seq_lens = torch.empty(bs_min, dtype=torch.int64) - print(f"Exporting prefill_bs{bs}") + print(f"Exporting {prefill_name}") if export_config.has_prefill_position: arg_devices = model.setup_arg_devices(cache_affinities, len(dynamic_shapes)) @fxb.export_program( - name=f"prefill_bs{bs}", + name=prefill_name, args=(tokens, start_pos, seq_lens, seq_block_ids, cache), dynamic_shapes=dynamic_shapes, arg_device=arg_devices, @@ -132,7 +136,7 @@ def _( arg_devices = model.setup_arg_devices(cache_affinities, len(dynamic_shapes)) @fxb.export_program( - name=f"prefill_bs{bs}", + name=prefill_name, args=(tokens, seq_lens, seq_block_ids, cache), dynamic_shapes=dynamic_shapes, arg_device=arg_devices, diff --git a/sharktank/sharktank/models/llm/config.py b/sharktank/sharktank/models/llm/config.py index 0f8a0cb6047..bcb6ab5d586 100644 --- a/sharktank/sharktank/models/llm/config.py +++ b/sharktank/sharktank/models/llm/config.py @@ -27,12 +27,13 @@ class ServiceConfig: max_seq_len: int attn_head_dim: int prefill_batch_sizes: list[int] - has_prefill_position: bool decode_batch_sizes: list[int] transformer_block_count: int logits_normalization: Optional[str] top_k: Optional[int] paged_kv_cache: KVCacheConfig + has_prefill_position: bool = False + use_extend_attention: bool = False @staticmethod def load(fp: Path): diff --git a/sharktank/sharktank/models/llm/export.py b/sharktank/sharktank/models/llm/export.py index 8a06b815663..4a9e62b5a18 100644 --- a/sharktank/sharktank/models/llm/export.py +++ b/sharktank/sharktank/models/llm/export.py @@ -202,10 +202,11 @@ def build_service_config( max_seq_len=hp.context_length, attn_head_dim=hp.attn_head_dim, prefill_batch_sizes=export_config.bs_prefill, - has_prefill_position=export_config.has_prefill_position, decode_batch_sizes=export_config.bs_decode, transformer_block_count=hp.block_count, logits_normalization=export_config.logits_normalization, top_k=export_config.top_k, paged_kv_cache=kv_config, + has_prefill_position=export_config.has_prefill_position, + use_extend_attention=export_config.use_extend_attention, ) diff --git a/sharktank/sharktank/utils/llm_utils.py b/sharktank/sharktank/utils/llm_utils.py index 98522b1db0b..e9eb63fd440 100644 --- a/sharktank/sharktank/utils/llm_utils.py +++ b/sharktank/sharktank/utils/llm_utils.py @@ -177,7 +177,10 @@ def __init__( setattr(self, funcname, func) if "prefill_bs" in funcname: self._prefill = func - self.prefill_bs = int(funcname[10:]) + if funcname[10:] == "_extend": + self.prefill_bs = 4 + else: + self.prefill_bs = int(funcname[10:]) if "decode_bs" in funcname: self._decode = func self.decode_bs = int(funcname[9:])