Skip to content

Commit e87537c

Browse files
committed
Update base for Update on "[Excutorch][Llama] Decouple input sequence length from kv cache context length"
Decouple max sequence length, for shape dynamism in torch.export, from sequence length used for kv cache sizing. Differential Revision: [D68448334](https://our.internmc.facebook.com/intern/diff/D68448334/) [ghstack-poisoned]
1 parent 388d2ae commit e87537c

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

backends/vulkan/vulkan_preprocess.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
from functools import partial
10+
911
from typing import Any, Dict, final, List
1012

1113
import executorch.backends.vulkan.utils as utils
@@ -18,6 +20,9 @@
1820
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
1921
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
2022

23+
from executorch.exir.memory_planning import (
24+
greedy,
25+
)
2126
from executorch.backends.vulkan._passes import (
2227
insert_prepack_nodes,
2328
RemoveLocalScalarDenseOpsTransform,
@@ -189,11 +194,12 @@ def preprocess( # noqa: C901
189194

190195
# Finally, apply dynamic shape passes and memory planning pass. These passes
191196
# must be applied only when the graph structure is finalized.
197+
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
192198
program = apply_passes(
193199
program,
194200
[
195201
ConstraintBasedSymShapeEvalPass(),
196-
MemoryPlanningPass(),
202+
MemoryPlanningPass(memory_planning_algo=greedy_memory_planning),
197203
],
198204
)
199205

exir/memory_planning.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,9 @@ def _find_max_overlapping_allocations_offset(
547547

548548

549549
def pick_shared_obj(
550-
shared_objects: List[SharedObject], spec: TensorSpec
550+
shared_objects: List[SharedObject],
551+
spec: TensorSpec,
552+
allow_overlapping_allocations: bool = True,
551553
) -> SharedObject:
552554
r"""
553555
Pick the available shared object to which to assign this spec,
@@ -611,7 +613,7 @@ def pick_shared_obj(
611613
picked.allocations.append(allocation_spec)
612614
break
613615

614-
if picked is None:
616+
if picked is None and allow_overlapping_allocations:
615617
for sobj in shared_objects:
616618
max_offset = _find_max_overlapping_allocations_offset(sobj, spec)
617619
if max_offset > 0:
@@ -673,7 +675,16 @@ def greedy(
673675
graph_signature: Optional[ExportGraphSignature] = None,
674676
alloc_graph_input: bool = True,
675677
alloc_graph_output: bool = True,
678+
allow_overlapping_allocations: bool = True,
676679
) -> List[int]:
680+
r"""Greedy algorithm to allocate memory for tensors in the graph.
681+
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
682+
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
683+
allow_overlapping_allocations: If set to true, allows for allocations that overlap
684+
in their lifetime but are at different offsets in the storage. By default true.
685+
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
686+
allocations disabled
687+
"""
677688
spec2obj = {}
678689
shared_objects = defaultdict(list)
679690
# Don't do assertion in collect_specs_from_nodes if we have already encountered
@@ -699,7 +710,9 @@ def greedy(
699710
if spec.mem_id is None:
700711
spec.mem_id = 1
701712
spec.realign(alignment)
702-
spec2obj[spec] = pick_shared_obj(shared_objects[spec.mem_id], spec)
713+
spec2obj[spec] = pick_shared_obj(
714+
shared_objects[spec.mem_id], spec, allow_overlapping_allocations
715+
)
703716

704717
if len(shared_objects) == 0:
705718
# Cannot find any tensor in the graph that needs to be allocated.

0 commit comments

Comments
 (0)