Skip to content

Commit 22472c1

Browse files
committed
Update on "[Executorch][llm] Enable local global attention in export_llama script"
Added a new option of --local_global_attention that takes in pattern of sizes to determine which layers are using local sliding window attention. For example, [0, 256, 256, 0, 256, 256] can be used for 6 layers transformer. Or you can also use [0, 256, 256] as pattern you want to repeat. Differential Revision: [D73891423](https://our.internmc.facebook.com/intern/diff/D73891423/) cc larryliu0820 mergennachin cccclai helunwencser jackzhxng [ghstack-poisoned]
2 parents e294da2 + e301bfc commit 22472c1

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,9 @@ def build_args_parser() -> argparse.ArgumentParser:
385385
"--local_global_attention",
386386
type=parse_list_of_ints,
387387
default=None,
388-
help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16].",
388+
help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16] to specify that every other layer is sliding window of 16."
389+
" [0, 16, 32] pattern specifes 2nd and 3rd layer has sliding window of 16 and 32 respecitvely. "
390+
" [16] pattern specifies all layers have sliding window of 16.",
389391
)
390392

391393
parser.add_argument("-2", "--fairseq2", action="store_true")
@@ -1332,7 +1334,7 @@ def _get_source_transforms( # noqa
13321334
if args.vulkan:
13331335
transforms.append(replace_with_vulkan_rotary_emb)
13341336

1335-
if args.local_global_attention:
1337+
if getattr(args, "local_global_attention", None) is not None:
13361338
transforms.append(
13371339
partial(
13381340
replace_kv_cache_with_ring_kv_cache,

examples/models/llama/tests/TARGETS

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,19 @@ python_unittest(
8585
"//executorch/examples/models/llama:sdpa",
8686
],
8787
)
88+
89+
python_unittest(
90+
name = "test_export_llama_lib",
91+
srcs = [
92+
"test_export_llama_lib.py",
93+
],
94+
preload_deps = [
95+
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
96+
],
97+
deps = [
98+
"//caffe2:torch",
99+
"//executorch/examples/models/llama:export_library",
100+
"//executorch/examples/models/llama:llama_transformer",
101+
"//executorch/extension/pybindings:portable_lib",
102+
],
103+
)

0 commit comments

Comments
 (0)