diff --git a/.ci/scripts/test_llama_lora.sh b/.ci/scripts/test_llama_lora.sh index 5c87cb8da72..6337bbf76a2 100644 --- a/.ci/scripts/test_llama_lora.sh +++ b/.ci/scripts/test_llama_lora.sh @@ -48,8 +48,17 @@ DOWNLOADED_PATH=$( --model_id "${HF_MODEL_REPO}" \ --files "adapter_config.json" "adapter_model.pt" "consolidated.00.pth" "params.json" "tokenizer.model" ) -EXPORTED_MODEL_NAME="llama_3_2_1B_lora.pte" -# Export model. +# Build llama runner. +cmake_install_executorch_libraries +cmake_build_llama_runner + +# Constants. +RUNTIME_ARGS="--tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1" +PROMPT="What happens if you eat watermelon seeds?" +EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C," + +# Export LoRA PTE file. +MODEL_NAME="llama_3_2_1B_lora" $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ base.params="${DOWNLOADED_PATH}/params.json" \ @@ -61,36 +70,64 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ model.dtype_override="fp32" \ backend.xnnpack.enabled=true \ backend.xnnpack.extended_ops=true \ - export.output_name="${EXPORTED_MODEL_NAME}" - -# Build llama runner. -cmake_install_executorch_libraries -cmake_build_llama_runner + export.output_name="${MODEL_NAME}.pte" -PROMPT="What happens if you eat watermelon seeds?" # Run llama runner -RUNTIME_ARGS="--model_path=${EXPORTED_MODEL_NAME} --tokenizer_path=${DOWNLOADED_PATH}/tokenizer.model --temperature=0 --seq_len=20 --warmup=1" - NOW=$(date +"%H:%M:%S") echo "Starting to run llama runner at ${NOW}" # shellcheck source=/dev/null -cmake-out/examples/models/llama/llama_main --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt +cmake-out/examples/models/llama/llama_main --model_path=${MODEL_NAME}.pte --prompt="${PROMPT}" ${RUNTIME_ARGS} > result.txt NOW=$(date +"%H:%M:%S") echo "Finished at ${NOW}" RESULT=$(cat result.txt) -EXPECTED_PREFIX="What happens if you eat watermelon seeds? Watermelon seeds are a good source of vitamin C," - if [[ "${RESULT}" == "${EXPECTED_PREFIX}"* ]]; then echo "Expected result prefix: ${EXPECTED_PREFIX}" echo "Actual result: ${RESULT}" + # Do not clean up files if test passes, as they're re-used in the next test. echo "Success" - cleanup_files else echo "Expected result prefix: ${EXPECTED_PREFIX}" echo "Actual result: ${RESULT}" echo "Failure; results not the same" + cleanup_files + exit 1 +fi +# Export LoRA PTE, PTD file. +MODEL_SEPARATE="${MODEL_NAME}_separate" +$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ + base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ + base.params="${DOWNLOADED_PATH}/params.json" \ + base.adapter_checkpoint="${DOWNLOADED_PATH}/adapter_model.pt" \ + base.adapter_config="${DOWNLOADED_PATH}/adapter_config.json" \ + base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \ + model.use_kv_cache=true \ + model.use_sdpa_with_kv_cache=true \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=true \ + backend.xnnpack.extended_ops=true \ + export.output_name="${MODEL_SEPARATE}.pte" \ + export.foundation_weights_file="${MODEL_SEPARATE}.ptd" + +# Run llama runner. +NOW=$(date +"%H:%M:%S") +echo "Starting to run llama runner at ${NOW}" +# shellcheck source=/dev/null +cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_path=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt +NOW=$(date +"%H:%M:%S") +echo "Finished at ${NOW}" + +RESULT2=$(cat result2.txt) +if [[ "${RESULT2}" == "${EXPECTED_PREFIX}"* ]]; then + echo "Expected result prefix: ${EXPECTED_PREFIX}" + echo "Actual result: ${RESULT2}" + echo "Success" + cleanup_files +else + echo "Expected result prefix: ${EXPECTED_PREFIX}" + echo "Actual result: ${RESULT2}" + echo "Failure; results not the same" cleanup_files exit 1 fi diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 90a9a3063e3..6a055c9413f 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -621,8 +621,12 @@ def get_serialized_buffer_index( ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key) ) - external_tag = tensor.meta.get("delegate_constant_tag", None) + custom_meta = tensor.meta.get("custom", None) + external_tag = ( + custom_meta.get("delegate_constant_tag", None) if custom_meta else None + ) if external_tag is not None: + external_tag = custom_meta.get("delegate_constant_tag", None) logging.info( f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store" ) diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 9ea683e4174..62c33c6a245 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -153,10 +153,10 @@ runtime.python_library( "//caffe2:torch", "//executorch/extension/llm/export/config:llm_config", "//executorch/backends/vulkan/_passes:vulkan_passes", + "//executorch/exir/passes:external_constants_pass", "//executorch/exir/passes:init_mutable_pass", "//executorch/examples/models:model_base", "//executorch/examples/models:models", - "//executorch/exir/passes:init_mutable_pass", "//executorch/extension/llm/custom_ops:custom_ops_aot_py", "//executorch/extension/llm/export:export_lib", # one definition has to be included in the user of the libarary diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index a0cb7dab0ea..ca940adb687 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1078,6 +1078,22 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 llm_config.backend.xnnpack.enabled = True if llm_config.backend.xnnpack.enabled: + if llm_config.export.foundation_weights_file is not None: + gen_tag_fn: Callable[[torch.fx.Node], str] = lambda x: ( + llm_config.export.foundation_weights_file + if "lora" not in x.name + else None + ) + + from executorch.exir.passes.external_constants_pass import ( + delegate_external_constants_pass_unlifted, + ) + + delegate_external_constants_pass_unlifted( + gm=builder_exported.pre_autograd_graph_module, + gen_tag_fn=gen_tag_fn, + ) + builder = _to_edge_and_lower_llama_xnnpack( builder_exported, modelname, diff --git a/exir/passes/external_constants_pass.py b/exir/passes/external_constants_pass.py index d9bba4635ff..414e131d6f5 100644 --- a/exir/passes/external_constants_pass.py +++ b/exir/passes/external_constants_pass.py @@ -113,6 +113,28 @@ def delegate_external_constants_pass( for node in module.graph.nodes: if node.op == "placeholder" and is_param_node(ep, node): if gen_tag_fn is not None: - node.meta["delegate_constant_tag"] = gen_tag_fn(node) + node.meta.setdefault("custom", {}) + node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node) + mutated = True + return PassResult(gm, mutated) + + +# Note: this pass must be run on an unlifted graph, e.g. ep.module(), +# and not on a lifted graph, e.g. ep.graph_module. +# This is using 'get_attr' to tag constants, which only appears in +# unlifted graphs. +def delegate_external_constants_pass_unlifted( + gm: GraphModule, + gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None, +) -> PassResult: + mutated = False + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in module.graph.nodes: + if node.op == "get_attr": + if gen_tag_fn is not None: + node.meta.setdefault("custom", {}) + node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node) mutated = True return PassResult(gm, mutated) diff --git a/exir/program/_program.py b/exir/program/_program.py index 809565b0709..8df41bed200 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1908,7 +1908,9 @@ def write_tensor_data_to_file(self, outdir) -> None: """ assert self._tensor_data is not None for filename, cord in self._tensor_data.items(): - with open(os.path.join(outdir, f"{filename}.ptd"), "wb") as f: + if not filename.endswith(".ptd"): + filename += ".ptd" + with open(os.path.join(outdir, f"{filename}"), "wb") as f: logging.info(f"Writing data file to {filename}") cord.write_to_file(f) diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index ab14a0b4a49..de5564cae4f 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -211,6 +211,9 @@ class ExportConfig: so_library: Shared library to specify custom quantized operators. export_only: Whether to stop right after torch.export() and just save the exported .pt2 graph file. + foundation_weights_file: configure the foundation weights of a model + to be placed in a separate file, external to the PTE. Pass the + intended file name here. """ max_seq_length: int = 128 @@ -219,6 +222,7 @@ class ExportConfig: output_name: Optional[str] = None so_library: Optional[str] = None export_only: bool = False + foundation_weights_file: Optional[str] = None def __post_init__(self): if self.max_context_length < self.max_seq_length: @@ -545,6 +549,8 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 llm_config.export.so_library = args.so_library if hasattr(args, "export_only"): llm_config.export.export_only = args.export_only + if hasattr(args, "foundation_weights_file"): + llm_config.export.foundation_weights_file = args.foundation_weights_file # QuantizationConfig if hasattr(args, "quantization_mode"): diff --git a/runtime/executor/merged_data_map.h b/runtime/executor/merged_data_map.h index 3ed708f1d2b..d5ae97057f2 100644 --- a/runtime/executor/merged_data_map.h +++ b/runtime/executor/merged_data_map.h @@ -37,8 +37,10 @@ class MergedDataMap final : public NamedDataMap { // Check for duplicate keys. for (uint32_t k = 0; k < first->get_num_keys().get(); k++) { const auto key = first->get_key(k).get(); + const auto error = second->get_tensor_layout(key).error(); + // TODO(lfq): add API to check if key exists. ET_CHECK_OR_RETURN_ERROR( - second->get_tensor_layout(key).error() == Error::NotFound, + error == Error::NotFound || error == Error::NotImplemented, InvalidArgument, "Duplicate key %s.", key);