Skip to content

Commit fca9d38

Browse files
committed
Update on "[llm] Add a generic text only LLM runner"
Introducing `text_llm_runner`. This can be used to run all text only decoder only LLM models supported by ExecuTorch. * Metadata is being read out from the .pte file and being used to construct the runner object. * examples/models/llama/runner.h[.cpp] only contains a simple wrapper around `text_llm_runner.h[.cpp]`. In next PRs I will move examples/models/phi-3-mini/runner to use the generic runner. Will look into QNN and MediaTek runners as well. Differential Revision: [D75910889](https://our.internmc.facebook.com/intern/diff/D75910889/) [ghstack-poisoned]
2 parents 652f613 + 0845100 commit fca9d38

File tree

145 files changed

+2711
-1711
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

145 files changed

+2711
-1711
lines changed

.ci/scripts/test_llama_torchao_lowbit.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ cmake --build cmake-out -j16 --target install --config Release
4040

4141
# Install llama runner with torchao
4242
cmake -DPYTHON_EXECUTABLE=python \
43-
-DCMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())') \
4443
-DCMAKE_BUILD_TYPE=Release \
4544
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
4645
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \

.ci/scripts/test_model.sh

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,24 @@ prepare_artifacts_upload() {
4949
}
5050

5151
build_cmake_executor_runner() {
52+
local backend_string_select="${1:-}"
5253
echo "Building executor_runner"
5354
rm -rf ${CMAKE_OUTPUT_DIR}
54-
cmake -DCMAKE_BUILD_TYPE=Debug \
55-
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
56-
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
57-
-B${CMAKE_OUTPUT_DIR} .
58-
59-
cmake --build ${CMAKE_OUTPUT_DIR} -j4 --config Debug
55+
mkdir ${CMAKE_OUTPUT_DIR}
56+
if [[ "$backend_string_select" == "XNNPACK" ]]; then
57+
echo "Backend $backend_string_select selected"
58+
(cd ${CMAKE_OUTPUT_DIR} \
59+
&& cmake -DCMAKE_BUILD_TYPE=Release \
60+
-DEXECUTORCH_BUILD_XNNPACK=ON \
61+
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" ..)
62+
cmake --build ${CMAKE_OUTPUT_DIR} -j4
63+
else
64+
cmake -DCMAKE_BUILD_TYPE=Debug \
65+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
66+
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
67+
-B${CMAKE_OUTPUT_DIR} .
68+
cmake --build ${CMAKE_OUTPUT_DIR} -j4 --config Debug
69+
fi
6070
}
6171

6272
run_portable_executor_runner() {
@@ -111,19 +121,6 @@ test_model() {
111121
run_portable_executor_runner
112122
}
113123

114-
build_cmake_xnn_executor_runner() {
115-
echo "Building xnn_executor_runner"
116-
117-
(rm -rf ${CMAKE_OUTPUT_DIR} \
118-
&& mkdir ${CMAKE_OUTPUT_DIR} \
119-
&& cd ${CMAKE_OUTPUT_DIR} \
120-
&& retry cmake -DCMAKE_BUILD_TYPE=Release \
121-
-DEXECUTORCH_BUILD_XNNPACK=ON \
122-
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" ..)
123-
124-
cmake --build ${CMAKE_OUTPUT_DIR} -j4
125-
}
126-
127124
test_model_with_xnnpack() {
128125
WITH_QUANTIZATION=$1
129126
WITH_DELEGATION=$2
@@ -148,12 +145,11 @@ test_model_with_xnnpack() {
148145

149146
# Run test model
150147
if [[ "${BUILD_TOOL}" == "buck2" ]]; then
148+
# TODO eventually buck should also use consolidated executor runners
151149
buck2 run //examples/xnnpack:xnn_executor_runner -- --model_path "${OUTPUT_MODEL_PATH}"
152150
elif [[ "${BUILD_TOOL}" == "cmake" ]]; then
153-
if [[ ! -f ${CMAKE_OUTPUT_DIR}/backends/xnnpack/xnn_executor_runner ]]; then
154-
build_cmake_xnn_executor_runner
155-
fi
156-
./${CMAKE_OUTPUT_DIR}/backends/xnnpack/xnn_executor_runner --model_path "${OUTPUT_MODEL_PATH}"
151+
build_cmake_executor_runner "XNNPACK"
152+
./${CMAKE_OUTPUT_DIR}/executor_runner --model_path "${OUTPUT_MODEL_PATH}"
157153
else
158154
echo "Invalid build tool ${BUILD_TOOL}. Only buck2 and cmake are supported atm"
159155
exit 1

.ci/scripts/utils.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,7 @@ build_executorch_runner() {
158158
cmake_install_executorch_lib() {
159159
echo "Installing libexecutorch.a and libportable_kernels.a"
160160
clean_executorch_install_folders
161-
retry cmake -DBUCK2="$BUCK" \
162-
-DCMAKE_INSTALL_PREFIX=cmake-out \
161+
retry cmake -DCMAKE_INSTALL_PREFIX=cmake-out \
163162
-DCMAKE_BUILD_TYPE=Release \
164163
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
165164
-Bcmake-out .

.github/scripts/label_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222

2323
LABEL_ERR_MSG_TITLE = "This PR needs a `release notes:` label"
2424
LABEL_ERR_MSG = f"""# {LABEL_ERR_MSG_TITLE}
25-
If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with `release notes:`.
26-
27-
If not, please add the `release notes: none` label.
25+
If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with `release notes:`. This helps us keep track and include your important work in the next release notes.
2826
2927
To add a label, you can comment to pytorchbot, for example
3028
`@pytorchbot label "release notes: none"`

.github/scripts/trymerge.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,7 @@
5959
patterns_to_regex,
6060
retries_decorator,
6161
)
62-
from label_utils import (
63-
gh_add_labels,
64-
gh_remove_label,
65-
has_required_labels,
66-
LABEL_ERR_MSG,
67-
)
62+
from label_utils import gh_add_labels, gh_remove_label
6863
from trymerge_explainer import get_revert_message, TryMergeExplainer
6964

7065
# labels
@@ -2116,9 +2111,6 @@ def merge(
21162111
# Check for approvals
21172112
find_matching_merge_rule(pr, repo, skip_mandatory_checks=True)
21182113

2119-
if not has_required_labels(pr):
2120-
raise RuntimeError(LABEL_ERR_MSG.lstrip(" #"))
2121-
21222114
if ignore_current:
21232115
checks = pr.get_checkrun_conclusions()
21242116
_, failing, _ = categorize_checks(

.github/workflows/check-labels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ jobs:
5151
PR_NUM: ${{ github.event.number || github.event.inputs.pr_number }}
5252
run: |
5353
set -ex
54-
python3 .github/scripts/check_labels.py --exit-non-zero "${PR_NUM}"
54+
python3 .github/scripts/check_labels.py "${PR_NUM}"

backends/cadence/aot/replace_ops.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,11 +2065,10 @@ def call_operator(
20652065
return super().call_operator(op, args, kwargs, meta)
20662066

20672067

2068-
@register_cadence_pass(CadencePassAttribute(opt_level=2))
2069-
class ReplaceGeluWithApproximateGeluPass(ExportPass):
2068+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2069+
class ReplaceAtenApproxGeluWithApproxGeluPass(ExportPass):
20702070
"""
2071-
Replace the gelu op with an approximate gelu op. The approximate gelu op
2072-
is more efficient on DSP backends.
2071+
Replace the aten gelu op with an approximate arg with an approximate gelu op.
20732072
"""
20742073

20752074
def call_operator(
@@ -2079,6 +2078,9 @@ def call_operator(
20792078
kwargs: Dict[str, Argument],
20802079
meta: NodeMetadata,
20812080
) -> ProxyValue:
2081+
if "approximate" not in kwargs:
2082+
return super().call_operator(op, args, kwargs, meta)
2083+
20822084
if op not in {
20832085
exir_ops.edge.aten.gelu.default,
20842086
}:
@@ -2414,7 +2416,7 @@ class CadenceReplaceOpsInGraph:
24142416
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
24152417
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
24162418
ReplaceWhereWithFullArgsWithWhereScalar,
2417-
ReplaceGeluWithApproximateGeluPass,
2419+
ReplaceAtenApproxGeluWithApproxGeluPass,
24182420
ReplaceSplitWithSlicePass,
24192421
ReplacePowWithMulPass,
24202422
]

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
ForceChannelLastForConvPass,
2727
MakeSliceAndCatDimOutermostPass,
2828
ReplaceAddMMWithLinearPass,
29+
ReplaceAtenApproxGeluWithApproxGeluPass,
2930
ReplaceAtenConvolutionWithJarvisConvolutionPass,
3031
ReplaceConstantPadNdWithSlicePass,
3132
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
3233
ReplaceConvWithIm2RowAndLinear,
3334
ReplaceEmptyTensorsWithFullPass,
3435
ReplaceFunctionallyEquivalentOpTargets,
35-
ReplaceGeluWithApproximateGeluPass,
3636
ReplaceIm2RowWithViewPass,
3737
ReplaceLinearWithFullyConnectedOpPass,
3838
ReplaceMatmulWithTransposedMatmulPass,
@@ -1287,17 +1287,41 @@ def forward(self, cond: torch.Tensor):
12871287
1,
12881288
)
12891289

1290-
def test_replace_aten_gelu_with_approximate_gelu(self):
1291-
class Gelu(torch.nn.Module):
1292-
def forward(self, input):
1293-
return torch.nn.functional.gelu(input)
1290+
def test_no_replace_aten_gelu_with_approximate_gelu(self):
1291+
inputs = torch.randn(2, 1, 64)
1292+
1293+
gm = single_op_builder(
1294+
placeholders=(inputs,),
1295+
op=exir_ops.edge.aten.gelu.default,
1296+
args=(inputs,),
1297+
)
1298+
gm = ExportPass().call(gm).graph_module
1299+
1300+
p = ReplaceAtenApproxGeluWithApproxGeluPass()
1301+
graph_after_passes = p.call(gm).graph_module
12941302

1303+
# Assert that aten.gelu op was not decomposed, since it didn't have an approximate argument
1304+
self.assertEqual(
1305+
count_node(
1306+
graph_after_passes,
1307+
exir_ops.edge.aten.gelu.default,
1308+
),
1309+
1,
1310+
)
1311+
1312+
def test_replace_aten_approximate_gelu_with_approximate_gelu(self):
12951313
inputs = torch.randn(2, 1, 64)
12961314

1297-
graph_module = export_to_edge(Gelu(), (inputs,)).exported_program().graph_module
1315+
gm = single_op_builder(
1316+
placeholders=(inputs,),
1317+
op=exir_ops.edge.aten.gelu.default,
1318+
args=(inputs,),
1319+
kwargs={"approximate": "tanh"},
1320+
)
1321+
gm = ExportPass().call(gm).graph_module
12981322

1299-
p = ReplaceGeluWithApproximateGeluPass()
1300-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1323+
p = ReplaceAtenApproxGeluWithApproxGeluPass()
1324+
graph_after_passes = p.call(gm).graph_module
13011325

13021326
# Assert that aten.gelu op was decomposed
13031327
self.assertEqual(

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3880,6 +3880,41 @@ def test_conv_former(self):
38803880
self.assertGreaterEqual(msg["top_1"], 60)
38813881
self.assertGreaterEqual(msg["top_5"], 80)
38823882

3883+
def test_deit(self):
3884+
if not self.required_envs([self.image_dataset]):
3885+
self.skipTest("missing required envs")
3886+
cmds = [
3887+
"python",
3888+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/deit.py",
3889+
"--dataset",
3890+
self.image_dataset,
3891+
"--artifact",
3892+
self.artifact_dir,
3893+
"--build_folder",
3894+
self.build_folder,
3895+
"--device",
3896+
self.device,
3897+
"--model",
3898+
self.model,
3899+
"--ip",
3900+
self.ip,
3901+
"--port",
3902+
str(self.port),
3903+
]
3904+
if self.host:
3905+
cmds.extend(["--host", self.host])
3906+
3907+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
3908+
with Listener((self.ip, self.port)) as listener:
3909+
conn = listener.accept()
3910+
p.communicate()
3911+
msg = json.loads(conn.recv())
3912+
if "Error" in msg:
3913+
self.fail(msg["Error"])
3914+
else:
3915+
self.assertGreaterEqual(msg["top_1"], 75)
3916+
self.assertGreaterEqual(msg["top_5"], 90)
3917+
38833918
def test_dino_v2(self):
38843919
if not self.required_envs([self.image_dataset]):
38853920
self.skipTest("missing required envs")

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.exir import ExportedProgram
1818
from executorch.exir.dialects._ops import ops as exir_ops
1919
from executorch.exir.pass_base import ExportPass, PassResult
20+
from executorch.exir.passes import dead_code_elimination_pass
2021

2122
#################
2223
## linear_qcnw ##
@@ -224,6 +225,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
224225
)
225226

226227
graph_module.recompile()
227-
graph_module = super().call(graph_module).graph_module
228+
dead_code_elimination_pass(graph_module)
228229

230+
# Re-trace the graph since new nodes were (potentially) inserted
231+
graph_module = super().call(graph_module).graph_module
229232
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)