diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index 034b75fa6d0..4e9cda21d02 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -178,6 +178,11 @@ def preprocess_multimethod( if len(py_op_wrapper_list) == len(edge_programs.values()): qnn_context_binary = qnn_manager.Compile(graph_name, py_op_wrapper_list) + if option.saver: + # TODO: Currently, only the first method is saved. Update this logic if saving multiple methods becomes necessary in the future. + exit( + f"Record all QNN API calls from saver backend at: {option.saver_output_dir}" + ) assert ( len(qnn_context_binary) != 0 ), "Failed to generate Qnn context binary." diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index c61173ad852..6957ff2e898 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -3384,6 +3384,38 @@ def test_qnn_backend_rewrite_prepared_observer(self): quantized_module = convert_pt2e(prepared) self.lower_module_and_test_output(quantized_module, sample_input) + def test_qnn_backend_saver_backend(self): + backend_options = generate_htp_compiler_spec(use_fp16=False) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + saver=True, + ) + module = Relu() # noqa: F405 + sample_input = (torch.randn([2, 5, 1, 3]),) + module = self.get_qdq_module(module, sample_input) + + from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( + flatbuffer_to_option, + option_to_flatbuffer, + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + option = flatbuffer_to_option(TestQNN.compiler_specs[0].value) + option.saver_output_dir = f"{tmp_dir}/saver_output" + TestQNN.compiler_specs[0].value = option_to_flatbuffer(option) + + with self.assertRaises(SystemExit): + self.lower_module_and_test_output(module, sample_input) + self.assertTrue( + os.path.isfile(f"{tmp_dir}/saver_output/params.bin"), + "failed to find params.bin", + ) + self.assertTrue( + os.path.isfile(f"{tmp_dir}/saver_output/saver_output.c"), + "failed to find saver_output.c", + ) + def test_qnn_backend_skip_node_id_partitioner(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))