diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py
index 21b16a29c58..034b75fa6d0 100644
--- a/backends/qualcomm/qnn_preprocess.py
+++ b/backends/qualcomm/qnn_preprocess.py
@@ -78,10 +78,7 @@ def _build_op_wrappers(
                         )
                         assert node.target == context_loader_target, err_msg
                         # if graph has context binary loader node, return directly
-                        return PreprocessResult(
-                            processed_bytes=node.meta[OpContextLoader.meta_ctx_bin],
-                            debug_handle_map={},
-                        )
+                        return node.meta[OpContextLoader.meta_ctx_bin]
                     except:
                         raise RuntimeError(err_msg)
 
@@ -161,7 +158,7 @@ def preprocess_multimethod(
                 generate_qnn_executorch_option(compile_spec)
             )
             qnn_manager.Init()
-            py_op_wrapper_list = []
+            py_op_wrapper_list, ctx_binary_list = [], []
             for j, programs in enumerate(edge_programs.values()):
                 logger.info(f"Processing Method({j}): ({i+1}/{num_sub_graphs})")
                 py_op_wrappers = QnnBackend._build_op_wrappers(
@@ -169,22 +166,36 @@ def preprocess_multimethod(
                     qnn_manager.IsTensorDump(),
                     option.op_package_options.op_package_infos,
                 )
-                py_op_wrapper_list.append(
-                    [py_op_wrapper.GetOpWrapper() for py_op_wrapper in py_op_wrappers]
-                )
+                if isinstance(py_op_wrappers, bytes):
+                    ctx_binary_list.append(py_op_wrappers)
+                else:
+                    py_op_wrapper_list.append(
+                        [
+                            py_op_wrapper.GetOpWrapper()
+                            for py_op_wrapper in py_op_wrappers
+                        ]
+                    )
 
-            qnn_context_binary = qnn_manager.Compile(graph_name, py_op_wrapper_list)
-            assert (
-                len(qnn_context_binary) != 0
-            ), "Failed to generate Qnn context binary."
-            qnn_manager.Destroy()
-            # methods should share the same context binary for current partition
-            for key in edge_programs.keys():
-                all_processed_results[key].append(
-                    PreprocessResult(
-                        processed_bytes=bytes(qnn_context_binary),
-                        debug_handle_map={},
+            if len(py_op_wrapper_list) == len(edge_programs.values()):
+                qnn_context_binary = qnn_manager.Compile(graph_name, py_op_wrapper_list)
+                assert (
+                    len(qnn_context_binary) != 0
+                ), "Failed to generate Qnn context binary."
+                qnn_manager.Destroy()
+                # methods should share the same context binary for current partition
+                for key in edge_programs.keys():
+                    all_processed_results[key].append(
+                        PreprocessResult(
+                            processed_bytes=bytes(qnn_context_binary),
+                            debug_handle_map={},
+                        )
                     )
-                )
+            elif len(ctx_binary_list) == len(edge_programs.values()):
+                for i, key in enumerate(edge_programs.keys()):
+                    all_processed_results[key].append(
+                        PreprocessResult(processed_bytes=ctx_binary_list[i])
+                    )
+            else:
+                raise RuntimeError("Hybrid compilation is not supported")
 
         return all_processed_results
diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py
index 747a6804957..7163ce88c27 100644
--- a/backends/qualcomm/tests/test_qnn_delegate.py
+++ b/backends/qualcomm/tests/test_qnn_delegate.py
@@ -5622,6 +5622,68 @@ def test_debugger_generate_optrace(self):
                         qhas_data = json.load(qhas_file)
                         self.assertIn("data", qhas_data)
 
+    def test_cli(self):
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            sample_input = torch.randn(1, 2, 3, 4)
+            ep = torch.export.export(Relu(), (sample_input,))  # noqa: F405
+            torch.export.save(ep, f"{tmp_dir}/relu.pt2")
+            torch.save(sample_input, f"{tmp_dir}/input_0_0.pt")
+            with open(f"{tmp_dir}/input_list", "w") as f:
+                f.write(f"{tmp_dir}/input_0_0.pt\n")
+
+            # quantize
+            cmds = [
+                "python",
+                "-m",
+                "examples.qualcomm.util_scripts.cli",
+                "quantize",
+                "--artifact",
+                f"{tmp_dir}/relu.pt2",
+                "--output_folder",
+                f"{tmp_dir}/q_out",
+                "--input_list",
+                f"{tmp_dir}/input_list",
+            ]
+            subprocess.run(cmds, stdout=subprocess.DEVNULL)
+            self.assertTrue(os.path.isfile(f"{tmp_dir}/q_out/relu_quantized.pt2"))
+            # compile
+            cmds = [
+                "python",
+                "-m",
+                "examples.qualcomm.util_scripts.cli",
+                "compile",
+                "--artifact",
+                f"{tmp_dir}/q_out/relu_quantized.pt2",
+                "--output_folder",
+                f"{tmp_dir}/c_out",
+                "--model",
+                self.model,
+            ]
+            subprocess.run(cmds, stdout=subprocess.DEVNULL)
+            self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/relu_quantized.pte"))
+            self.assertTrue(os.path.isfile(f"{tmp_dir}/c_out/relu_quantized.svg"))
+            # execute
+            cmds = [
+                "python",
+                "-m",
+                "examples.qualcomm.util_scripts.cli",
+                "execute",
+                "--artifact",
+                f"{tmp_dir}/c_out/relu_quantized.pte",
+                "--output_folder",
+                f"{tmp_dir}/e_out",
+                "--model",
+                self.model,
+                "--device",
+                self.device,
+                "--build_folder",
+                self.build_folder,
+                "--input_list",
+                f"{tmp_dir}/input_list",
+            ]
+            subprocess.run(cmds, stdout=subprocess.DEVNULL)
+            self.assertTrue(os.path.isfile(f"{tmp_dir}/e_out/output_0_0.pt"))
+
 
 def setup_environment():
     parser = setup_common_args_and_variables()
diff --git a/examples/qualcomm/qaihub_scripts/utils/export.py b/examples/qualcomm/qaihub_scripts/utils/export.py
index 4d252175dbb..2ee1968dd82 100644
--- a/examples/qualcomm/qaihub_scripts/utils/export.py
+++ b/examples/qualcomm/qaihub_scripts/utils/export.py
@@ -18,7 +18,6 @@
 from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
 from executorch.backends.qualcomm.utils.utils import (
     draw_graph,
-    ExecutorchBackendConfig,
     from_context_binary,
     generate_htp_compiler_spec,
     generate_qnn_executorch_compiler_spec,
@@ -26,6 +25,7 @@
 )
 from executorch.examples.qualcomm.qaihub_scripts.utils.utils import preprocess_binary
 from executorch.examples.qualcomm.utils import make_output_dir, SimpleADB
+from executorch.exir import ExecutorchBackendConfig
 from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
 
 
diff --git a/examples/qualcomm/util_scripts/README.md b/examples/qualcomm/util_scripts/README.md
new file mode 100644
index 00000000000..712bbcd4277
--- /dev/null
+++ b/examples/qualcomm/util_scripts/README.md
@@ -0,0 +1,79 @@
+# CLI Tool for Quantize / Compile / Deploy PyTorch Model with QNN Backend
+
+An easy-to-use tool for quantizing / compiling / executing .pte program with Qualcomm AI Engine Direct. Tool is verified with [host environement](../../../docs/source/backends-qualcomm.md#host-os).
+
+## Description
+
+This tool aims for users who want to deploy models with ExecuTorch runtime. It's possible for them to produce .pte program in few steps.
+
+### Quantizing Model
+
+* Save torch.nn.Module with .pt2 format & prepare input data
+  ```bash
+  # create workspace for following operations
+  cd path/to/executorch
+  mkdir cli_example
+  ```
+  ```python
+  # take SimpleModel as an example
+  import torch
+  from executorch.backends.qualcomm.tests.models import SimpleModel
+  from pathlib import Path
+  # make example inputs
+  example_inputs = (torch.randn(1, 32, 28, 28), torch.randn(1, 32, 28, 28))
+  # generate ExportedProgram
+  ep = torch.export.export(SimpleModel(), example_inputs)
+  # save to workspace
+  ws = f"{Path().cwd()}/cli_example"
+  torch.export.save(ep, f"{ws}/simple_model.pt2")
+  # prepare calibration dataset: 2 sets of data with 2 inputs each
+  input_list = ""
+  for i in range(2):
+      current_input = ""
+      for j in range(2):
+          file_name = f"{ws}/input_{i}_{j}.pt"
+          torch.save(torch.randn(1, 32, 28, 28), file_name)
+          current_input += f"{file_name} "
+      input_list += f"{current_input.strip()}\n"
+
+  with open(f"{ws}/input_list", 'w') as f:
+      f.write(input_list)
+  ```
+
+* Quantize
+  ```bash 
+  # user could get more information via: PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli quantize -h
+  PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli quantize -a cli_example/simple_model.pt2 -o cli_example/quantize_output -c use_8a8w -i cli_example/input_list --per_channel
+  ```
+* Artifacts for quantized .pt2 file
+  - `cli_example/quantize_output/simple_model_quantized.pt2`
+
+
+### Compiling Program
+
+* Compile .pt2 to .pte program
+  ```bash
+  # `pip install pydot` if package is missing
+  # user could get more information via: PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli compile -h
+  PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli compile -a cli_example/quantize_output/simple_model_quantized.pt2 -o cli_example/compile_output -m SM8750
+  ```
+* (Optional) Compile pre-generated context binary to .pte program
+  ```bash
+  # `pip install pydot` if package is missing
+  # user could get more information via: PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli compile -h
+  PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli compile -a model.bin -o path/to/model/output -m SM8750
+  ```
+* Artifacts for .pte file and figure of graph information
+  - `cli_example/compile_output/simple_model_quantized.pte`
+  - `cli_example/compile_output/simple_model_quantized.svg`
+
+### Executing Program
+
+* Execute .pte program
+  ```bash
+  # user could get more information via: PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli execute -h
+  PYTHONPATH=.. python -m examples.qualcomm.util_scripts.cli execute -a cli_example/compile_output/simple_model_quantized.pte -o cli_example/execute_output -i cli_example/input_list -s $DEVICE_SERIAL -b build-android -m SM8750
+  ```
+* Artifacts for .pte file and figure of graph information
+  - `cli_example/execute_output/output_{data_index}_{output_index}.pt`.
+  `data_index` represents the sequence of dataset, `output_index` stands for the order of graph output.
diff --git a/examples/qualcomm/util_scripts/cli.py b/examples/qualcomm/util_scripts/cli.py
new file mode 100644
index 00000000000..e4c4c5dcaf8
--- /dev/null
+++ b/examples/qualcomm/util_scripts/cli.py
@@ -0,0 +1,504 @@
+# Copyright (c) Qualcomm Innovation Center, Inc.
+# All rights reserved
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import importlib
+import logging
+import os
+import re
+from pathlib import Path
+
+import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
+import numpy as np
+
+import torch
+
+from executorch.backends.qualcomm._passes.qnn_pass_manager import (
+    get_capture_program_passes,
+)
+from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
+from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
+from executorch.backends.qualcomm.utils.constants import QCOM_PASS_ACTIVATE_KEY
+from executorch.backends.qualcomm.utils.utils import (
+    draw_graph,
+    dump_context_from_pte,
+    from_context_binary,
+    generate_htp_compiler_spec,
+    generate_qnn_executorch_compiler_spec,
+    generate_qnn_executorch_option,
+    QNN_QUANT_TYPE_MAP,
+    QNN_TENSOR_TYPE_MAP,
+    to_edge_transform_and_lower_to_qnn,
+)
+from executorch.examples.qualcomm.qaihub_scripts.utils.utils import preprocess_binary
+from executorch.examples.qualcomm.utils import (
+    make_output_dir,
+    make_quantizer,
+    SimpleADB,
+)
+from executorch.exir import ExecutorchBackendConfig
+from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
+from torchao.quantization import pt2e
+from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
+
+
+def get_logger():
+    logger = logging.getLogger("examples.qualcomm.util_scripts.cli")
+    handler = logging.StreamHandler()
+    handler.setFormatter(
+        logging.Formatter(
+            fmt="[%(asctime)s %(prefix)s] %(levelname)-8s: %(message)s",
+            datefmt="%Y-%m-%d %H:%M:%S",
+        )
+    )
+    logger.addHandler(handler)
+    logger.setLevel(logging.INFO)
+    logger.propagate = False
+    return logging.LoggerAdapter(logger, extra={"prefix": "QNN_BACKEND"})
+
+
+def get_io_info(pte_path, compiler_specs):
+    dtype_map = {}
+    for type_map in (QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP):
+        for k, v in type_map.items():
+            dtype_map.setdefault(v, k)
+
+    def fill_tensor_info(info, qnn_tensors, category):
+        for tensor in qnn_tensors:
+            encoding = tensor.GetEncodings()
+            quantization_info = {
+                "scale": encoding.data["scale"].tolist(),
+                "offset": encoding.data["offset"].tolist(),
+                "axis": encoding.axis,
+            }
+            info[category].append(
+                {
+                    "name": tensor.GetName(),
+                    "shape": tensor.GetDims().tolist(),
+                    "dtype": dtype_map[tensor.GetDataType()],
+                    "encoding": quantization_info,
+                }
+            )
+
+    in_key, out_key = "inputs", "outputs"
+    tensor_info = {in_key: [], out_key: []}
+
+    path_of_pte = Path(pte_path)
+    dump_context_from_pte(path_of_pte.absolute())
+    ctx_bin = [f for f in os.listdir(path_of_pte.parent) if Path(f).suffix == ".bin"][0]
+    # assume graph is fully delegated or it will be too hard to handle
+    with open(f"{path_of_pte.parent}/{ctx_bin}", "rb") as f:
+        ctx_bin = preprocess_binary(f.read(), compiler_specs)
+        # leverage QNN pybind interface to retrieve tensor encodings
+        qnn_mgr = PyQnnManagerAdaptor.QnnManager(
+            generate_qnn_executorch_option(compiler_specs), ctx_bin
+        )
+        assert qnn_mgr.Init().value == 0, "failed to load context binary"
+        graph_name = qnn_mgr.GetGraphNames()[0]
+        qnn_mgr.AllocateTensor(graph_name)
+        fill_tensor_info(tensor_info, qnn_mgr.GetGraphInputs(graph_name), in_key)
+        fill_tensor_info(tensor_info, qnn_mgr.GetGraphOutputs(graph_name), out_key)
+        qnn_mgr.Destroy()
+
+    return tensor_info
+
+
+def quantize(args):
+    logger = get_logger()
+
+    # get corresponding QnnQuantizer
+    try:
+        quant_dtype = getattr(QuantDtype, args.config)
+        act_observer = getattr(pt2e, args.activation_observer)
+        quantizer = make_quantizer(
+            quant_dtype=quant_dtype,
+            per_channel_conv=args.per_channel,
+            per_channel_linear=args.per_row,
+            act_observer=act_observer,
+        )
+    except Exception:
+        logger.error(
+            f"Failed to retrieve expected config {args.config} / {args.activation_observer}."
+        )
+        exit(1)
+
+    # step 0: load saved model
+    ep = torch.export.load(args.artifact)
+    # step 1: use prepare_pt2e to annotate QDQ pairs
+    ep_prepared = prepare_pt2e(ep.module(), quantizer)
+    logger.info(f"perform calibration on {args.artifact}")
+    # step 2: perform calibration
+    with open(args.input_list, "r") as f:
+        for line in f.read().split("\n")[:-1]:
+            inputs = [torch.load(t, weights_only=True) for t in line.split(" ")]
+            ep_prepared(*inputs)
+    # step 3: use convert_pt2e to fix encodings of QDQ pairs
+    logger.info(f"saving calibrated model for {args.artifact}")
+    ep_converted = convert_pt2e(ep_prepared)
+    ep_quantized = torch.export.export(ep_converted, tuple(inputs))
+    make_output_dir(args.output_folder)
+    torch.export.save(
+        ep_quantized, f"{args.output_folder}/{Path(args.artifact).stem}_quantized.pt2"
+    )
+
+
+def compile(args):
+    logger = get_logger()
+
+    # setup memory planning
+    memory_planning_pass = MemoryPlanningPass(
+        alloc_graph_input=args.shared_buffer is None,
+        alloc_graph_output=args.shared_buffer is None,
+    )
+
+    file_name, extension = Path(args.artifact).stem, Path(args.artifact).suffix
+    make_output_dir(args.output_folder)
+    # setup compiler spec dedicated to QNN HTP backend
+    backend_options = generate_htp_compiler_spec(use_fp16=True)
+    # setup general compiler spec for QNN
+    compiler_specs = generate_qnn_executorch_compiler_spec(
+        soc_model=getattr(QcomChipset, args.model),
+        backend_options=backend_options,
+        is_from_context_binary=extension == "bin",
+    )
+    if extension == ".bin":
+        custom_op_name = f"ctx_loader_{file_name}"
+        # step 1: generate ExportedProgram with custom op as a binary loader & lower it w/QnnBackend
+        logger.info(f"exporting program for {args.artifact}")
+        prog_info = from_context_binary(
+            args.artifact, custom_op_name, getattr(QcomChipset, args.model)
+        )
+        # step 2: write pte files and store final graph
+        logger.info(f"exporting {file_name}.pte")
+        with open(f"{args.output_folder}/{file_name}.pte", "wb") as f:
+            prog_info["edge_program_manager"].to_executorch(
+                config=ExecutorchBackendConfig(
+                    memory_planning_pass=memory_planning_pass
+                )
+            ).write_to_file(f)
+        logger.info(f"exporting network graph with {file_name}.svg")
+        draw_graph(file_name, args.output_folder, prog_info["exported_program"])
+    elif extension == ".pt2":
+        # step 0: prepare exported_program
+        ep = torch.export.load(args.artifact)
+        sample_inputs = ep.example_inputs[0]
+        # step 1: start lowering to QnnBackend
+        logger.info(f"start lowering program for {args.artifact}")
+        passes, user_passes = get_capture_program_passes(), []
+        if args.pass_job is not None:
+            for job in args.pass_job:
+                try:
+                    user_passes.append(
+                        importlib.import_module(
+                            "executorch.backends.qualcomm._passes", job
+                        )
+                    )
+                except Exception:
+                    logger.error(f"failed to extract designated pass '{args.artifact}'")
+
+        for user_pass in user_passes:
+            passes[user_pass][QCOM_PASS_ACTIVATE_KEY] = True
+
+        edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
+            module=ep.module(),
+            inputs=sample_inputs,
+            compiler_specs=compiler_specs,
+            passes_job=passes,
+        )
+        # step 2: write pte files and store final graph
+        logger.info(f"exporting {file_name}.pte")
+        with open(f"{args.output_folder}/{file_name}.pte", "wb") as f:
+            edge_prog_mgr.to_executorch(
+                config=ExecutorchBackendConfig(
+                    memory_planning_pass=memory_planning_pass
+                )
+            ).write_to_file(f)
+        logger.info(f"exporting network graph with {file_name}.svg")
+        draw_graph(file_name, args.output_folder, edge_prog_mgr.exported_program())
+    else:
+        logger.error(f"unsupported file extension for '{args.artifact}'")
+
+
+def execute(args):
+    logger = get_logger()
+
+    pte_name = Path(args.artifact).stem
+
+    # load input files
+    logger.info("loading user inputs")
+    user_inputs, input_list = [], ""
+    with open(args.input_list, "r") as f:
+        for line in f.read().split("\n")[:-1]:
+            inputs, input_names = [], ""
+            for data in line.split(" "):
+                input_names += f"{Path(data).stem}.raw "
+                inputs.append(torch.load(data, weights_only=True))
+            user_inputs.append(inputs)
+            input_list += input_names.strip() + "\n"
+
+    logger.info("retrieving graph I/O")
+    # setup compiler spec dedicated to QNN HTP backend
+    backend_options = generate_htp_compiler_spec(use_fp16=True)
+    # setup general compiler spec for QNN
+    compiler_specs = generate_qnn_executorch_compiler_spec(
+        soc_model=getattr(QcomChipset, args.model),
+        backend_options=backend_options,
+    )
+    io_info = get_io_info(args.artifact, compiler_specs)
+
+    logger.info("preparing ADB connection")
+    # leverage SimpleADB for e2e inference
+    adb = SimpleADB(
+        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
+        build_path=args.build_folder,
+        pte_path=args.artifact,
+        workspace=f"/data/local/tmp/executorch/{pte_name}",
+        device_id=args.device,
+        soc_model=args.model,
+        host_id=args.host,
+        shared_buffer=args.shared_buffer,
+    )
+
+    logger.info("pushing QNN libraries & other artifacts")
+    adb.push(inputs=user_inputs, input_list=input_list)
+
+    logger.info("starting inference")
+    adb.execute()
+
+    def post_process():
+        torch_to_numpy_dtype_dict = {
+            torch.bool: np.dtype("bool"),
+            torch.uint8: np.dtype("uint8"),
+            torch.int8: np.dtype("int8"),
+            torch.int16: np.dtype("int16"),
+            torch.int32: np.dtype("int32"),
+            torch.int64: np.dtype("int64"),
+            torch.float16: np.dtype("float16"),
+            torch.float32: np.dtype("float32"),
+            torch.float64: np.dtype("float64"),
+            torch.complex64: np.dtype("complex64"),
+            torch.complex128: np.dtype("complex128"),
+        }
+        output_info = io_info["outputs"]
+        output_folder = f"{args.output_folder}/outputs"
+        for _, f in enumerate(os.listdir(output_folder)):
+            filename = os.path.join(output_folder, f)
+            match_res = re.match(r".*([0-9]+)_([0-9]+)\.raw$", filename)
+            data_index, output_index = int(match_res.group(1)), int(match_res.group(2))
+            output = np.fromfile(
+                filename,
+                dtype=eval(
+                    f"np.{torch_to_numpy_dtype_dict[output_info[output_index]['dtype']]}"
+                ),
+            )
+            output = torch.from_numpy(
+                output.reshape(output_info[output_index]["shape"])
+            )
+            torch.save(
+                output, f"{args.output_folder}/output_{data_index}_{output_index}.pt"
+            )
+
+    logger.info("collecting output data")
+    make_output_dir(args.output_folder)
+    adb.pull(args.output_folder, post_process)
+    logger.info(f"execution finished, please check {args.output_folder} for results")
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description=(
+            "Utility to quantize / compile / execute models via Qualcomm backend"
+        ),
+    )
+    subparsers = parser.add_subparsers(
+        title="subcommands",
+        description=(
+            "[quantize]: Perform PTQ with QnnQuantizer for models in .pt2 extension. "
+            "[compile]: Compile model in .pt2 extenstion / context binary into .pte file. "
+            "[execute]: Perform on-device inference with given .pte."
+        ),
+    )
+
+    sub_quantize = subparsers.add_parser(
+        name="quantize",
+        help=(
+            "e.g. python -m executorch.example.qualcomm.util_scripts.cli quantize "
+            "-a model.pt2 -c use_8a8w -i calibration_data"
+        ),
+    )
+    sub_quantize.add_argument(
+        "-a",
+        "--artifact",
+        type=str,
+        required=True,
+        help="Path to saved .pt2 model in floating point precision.",
+    )
+    sub_quantize.add_argument(
+        "-o",
+        "--output_folder",
+        type=str,
+        default="./output_quantized",
+        help="Path to output artifact, store in 'output_quantized' if not given.",
+    )
+    sub_quantize.add_argument(
+        "-c",
+        "--config",
+        type=str,
+        default="use_8a8w",
+        help=(f"Configuration to be applied: {list(QuantDtype.__members__.keys())}."),
+    )
+    sub_quantize.add_argument(
+        "-i",
+        "--input_list",
+        type=str,
+        required=True,
+        help=(
+            "List of input files specified for calibration. "
+            'e.g. File content with: "input_0_0.pt2 input_0_1.pt2\\ninput_1_0.pt2 input_1_1.pt2" '
+            "means there are 2 sets of data for calibration on a graph with 2 inputs."
+        ),
+    )
+    sub_quantize.add_argument(
+        "--per_channel",
+        action="store_true",
+        help="Use per_channel encoding for operator convolution and its' families.",
+    )
+    sub_quantize.add_argument(
+        "--per_row",
+        action="store_true",
+        help="Use per_row encoding for operator linear.",
+    )
+    sub_quantize.add_argument(
+        "--activation_observer",
+        type=str,
+        default="MovingAverageMinMaxObserver",
+        help=(
+            "Activation observer for PTQ "
+            "(MinMaxObserver / MovingAverageMinMaxObserver / HistogramObserver)."
+        ),
+    )
+    sub_quantize.set_defaults(callback=quantize)
+
+    sub_compile = subparsers.add_parser(
+        name="compile",
+        help=(
+            "e.g. python -m executorch.example.qualcomm.util_scripts.cli compile "
+            "-a model.(pt2 / bin) -m SM8750"
+        ),
+    )
+    sub_compile.add_argument(
+        "-a",
+        "--artifact",
+        type=str,
+        required=True,
+        help="Path to saved .pt2 model or pre-generated context binary.",
+    )
+    sub_compile.add_argument(
+        "-m",
+        "--model",
+        type=str,
+        required=True,
+        help="SoC model. e.g. SM8750",
+    )
+    sub_compile.add_argument(
+        "-o",
+        "--output_folder",
+        type=str,
+        default="./output_pte",
+        help="Path to output artifacts, store in 'output_pte' if not given.",
+    )
+    sub_compile.add_argument(
+        "-p",
+        "--pass_job",
+        nargs="+",
+        type=str,
+        help=(
+            'Add extra passes for model lowering. e.g. "ExpandBroadcastTensorShape".'
+        ),
+    )
+    sub_compile.add_argument(
+        "--shared_buffer",
+        help=(
+            "Enable usage of shared buffer between application and backend for graph I/O."
+        ),
+        action="store_true",
+    )
+    sub_compile.set_defaults(callback=compile)
+
+    sub_execute = subparsers.add_parser(
+        name="execute",
+        help=(
+            "e.g. python -m executorch.example.qualcomm.util_scripts.cli "
+            "execute -p model.pte -i execution_data -s device_serial"
+        ),
+    )
+    sub_execute.add_argument(
+        "-a",
+        "--artifact",
+        type=str,
+        required=True,
+        help="Path to .pte file generated from 'compile' subcommand.",
+    )
+    sub_execute.add_argument(
+        "-i",
+        "--input_list",
+        type=str,
+        help=(
+            "List of input files specified for execution. "
+            'e.g. File content with: "input_0_0.pt2 input_0_1.pt2\\ninput_1_0.pt2 input_1_1.pt2" '
+            "means there are 2 sets of data for execution on a graph with 2 inputs.\n"
+        ),
+    )
+    sub_execute.add_argument(
+        "-m",
+        "--model",
+        type=str,
+        required=True,
+        help="SoC model. e.g. SM8750",
+    )
+    sub_execute.add_argument(
+        "-s",
+        "--device",
+        type=str,
+        required=True,
+        help="Serial no of device which could be obtained by 'adb devices'.",
+    )
+    sub_execute.add_argument(
+        "-o",
+        "--output_folder",
+        type=str,
+        default="./output_data",
+        help="Path to output data, store in 'output_data' if not given.",
+    )
+    sub_execute.add_argument(
+        "-b",
+        "--build_folder",
+        help="Path to cmake binary directory for android, e.g., /path/to/build-android",
+        type=str,
+        required=True,
+    )
+    sub_execute.add_argument(
+        "-H",
+        "--host",
+        type=str,
+        help="Gateway hostname.",
+    )
+    sub_execute.add_argument(
+        "--shared_buffer",
+        help=(
+            "Enable usage of shared buffer between application and backend for graph I/O."
+            " Please use with `--shared_buffer` in compile command."
+        ),
+        action="store_true",
+    )
+    sub_execute.set_defaults(callback=execute)
+
+    args = parser.parse_args()
+    args.callback(args)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py
index 6d9a6653ec7..e70510b0b70 100755
--- a/examples/qualcomm/utils.py
+++ b/examples/qualcomm/utils.py
@@ -9,6 +9,7 @@
 
 import argparse
 import os
+import shutil
 import subprocess
 import sys
 import tempfile
@@ -395,9 +396,7 @@ def build_executorch_binary(
 
 def make_output_dir(path: str):
     if os.path.exists(path):
-        for f in os.listdir(path):
-            os.remove(os.path.join(path, f))
-        os.removedirs(path)
+        shutil.rmtree(path, ignore_errors=True)
     os.makedirs(path)