Skip to content

Commit 14e7cdc

Browse files
committed
Update base for Update on " [ExecuTorch][BE] Split kv cache and SDPA for better code sharing"
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress Differential Revision: [D67914054](https://our.internmc.facebook.com/intern/diff/D67914054) [ghstack-poisoned]
2 parents 75044ad + af7613c commit 14e7cdc

File tree

11 files changed

+42
-56
lines changed

11 files changed

+42
-56
lines changed

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ include_patterns = [
294294
'build/**/*.py',
295295
'codegen/**/*.py',
296296
# 'devtools/**/*.py',
297+
'devtools/visualization/**/*.py',
297298
'docs/**/*.py',
298299
# 'examples/**/*.py',
299300
# 'exir/**/*.py',

backends/arm/arm_backend.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,6 @@ def dump_intermediate_artifacts_to(
122122
self.path_for_intermediates = output_path
123123
return self
124124

125-
def set_input_order(
126-
self, input_order: Optional[str] = None
127-
) -> "ArmCompileSpecBuilder":
128-
"""
129-
Reorder the inputs coming in. This may be required when inputs > 1.
130-
And while using the U55/U85 CompileSpec.
131-
"""
132-
self.input_order = input_order
133-
return self
134-
135125
def build(self) -> List[CompileSpec]:
136126
"""
137127
Generate a list of compile spec objects from the builder

backends/arm/test/common.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,33 +85,28 @@ def get_tosa_compile_spec_unbuilt(
8585

8686
def get_u55_compile_spec(
8787
custom_path=None,
88-
reorder_inputs=None,
8988
) -> list[CompileSpec]:
9089
"""
9190
Default compile spec for Ethos-U55 tests.
9291
"""
9392
return get_u55_compile_spec_unbuilt(
9493
custom_path=custom_path,
95-
reorder_inputs=reorder_inputs,
9694
).build()
9795

9896

9997
def get_u85_compile_spec(
10098
custom_path=None,
101-
reorder_inputs=None,
10299
) -> list[CompileSpec]:
103100
"""
104101
Default compile spec for Ethos-U85 tests.
105102
"""
106103
return get_u85_compile_spec_unbuilt(
107104
custom_path=custom_path,
108-
reorder_inputs=reorder_inputs,
109105
).build()
110106

111107

112108
def get_u55_compile_spec_unbuilt(
113109
custom_path=None,
114-
reorder_inputs=None,
115110
) -> ArmCompileSpecBuilder:
116111
"""Get the ArmCompileSpecBuilder for the Ethos-U55 tests, to modify
117112
the compile spec before calling .build() to finalize it.
@@ -128,14 +123,12 @@ def get_u55_compile_spec_unbuilt(
128123
extra_flags="--debug-force-regor --output-format=raw",
129124
)
130125
.dump_intermediate_artifacts_to(artifact_path)
131-
.set_input_order(reorder_inputs)
132126
)
133127
return compile_spec
134128

135129

136130
def get_u85_compile_spec_unbuilt(
137131
custom_path=None,
138-
reorder_inputs=None,
139132
) -> list[CompileSpec]:
140133
"""Get the ArmCompileSpecBuilder for the Ethos-U85 tests, to modify
141134
the compile spec before calling .build() to finalize it.
@@ -150,7 +143,6 @@ def get_u85_compile_spec_unbuilt(
150143
extra_flags="--output-format=raw",
151144
)
152145
.dump_intermediate_artifacts_to(artifact_path)
153-
.set_input_order(reorder_inputs)
154146
)
155147
return compile_spec
156148

backends/arm/test/test_arm_baremetal.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,12 @@ test_run_ethosu_fvp() { # End to End model tests
9696
# Ethos-U55
9797
echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U55"
9898
examples/arm/run.sh --target=ethos-u55-128 --model_name=mv2
99-
examples/arm/run.sh --target=ethos-u55-128 --model_name=lstm --reorder_inputs=1,0,2
99+
examples/arm/run.sh --target=ethos-u55-128 --model_name=lstm
100100

101101
# Ethos-U85
102102
echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U85"
103103
examples/arm/run.sh --target=ethos-u85-128 --model_name=mv2
104-
examples/arm/run.sh --target=ethos-u85-128 --model_name=lstm --reorder_inputs=1,0,2
104+
examples/arm/run.sh --target=ethos-u85-128 --model_name=lstm
105105
}
106106

107107
${TEST_SUITE}

devtools/install_requirements.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# Conflict: this requires numpy<2 whereas ExecuTorch core requires numpy>=2
9+
# Follow https://github.com/google-ai-edge/model-explorer/issues/277 for potential
10+
# resolution
11+
pip install ai-edge-model-explorer>=0.1.16

devtools/visualization/visualization_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,16 @@
88
import time
99

1010
from executorch.exir import EdgeProgramManager, ExecutorchProgramManager
11-
from model_explorer import config, consts, visualize_from_config # type: ignore
1211
from torch.export.exported_program import ExportedProgram
1312

13+
try:
14+
from model_explorer import config, consts, visualize_from_config # type: ignore
15+
except ImportError:
16+
print(
17+
"Error: 'model_explorer' is not installed. Install using devtools/install_requirement.sh"
18+
)
19+
raise
20+
1421

1522
class SingletonModelExplorerServer:
1623
"""Singleton context manager for starting a model-explorer server.

devtools/visualization/visualization_utils_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717
visualize,
1818
)
1919
from executorch.exir import ExportedProgram
20-
from model_explorer.config import ModelExplorerConfig # type: ignore
20+
21+
try:
22+
from model_explorer.config import ModelExplorerConfig # type: ignore
23+
except ImportError:
24+
print(
25+
"Error: 'model_explorer' is not installed. Install using devtools/install_requirement.sh"
26+
)
27+
raise
2128

2229

2330
@pytest.fixture

examples/arm/aot_arm_compiler.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -259,34 +259,25 @@ def get_calibration_data(
259259
def get_compile_spec(
260260
target: str,
261261
intermediates: Optional[str] = None,
262-
reorder_inputs: Optional[str] = None,
263262
system_config: Optional[str] = None,
264263
memory_mode: Optional[str] = None,
265264
) -> list[CompileSpec]:
266265
spec_builder = None
267266
if target == "TOSA":
268267
spec_builder = ArmCompileSpecBuilder().tosa_compile_spec("TOSA-0.80+BI")
269268
elif "ethos-u55" in target:
270-
spec_builder = (
271-
ArmCompileSpecBuilder()
272-
.ethosu_compile_spec(
273-
target,
274-
system_config=system_config,
275-
memory_mode=memory_mode,
276-
extra_flags="--debug-force-regor --output-format=raw --verbose-operators --verbose-cycle-estimate",
277-
)
278-
.set_input_order(reorder_inputs)
269+
spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec(
270+
target,
271+
system_config=system_config,
272+
memory_mode=memory_mode,
273+
extra_flags="--debug-force-regor --output-format=raw --verbose-operators --verbose-cycle-estimate",
279274
)
280275
elif "ethos-u85" in target:
281-
spec_builder = (
282-
ArmCompileSpecBuilder()
283-
.ethosu_compile_spec(
284-
target,
285-
system_config=system_config,
286-
memory_mode=memory_mode,
287-
extra_flags="--output-format=raw --verbose-operators --verbose-cycle-estimate",
288-
)
289-
.set_input_order(reorder_inputs)
276+
spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec(
277+
target,
278+
system_config=system_config,
279+
memory_mode=memory_mode,
280+
extra_flags="--output-format=raw --verbose-operators --verbose-cycle-estimate",
290281
)
291282

292283
if intermediates is not None:
@@ -429,14 +420,6 @@ def get_args():
429420
required=False,
430421
help="Location for outputs, if not the default of cwd.",
431422
)
432-
parser.add_argument(
433-
"-r",
434-
"--reorder_inputs",
435-
type=str,
436-
required=False,
437-
default=None,
438-
help="Provide the order of the inputs. This can be required when inputs > 1.",
439-
)
440423
parser.add_argument(
441424
"--system_config",
442425
required=False,
@@ -519,7 +502,6 @@ def get_args():
519502
compile_spec = get_compile_spec(
520503
args.target,
521504
args.intermediates,
522-
args.reorder_inputs,
523505
args.system_config,
524506
args.memory_mode,
525507
)

examples/arm/run.sh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ build_with_etdump=false
2929
build_type="Release"
3030
extra_build_flags=""
3131
build_only=false
32-
reorder_inputs=""
3332
system_config=""
3433
memory_mode=""
3534

@@ -46,7 +45,6 @@ help() {
4645
echo " --extra_build_flags Extra flags to pass to cmake like -DET_ARM_BAREMETAL_METHOD_ALLOCATOR_POOL_SIZE=60000 Default: none "
4746
echo " --build_only Only build, don't run FVP"
4847
echo " --scratch-dir=<FOLDER> Path to your Ethos-U scrach dir if you not using default"
49-
echo " --reorder_inputs=<FLAGS> Reorder the inputs. This can be required when inputs > 1."
5048
echo " --system_config=<CONFIG> System configuration to select from the Vela configuration file (see vela.ini). Default: Ethos_U55_High_End_Embedded for EthosU55 targets, Ethos_U85_SYS_DRAM_Mid for EthosU85 targets."
5149
echo " NOTE: If given, this option must match the given target. This option also sets timing adapter values customized for specific hardware, see ./executor_runner/CMakeLists.txt."
5250
echo " --memory_mode=<MODE> Memory mode to select from the Vela configuration file (see vela.ini), e.g. Shared_Sram/Sram_Only. Default: 'Shared_Sram' for Ethos-U55 targets, 'Sram_Only' for Ethos-U85 targets"
@@ -66,7 +64,6 @@ for arg in "$@"; do
6664
--extra_build_flags=*) extra_build_flags="${arg#*=}";;
6765
--build_only) build_only=true ;;
6866
--scratch-dir=*) root_dir="${arg#*=}";;
69-
--reorder_inputs=*) reorder_inputs="${arg#*=}";;
7067
--system_config=*) system_config="${arg#*=}";;
7168
--memory_mode=*) memory_mode="${arg#*=}";;
7269
*)
@@ -151,7 +148,7 @@ function generate_pte_file() {
151148
# We are using the aot_lib from build_quantization_aot_lib below
152149
SO_LIB=$(find cmake-out-aot-lib -name libquantized_ops_aot_lib.${SO_EXT})
153150
154-
local ARM_AOT_CMD="python3 -m examples.arm.aot_arm_compiler --model_name=${model} --target=${target} ${model_compiler_flags} --reorder_inputs=${reorder_inputs} --output ${output_folder} --so_library=$SO_LIB --system_config=${system_config} --memory_mode=${memory_mode}"
151+
local ARM_AOT_CMD="python3 -m examples.arm.aot_arm_compiler --model_name=${model} --target=${target} ${model_compiler_flags} --output ${output_folder} --so_library=$SO_LIB --system_config=${system_config} --memory_mode=${memory_mode}"
155152
echo "CALL ${ARM_AOT_CMD}" >&2
156153
${ARM_AOT_CMD} 1>&2
157154
@@ -372,7 +369,6 @@ if [[ -z "$model_name" ]]; then
372369
else
373370
test_model=( "$model_name" )
374371
model_compiler_flags=( "$aot_arm_compiler_flags" )
375-
reorder_inputs=( "$reorder_inputs" )
376372
fi
377373
378374
# loop over running the AoT flow and executing the model on device

install_requirements.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ def python_is_compatible():
170170
"tomli", # Imported by extract_sources.py when using python < 3.11.
171171
"wheel", # For building the pip package archive.
172172
"zstd", # Imported by resolve_buck.py.
173-
"ai-edge-model-explorer>=0.1.16", # For visualizing ExportedPrograms
174173
]
175174

176175
# Assemble the list of requirements to actually install.

0 commit comments

Comments
 (0)