Skip to content

Commit 8c8a7df

Browse files
committed
Update
[ghstack-poisoned]
2 parents f121a8d + bcb0f18 commit 8c8a7df

File tree

87 files changed

+2647
-874
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+2647
-874
lines changed

.ci/scripts/test_llama_lora.sh

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,17 @@ DOWNLOADED_PATH=$(
4848
--model_id "${HF_MODEL_REPO}" \
4949
--files "adapter_config.json" "adapter_model.pt" "consolidated.00.pth" "params.json" "tokenizer.model"
5050
)
51-
EXPORTED_MODEL_NAME="llama_3_2_1B_lora.pte"
52-
# Export model.
51+
# Build llama runner.
52+
cmake_install_executorch_libraries
53+
cmake_build_llama_runner
54+
55+
# Constants.
56+
RUNTIME_ARGS="--tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"
57+
PROMPT="What happens if you eat watermelon seeds?"
58+
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"
59+
60+
# Export LoRA PTE file.
61+
MODEL_NAME="llama_3_2_1B_lora"
5362
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
5463
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
5564
base.params="${DOWNLOADED_PATH}/params.json" \
@@ -61,36 +70,64 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
6170
model.dtype_override="fp32" \
6271
backend.xnnpack.enabled=true \
6372
backend.xnnpack.extended_ops=true \
64-
export.output_name="${EXPORTED_MODEL_NAME}"
65-
66-
# Build llama runner.
67-
cmake_install_executorch_libraries
68-
cmake_build_llama_runner
73+
export.output_name="${MODEL_NAME}.pte"
6974

70-
PROMPT="What happens if you eat watermelon seeds?"
7175
# Run llama runner
72-
RUNTIME_ARGS="--model_path=${EXPORTED_MODEL_NAME} --tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1"
73-
7476
NOW=$(date +"%H:%M:%S")
7577
echo "Starting to run llama runner at ${NOW}"
7678
# shellcheck source=/dev/null
77-
cmake-out/examples/models/llama/llama_main --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
79+
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_NAME}.pte --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt
7880
NOW=$(date +"%H:%M:%S")
7981
echo "Finished at ${NOW}"
8082

8183
RESULT=$(cat result.txt)
82-
EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C,"
83-
8484
if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then
8585
echo "Expected result prefix: ${EXPECTED_PREFIX}"
8686
echo "Actual result: ${RESULT}"
87+
# Do not clean up files if test passes, as they're re-used in the next test.
8788
echo "Success"
88-
cleanup_files
8989
else
9090
echo "Expected result prefix: ${EXPECTED_PREFIX}"
9191
echo "Actual result: ${RESULT}"
9292
echo "Failure; results not the same"
93+
cleanup_files
94+
exit 1
95+
fi
9396

97+
# Export LoRA PTE, PTD file.
98+
MODEL_SEPARATE="${MODEL_NAME}_separate"
99+
$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \
100+
base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \
101+
base.params="${DOWNLOADED_PATH}/params.json" \
102+
base.adapter_checkpoint="${DOWNLOADED_PATH}/adapter_model.pt" \
103+
base.adapter_config="${DOWNLOADED_PATH}/adapter_config.json" \
104+
base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \
105+
model.use_kv_cache=true \
106+
model.use_sdpa_with_kv_cache=true \
107+
model.dtype_override="fp32" \
108+
backend.xnnpack.enabled=true \
109+
backend.xnnpack.extended_ops=true \
110+
export.output_name="${MODEL_SEPARATE}.pte" \
111+
export.foundation_weights_file="${MODEL_SEPARATE}.ptd"
112+
113+
# Run llama runner.
114+
NOW=$(date +"%H:%M:%S")
115+
echo "Starting to run llama runner at ${NOW}"
116+
# shellcheck source=/dev/null
117+
cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_path=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt
118+
NOW=$(date +"%H:%M:%S")
119+
echo "Finished at ${NOW}"
120+
121+
RESULT2=$(cat result2.txt)
122+
if [[ "${RESULT2}" == "${EXPECTED_PREFIX}"* ]]; then
123+
echo "Expected result prefix: ${EXPECTED_PREFIX}"
124+
echo "Actual result: ${RESULT2}"
125+
echo "Success"
126+
cleanup_files
127+
else
128+
echo "Expected result prefix: ${EXPECTED_PREFIX}"
129+
echo "Actual result: ${RESULT2}"
130+
echo "Failure; results not the same"
94131
cleanup_files
95132
exit 1
96133
fi

.github/workflows/pull.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ jobs:
315315
bash examples/models/moshi/mimi/install_requirements.sh
316316
317317
# reinstall executorch
318-
bash ./install_executorch.sh
318+
bash ./install_executorch.sh --minimal
319319
320320
# run python unittest
321321
python -m unittest examples.models.moshi.mimi.test_mimi

.github/workflows/trunk.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ jobs:
288288
- test_arm_baremetal: test_models_tosa
289289
- test_arm_baremetal: test_models_ethos-u55
290290
- test_arm_baremetal: test_models_ethos-u85
291+
- test_arm_baremetal: test_smaller_stories_llama
291292
fail-fast: false
292293
with:
293294
runner: linux.2xlarge.memory

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pip-out/
2424
# Any exported models and profiling outputs
2525
*.bin
2626
*.model
27+
*.etdump
2728
tokenizer.json
2829
*.pte
2930
*.ptd
@@ -58,6 +59,7 @@ xcuserdata/
5859
/include/
5960
/share/
6061
/version.py
62+
*.csv
6163

6264
# Android
6365
*.aar

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .decompose_div_pass import DecomposeDivPass # noqa
3737
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
3838
from .decompose_gelu_pass import DecomposeGeluPass # noqa
39+
from .decompose_glu_pass import DecomposeGluPass # noqa
3940
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
4041
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
4142
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
DecomposeDivPass,
4242
DecomposeEmbeddingPass,
4343
DecomposeGeluPass,
44+
DecomposeGluPass,
4445
DecomposeGroupedConv,
4546
DecomposeGroupNormPass,
4647
DecomposeLayerNormPass,
@@ -184,6 +185,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
184185
self.add_pass(ConvertSplitToSlicePass())
185186
self.add_pass(FuseBatchnorm2DPass(exported_program))
186187
self.add_pass(ConvertMmToBmmPass())
188+
self.add_pass(DecomposeGluPass())
187189
self.add_pass(DecomposeLinearPass())
188190
self.add_pass(DecomposeLeakyReLUPass())
189191
self.add_pass(DecomposeGroupNormPass())
@@ -264,6 +266,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
264266
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
265267
self.add_pass(DecomposeNotEqualPass())
266268
self.add_pass(DecomposeCosineSimilarityPass())
269+
self.add_pass(DecomposeGluPass())
267270
self.add_pass(DecomposeDivPass())
268271
self.add_pass(DecomposeLeakyReLUPass())
269272
self.add_pass(DecomposeLinearVectorNormPass())
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes import ArmPass
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
11+
# For FP case
12+
edge_glu = exir_ops.edge.aten.glu.default
13+
14+
# For INT case
15+
aten_glu = torch.ops.aten.glu.default
16+
17+
18+
def get_ops(op):
19+
"""Returns the appropriate operator functions based on the input operator."""
20+
if op == edge_glu:
21+
return (
22+
exir_ops.edge.aten.mul.Tensor,
23+
exir_ops.edge.aten.sigmoid.default,
24+
exir_ops.edge.aten.slice_copy.Tensor,
25+
)
26+
elif op == aten_glu:
27+
return (
28+
torch.ops.aten.mul.Tensor,
29+
torch.ops.aten.sigmoid.default,
30+
torch.ops.aten.slice_copy.Tensor,
31+
)
32+
else:
33+
raise ValueError(f"Unsupported operator: {op}")
34+
35+
36+
class DecomposeGluPass(ArmPass):
37+
"""Decomposes the GLU operator into hadamard product and sigmoid."""
38+
39+
def call_operator(self, op, args, kwargs, meta):
40+
if op not in [edge_glu, aten_glu]:
41+
return super().call_operator(op, args, kwargs, meta)
42+
43+
hadamard_prod, sigmoid, slice_op = get_ops(op)
44+
X = args[0]
45+
46+
dim = args[1] if len(args) > 1 else kwargs.get("dim", -1)
47+
48+
if "val" not in X.node.meta:
49+
raise Exception("Could not get dimension metadata in input.")
50+
51+
if dim < 0:
52+
dim += X.node.meta["val"].dim()
53+
54+
n = X.node.meta["val"].size(dim)
55+
56+
if n % 2:
57+
raise RuntimeError(
58+
f"glu expects an even split along dim={dim}, got size {n}"
59+
)
60+
61+
middle = n // 2
62+
63+
T1 = super().call_operator(
64+
slice_op, (X, dim, 0, middle), {}, meta, updated=True
65+
)
66+
67+
T2 = super().call_operator(
68+
slice_op, (X, dim, middle, n), {}, meta, updated=True
69+
)
70+
71+
T2_sigmoid = super().call_operator(sigmoid, (T2,), {}, meta, updated=True)
72+
73+
return super().call_operator(
74+
hadamard_prod, (T1, T2_sigmoid), {}, meta, updated=True
75+
)

backends/arm/_passes/fuse_equal_placeholders_pass.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import hashlib
7+
from collections import defaultdict
8+
69
import torch
710
from executorch.backends.arm._passes.arm_pass_utils import (
811
get_constant_placeholder_kind,
@@ -21,7 +24,7 @@ class FuseEqualPlaceholdersPass(ExportPass):
2124
"""
2225
This pass optimizes memory usage by finding constant placeholders
2326
pointing to identical tensors and fusing them to one single placeholder
24-
with multiple users.
27+
with multiple users, using a cache for faster comparison.
2528
"""
2629

2730
def __init__(self, exported_program: ExportedProgram):
@@ -30,58 +33,54 @@ def __init__(self, exported_program: ExportedProgram):
3033

3134
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
3235
modified = False
33-
const_placeholder_nodes = []
34-
for node in graph_module.graph.nodes:
35-
if is_param_node(self.exported_program, node):
36-
const_placeholder_nodes.append(node)
37-
38-
while const_placeholder_nodes:
3936

40-
# Find equal tensors
41-
node1 = const_placeholder_nodes.pop()
42-
eq_nodes = [node1]
43-
tensor1 = get_param_tensor(self.exported_program, node1)
44-
if tensor1 is None:
37+
# Build a cache of params: mapping hash_key -> list of (node, tensor)
38+
hash_buckets = defaultdict(list)
39+
for node in graph_module.graph.nodes:
40+
if not is_param_node(self.exported_program, node):
4541
continue
42+
tensor = get_param_tensor(self.exported_program, node)
43+
if tensor is None:
44+
continue
45+
# Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes
46+
# Ensure tensor is on CPU and contiguous
47+
t_cpu = tensor.detach().cpu().contiguous()
48+
data_bytes = t_cpu.numpy().tobytes()
49+
key = (
50+
str(t_cpu.dtype),
51+
tuple(t_cpu.shape),
52+
hashlib.sha1(data_bytes).hexdigest(),
53+
)
54+
hash_buckets[key].append((node, t_cpu))
4655

47-
for node2 in const_placeholder_nodes:
48-
tensor2 = get_param_tensor(self.exported_program, node2)
49-
if tensor2 is None:
50-
continue
51-
52-
if (
53-
tensor1.dtype == tensor2.dtype
54-
and tensor1.shape == tensor2.shape
55-
and torch.allclose(tensor1, tensor2, atol=1e-08)
56-
):
57-
eq_nodes.append(node2)
56+
# For each bucket with more than one entry, fuse:
57+
for nodes_tensors in hash_buckets.values():
58+
if len(nodes_tensors) < 2:
59+
continue
5860

59-
if len(eq_nodes) > 1:
60-
common_name = node1.name + "_common"
61-
common_kind = get_constant_placeholder_kind(
62-
self.exported_program, node1
61+
# Create a new placeholder from first in list of equal placeholders.
62+
rep_node, rep_tensor = nodes_tensors[0]
63+
common_name = rep_node.name + "_common"
64+
common_kind = get_constant_placeholder_kind(self.exported_program, rep_node)
65+
common_persistent = True
66+
with graph_module.graph.inserting_before(rep_node):
67+
common_node = create_constant_placeholder(
68+
self.exported_program,
69+
graph_module.graph,
70+
common_name,
71+
common_kind,
72+
rep_tensor,
73+
common_persistent,
6374
)
64-
common_persisten_buffer = True
65-
66-
with graph_module.graph.inserting_before(node1):
67-
common_node = create_constant_placeholder(
68-
self.exported_program,
69-
graph_module.graph,
70-
common_name,
71-
common_kind,
72-
tensor1,
73-
common_persisten_buffer,
74-
)
75-
76-
for eq_node in eq_nodes:
77-
eq_node.replace_all_uses_with(common_node)
78-
delete_constant_placeholder(self.exported_program, eq_node)
79-
if eq_node != node1:
80-
const_placeholder_nodes.remove(eq_node)
8175

76+
# Replace uses and delete duplicates
77+
for node, _ in nodes_tensors:
78+
node.replace_all_uses_with(common_node)
79+
delete_constant_placeholder(self.exported_program, node)
8280
modified = True
8381

8482
if modified:
8583
graph_module.recompile()
8684
graph_module = super().call(graph_module).graph_module
85+
8786
return PassResult(graph_module=graph_module, modified=modified)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def is_node_supported(
258258
exir_ops.edge.aten.masked_fill.Scalar,
259259
exir_ops.edge.aten.asinh.default,
260260
exir_ops.edge.aten.cosh.default,
261+
exir_ops.edge.aten.glu.default,
261262
]
262263

263264
return supported
@@ -299,6 +300,7 @@ def is_node_supported(
299300
exir_ops.edge.aten.leaky_relu.default: None,
300301
exir_ops.edge.aten.round.default: None,
301302
exir_ops.edge.aten.addmm.default: None,
303+
exir_ops.edge.aten.glu.default: None,
302304
}
303305

304306
if node.target in needs_decomp_dict:

0 commit comments

Comments
 (0)