diff --git a/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp b/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp index 97d4491bc6a..5e794dde323 100644 --- a/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp +++ b/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp @@ -66,6 +66,37 @@ std::unique_ptr CreateQuantizationParamWrapper( return quantize_param_wrapper; } +std::string GetScalarValue(const Qnn_Scalar_t& scalar) { + switch (scalar.dataType) { + case QNN_DATATYPE_FLOAT_32: + return std::to_string(scalar.floatValue); + case QNN_DATATYPE_FLOAT_64: + return std::to_string(scalar.doubleValue); + case QNN_DATATYPE_UINT_64: + return std::to_string(scalar.uint64Value); + case QNN_DATATYPE_INT_64: + return std::to_string(scalar.int64Value); + case QNN_DATATYPE_UINT_32: + return std::to_string(scalar.uint32Value); + case QNN_DATATYPE_INT_32: + return std::to_string(scalar.int32Value); + case QNN_DATATYPE_UINT_16: + return std::to_string(scalar.uint16Value); + case QNN_DATATYPE_INT_16: + return std::to_string(scalar.int16Value); + case QNN_DATATYPE_UINT_8: + return std::to_string(scalar.uint8Value); + case QNN_DATATYPE_INT_8: + return std::to_string(scalar.int8Value); + case QNN_DATATYPE_BOOL_8: + return std::to_string(static_cast(scalar.bool8Value)); + case QNN_DATATYPE_STRING: + return std::string(scalar.stringValue); + default: + return "QNN_DATATYPE_UNDEFINED"; + } +} + std::shared_ptr CreateTensorWrapper( const std::string& tensor_name, Qnn_TensorType_t tensor_type, @@ -176,11 +207,60 @@ PYBIND11_MODULE(PyQnnWrapperAdaptor, m) { Qnn_QuantizationEncoding_t:: QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) .export_values(); + py::class_>(m, "OpWrapper") .def(py::init< const std::string&, const std::string&, - const std::string&>()); + const std::string&>()) + .def( + "GetInputTensors", + &OpWrapper::GetInputTensors, + "A function which gets input tensors") + .def( + "GetOutputTensors", + &OpWrapper::GetOutputTensors, + "A function which gets output tensors") + .def("GetOpType", &OpWrapper::GetOpType, "A function which gets op type") + .def("GetName", &OpWrapper::GetName, "A function which gets name") + .def( + "GetPackageName", + &OpWrapper::GetPackageName, + "A function which gets package name") + .def( + "GetParams", &OpWrapper::GetRawParams, "A function which gets params") + // lambda function + // python: op_wrapper.GetOpConfig() + .def( + "GetOpConfig", + [](OpWrapper& self) { + auto op_config = self.GetOpConfig(); + py::dict result; + py::list params_list; + py::list input_tensors_list; + py::list output_tensors_list; + result["version"] = op_config.version; + result["name"] = op_config.v1.name; + result["packageName"] = op_config.v1.packageName; + result["typeName"] = op_config.v1.typeName; + result["numOfParams"] = op_config.v1.numOfParams; + for (size_t i = 0; i < op_config.v1.numOfParams; ++i) { + params_list.append(op_config.v1.params[i]); + } + result["params"] = params_list; + result["numOfInputs"] = op_config.v1.numOfInputs; + for (size_t i = 0; i < op_config.v1.numOfInputs; ++i) { + input_tensors_list.append(op_config.v1.inputTensors[i]); + } + result["inputTensors"] = input_tensors_list; + result["numOfOutputs"] = op_config.v1.numOfOutputs; + for (size_t i = 0; i < op_config.v1.numOfOutputs; ++i) { + output_tensors_list.append(op_config.v1.outputTensors[i]); + } + result["outputTensors"] = output_tensors_list; + return result; + }, + "Get operator configuration"); py::class_>(m, "TensorWrapper") .def(py::init(py::overload_cast< @@ -197,7 +277,9 @@ PYBIND11_MODULE(PyQnnWrapperAdaptor, m) { py::class_(m, "QuantizeParamsWrapper"); py::class_(m, "Qnn_ScaleOffset_t") - .def(py::init()); + .def(py::init()) + .def_readonly("scale", &Qnn_ScaleOffset_t::scale) + .def_readonly("offset", &Qnn_ScaleOffset_t::offset); py::class_>( m, "PyQnnOpWrapper") @@ -248,6 +330,158 @@ PYBIND11_MODULE(PyQnnWrapperAdaptor, m) { .def("GetDataType", &PyQnnTensorWrapper::GetDataType) .def("GetName", &PyQnnTensorWrapper::GetName) .def("GetEncodings", &PyQnnTensorWrapper::GetEncodings); + + py::class_(m, "Qnn_OpConfig") + .def_readonly("version", &Qnn_OpConfig_t::version) + // getter + // python: op_wrapper.GetOpConfig().v1 + .def_property_readonly( + "v1", [](const Qnn_OpConfig_t& config) -> const Qnn_OpConfigV1_t& { + return config.v1; + }); + + py::enum_(m, "Qnn_OpConfigVersion") + .value("QNN_OPCONFIG_VERSION_1", QNN_OPCONFIG_VERSION_1) + .value("QNN_OPCONFIG_VERSION_UNDEFINED", QNN_OPCONFIG_VERSION_UNDEFINED) + .export_values(); + + py::class_(m, "Qnn_OpConfigV1") + .def_readonly("name", &Qnn_OpConfigV1_t::name) + .def_readonly("packageName", &Qnn_OpConfigV1_t::packageName) + .def_readonly("typeName", &Qnn_OpConfigV1_t::typeName) + .def_readonly("numOfParams", &Qnn_OpConfigV1_t::numOfParams) + .def_readonly("params", &Qnn_OpConfigV1_t::params) + .def_readonly("numOfInputs", &Qnn_OpConfigV1_t::numOfInputs) + .def_readonly("inputTensors", &Qnn_OpConfigV1_t::inputTensors) + .def_readonly("numOfOutputs", &Qnn_OpConfigV1_t::numOfOutputs) + .def_readonly("outputTensors", &Qnn_OpConfigV1_t::outputTensors); + + py::class_(m, "Qnn_Param") + .def_readonly("paramType", &Qnn_Param_t::paramType) + .def_readonly("name", &Qnn_Param_t::name) + .def_property_readonly( + "scalarParam", + [](const Qnn_Param_t& param) -> const Qnn_Scalar_t& { + if (param.paramType == Qnn_ParamType_t::QNN_PARAMTYPE_SCALAR) { + return param.scalarParam; + } + throw std::runtime_error("ParamType is not scalar."); + }) + .def_property_readonly( + "tensorParam", [](const Qnn_Param_t& param) -> const Qnn_Tensor_t& { + if (param.paramType == Qnn_ParamType_t::QNN_PARAMTYPE_TENSOR) { + return param.tensorParam; + } + throw std::runtime_error("ParamType is not tensor."); + }); + + py::enum_(m, "Qnn_ParamType_t") + .value("QNN_PARAMTYPE_SCALAR", Qnn_ParamType_t::QNN_PARAMTYPE_SCALAR) + .value("QNN_PARAMTYPE_TENSOR", Qnn_ParamType_t::QNN_PARAMTYPE_TENSOR) + .value( + "QNN_PARAMTYPE_UNDEFINED", Qnn_ParamType_t::QNN_PARAMTYPE_UNDEFINED) + .export_values(); + + py::class_(m, "Qnn_Scalar_t") + .def_readonly("dataType", &Qnn_Scalar_t::dataType) + .def("value", &GetScalarValue, "Get the value of the scalar as a string"); + + py::class_(m, "Qnn_Tensor_t") + .def_readonly("version", &Qnn_Tensor_t::version) + .def_property_readonly( + "v1", + [](Qnn_Tensor_t& t) -> Qnn_TensorV1_t& { + if (t.version == QNN_TENSOR_VERSION_1) { + return t.v1; + } + throw std::runtime_error("Tensor version is not V1."); + }) + .def_property_readonly("v2", [](Qnn_Tensor_t& t) -> Qnn_TensorV2_t& { + if (t.version == QNN_TENSOR_VERSION_2) { + return t.v2; + } + throw std::runtime_error("Tensor version is not V2."); + }); + + py::enum_(m, "Qnn_TensorVersion_t") + .value("QNN_TENSOR_VERSION_1", Qnn_TensorVersion_t::QNN_TENSOR_VERSION_1) + .value("QNN_TENSOR_VERSION_2", Qnn_TensorVersion_t::QNN_TENSOR_VERSION_2) + .value( + "QNN_TENSOR_VERSION_UNDEFINED", + Qnn_TensorVersion_t::QNN_TENSOR_VERSION_UNDEFINED) + .export_values(); + + py::class_(m, "QnnTensorV1") + .def_readonly("id", &Qnn_TensorV1_t::id) + .def_readonly("name", &Qnn_TensorV1_t::name) + .def_readonly("type", &Qnn_TensorV1_t::type) + .def_readonly("dataFormat", &Qnn_TensorV1_t::dataFormat) + .def_readonly("dataType", &Qnn_TensorV1_t::dataType) + .def_readonly("quantizeParams", &Qnn_TensorV1_t::quantizeParams) + .def_readonly("rank", &Qnn_TensorV1_t::rank) + // change dimensions pointer to vector(begin to rank) + .def_property_readonly( + "dimensions", + [](const Qnn_TensorV1_t& t) { + return std::vector(t.dimensions, t.dimensions + t.rank); + }) + .def_readonly("memType", &Qnn_TensorV1_t::memType); + + py::enum_(m, "Qnn_TensorMemType_t") + .value( + "QNN_TENSORMEMTYPE_RAW", Qnn_TensorMemType_t::QNN_TENSORMEMTYPE_RAW) + .value( + "QNN_TENSORMEMTYPE_MEMHANDLE", + Qnn_TensorMemType_t::QNN_TENSORMEMTYPE_MEMHANDLE) + .value( + "QNN_TENSORMEMTYPE_UNDEFINED", + Qnn_TensorMemType_t::QNN_TENSORMEMTYPE_UNDEFINED) + .export_values(); + + py::class_(m, "QnnQuantizeParams") + .def_readonly( + "encodingDefinition", &Qnn_QuantizeParams_t::encodingDefinition) + .def_readonly( + "quantizationEncoding", &Qnn_QuantizeParams_t::quantizationEncoding) + .def_property_readonly( + "scaleOffsetEncoding", + [](const Qnn_QuantizeParams_t& qp) { + if (qp.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + return qp.scaleOffsetEncoding; + } + throw std::runtime_error( + "Invalid quantization encoding type for scaleOffsetEncoding."); + }) + .def_property_readonly( + "axisScaleOffsetEncoding", [](const Qnn_QuantizeParams_t& qp) { + if (qp.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + return qp.axisScaleOffsetEncoding; + } + throw std::runtime_error( + "Invalid quantization encoding type for axisScaleOffsetEncoding."); + }); + + py::enum_(m, "QnnDefinition") + .value( + "QNN_DEFINITION_IMPL_GENERATED", + Qnn_Definition_t::QNN_DEFINITION_IMPL_GENERATED) + .value("QNN_DEFINITION_DEFINED", Qnn_Definition_t::QNN_DEFINITION_DEFINED) + .value( + "QNN_DEFINITION_UNDEFINED", + Qnn_Definition_t::QNN_DEFINITION_UNDEFINED) + .export_values(); + + py::class_(m, "QnnAxisScaleOffset") + .def_readonly("axis", &Qnn_AxisScaleOffset_t::axis) + .def_readonly("numScaleOffsets", &Qnn_AxisScaleOffset_t::numScaleOffsets) + .def_property_readonly( + "scaleOffset", [](const Qnn_AxisScaleOffset_t& aso) { + return std::vector( + aso.scaleOffset, aso.scaleOffset + aso.numScaleOffsets); + }); + // op_wrapper.GetParams() get std::vector } } // namespace qnn } // namespace backends diff --git a/backends/qualcomm/aot/wrappers/OpWrapper.h b/backends/qualcomm/aot/wrappers/OpWrapper.h index 338d8c5fede..3f0b742c0b3 100644 --- a/backends/qualcomm/aot/wrappers/OpWrapper.h +++ b/backends/qualcomm/aot/wrappers/OpWrapper.h @@ -102,6 +102,19 @@ class OpWrapper final { const std::string GetOpType() { return op_type_; } + const std::string GetName() { + return name_; + } + const std::string GetPackageName() { + return package_name_; + } + std::vector GetRawParams() const { + std::vector raw_params; + for (const auto& param : params_) { + raw_params.push_back(param.get()); + } + return raw_params; + } Qnn_OpConfig_t GetOpConfig(); private: diff --git a/backends/qualcomm/debugger/utils.py b/backends/qualcomm/debugger/utils.py new file mode 100644 index 00000000000..690517ea1b6 --- /dev/null +++ b/backends/qualcomm/debugger/utils.py @@ -0,0 +1,181 @@ +import os +import shutil +import tempfile + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import pandas as pd +from graphviz import Digraph + + +class DrawGraph: + def __init__( + self, + filename: str, + directory: str, + py_op_wrapper_list: [PyQnnWrapper.PyQnnOpWrapper], + dot_string=False, + ): + self.filename = filename + self.directory = directory + self.py_op_wrapper_list = py_op_wrapper_list + self.dot = Digraph(filename, format="svg") + self.dot.attr(rankdir="TB") + self.dot_string = dot_string + self.draw() + + def dfs_add_edges(self, node_name, visited, node_list): + if node_name in visited: + return + visited.add(node_name) + + input_list = node_list[node_name]["input_list"] + for input_node_name in input_list: + self.dot.edge(input_node_name, node_name) + self.dfs_add_edges(input_node_name, visited, node_list) + + def get_dot_graph(self): + visited = set() + node_list = {} + excel_data = [] + + self.get_node(node_list) + self.add_node(node_list, excel_data) + self.to_excel(excel_data) + + # add edge + for node_name, _ in node_list.items(): + if node_name not in visited: + self.dfs_add_edges(node_name, visited, node_list) + + return self.dot + + def get_node(self, node_list): + for py_op_wrapper in self.py_op_wrapper_list: + op_wrapper = py_op_wrapper.GetOpWrapper() + # TODO: multi output + for i in range(op_wrapper.GetOpConfig()["numOfOutputs"]): + if op_wrapper.GetOpConfig()["outputTensors"][0].version == 1: + node = op_wrapper.GetOpConfig()["outputTensors"][i].v1 + node_name = node.name + input_list = [] + for j in range(op_wrapper.GetOpConfig()["numOfInputs"]): + if op_wrapper.GetOpConfig()["inputTensors"][j].version == 1: + input_node = op_wrapper.GetOpConfig()["inputTensors"][j].v1 + input_node_name = input_node.name + if input_node_name not in node_list: + node_list[input_node_name] = { + "node": input_node, + "input_list": [], + } + input_list.append(input_node_name) + # TODO: tensor v2 + elif op_wrapper.GetOpConfig()["outputTensors"][j].version == 2: + raise ValueError("Unsupported tensor version: 2") + if node_name not in node_list: + node_list[node_name] = {"node": node, "input_list": input_list} + else: + node_list[node_name]["input_list"] = input_list + # TODO: tensor v2 + elif op_wrapper.GetOpConfig()["outputTensors"][i].version == 2: + raise ValueError("Unsupported tensor version: 2") + + def add_node(self, node_list, excel_data): + for node_name, tensor in node_list.items(): + node = tensor["node"] + name = node_name + data_type = node.dataType + tensor_type = node.type + dims = node.dimensions + quantization_encoding = node.quantizeParams.quantizationEncoding + scale = [] + offset = [] + if ( + quantization_encoding + == PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET + ): + scale.append(node.quantizeParams.scaleOffsetEncoding.scale) + offset.append(node.quantizeParams.scaleOffsetEncoding.offset) + elif ( + quantization_encoding + == PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET + ): + for i in range( + node.quantizeParams.axisScaleOffsetEncoding.numScaleOffsets + ): + scale.append( + node.quantizeParams.axisScaleOffsetEncoding.scaleOffset[i].scale + ) + offset.append( + node.quantizeParams.axisScaleOffsetEncoding.scaleOffset[ + i + ].offset + ) + excel_data.append( + { + "name": name, + "tensor_type": tensor_type, + "scale": scale, + "offset": offset, + } + ) + # Default color for intermediate nodes + bg_color = "white" + if "input" in node_name or "output" in node_name: + bg_color = "lightgreen" + elif tensor_type == 4: + bg_color = "lightpink" + label = f"""< + + + + + + + """ + label += "
name: {name}
data_type: {data_type}
tensor_type: {tensor_type}
dims: {dims}
quantization_encoding: {quantization_encoding}
>" + self.dot.node( + node_name, + label, + shape="box", + style="rounded", + fillcolor="transparent", + color="black", + ) + + def to_excel(self, excel_data): + param_rows = [] + activation_rows = [] + + for entry in excel_data: + name = entry["name"] + scale = entry["scale"] + offset = entry["offset"] + if ( + entry["tensor_type"] + == PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC + ): + param_rows.append({"name": name, "scale": scale, "offset": offset}) + else: + activation_rows.append({"name": name, "scale": scale, "offset": offset}) + param_df = pd.DataFrame(param_rows) + scale_df = pd.DataFrame(activation_rows) + output_file = f"{self.filename}.xlsx" + + with pd.ExcelWriter(output_file) as writer: + param_df.to_excel(writer, sheet_name="Parameters", index=False) + scale_df.to_excel(writer, sheet_name="Scales", index=False) + + def draw(self): + graph = self.get_dot_graph() + with tempfile.TemporaryDirectory() as tmp_dir: + temp_directory = f"{tmp_dir}/outputs" + graph.render( + self.filename, temp_directory, format="svg", cleanup=not self.dot_string + ) + source_file = os.path.join(temp_directory, f"{self.filename}.svg") + destination_file = os.path.join(".", f"{self.filename}.svg") + shutil.move(source_file, destination_file) + if self.dot_string: + dot_file = os.path.join(temp_directory, f"{self.filename}") + dot_dest_file = os.path.join(".", f"{self.filename}.dot") + shutil.move(dot_file, dot_dest_file) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 0ed66329c33..411c146be70 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -451,6 +451,23 @@ def forward(self, x): return x / 10 +class DrawGraphModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu1 = torch.nn.ReLU() + self.relu2 = torch.nn.ReLU() + kernel_sz = 32 + self.conv1 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True) + self.conv2 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x) + y1 = self.relu1(x1) + y2 = self.relu1(x2) + return y1 + y2 + + class EinsumBilinear(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 10917cdd6bf..07e0cee4bea 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -44,6 +44,20 @@ from executorch.backends.qualcomm.tests.models import * # noqa: F403 +import os +import random + +from collections import defaultdict +from typing import List + +from executorch.backends.qualcomm._passes.fuse_consecutive_transpose import ( + FuseConsecutiveTranspose, +) +from executorch.backends.qualcomm._passes.insert_io_qdq import InsertIOQDQ +from executorch.backends.qualcomm._passes.insert_requantize import InsertRequantize +from executorch.backends.qualcomm._passes.layout_transform import LayoutTransform +from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors +from executorch.backends.qualcomm.debugger.utils import DrawGraph from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model from executorch.examples.models.edsr import EdsrModel from executorch.examples.models.inception_v3 import InceptionV3Model @@ -57,6 +71,7 @@ from executorch.examples.models.wav2letter import Wav2LetterModel from executorch.exir import to_edge from executorch.exir.backend.backend_api import disable_validation +from executorch.exir.passes import PassManager class TestQNNFloatingPointOperator(TestQNN): @@ -1675,6 +1690,156 @@ def test_qnn_backend_context_direct(self): bundle_program["edge_program_manager"].to_executorch(), ) + def test_qnn_backend_draw_graph(self): + golden_data = """digraph test { + rankdir=TB + input_0_x_0 [label=< + + + + + + +
name: input_0_x_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_WRITE
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] + p_conv2_weight_0 [label=< + + + + + + +
name: p_conv2_weight_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
dims: [3, 3, 32, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] + p_conv2_bias_0 [label=< + + + + + + +
name: p_conv2_bias_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
dims: [32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] + aten_convolution_default_1_0 [label=< + + + + + + +
name: aten_convolution_default_1_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] + aten_relu_default_0 [label=< + + + + + + +
name: aten_relu_default_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] + aten_relu_default_1_0 [label=< + + + + + + +
name: aten_relu_default_1_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] + output_aten_add_tensor_0 [label=< + + + + + + +
name: output_aten_add_tensor_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] + p_conv1_weight_0 [label=< + + + + + + +
name: p_conv1_weight_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
dims: [3, 3, 32, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] + p_conv1_bias_0 [label=< + + + + + + +
name: p_conv1_bias_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
dims: [32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] + aten_convolution_default_0 [label=< + + + + + + +
name: aten_convolution_default_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] + input_0_x_0 -> aten_convolution_default_1_0 + p_conv2_weight_0 -> aten_convolution_default_1_0 + p_conv2_bias_0 -> aten_convolution_default_1_0 + aten_convolution_default_0 -> aten_relu_default_0 + input_0_x_0 -> aten_convolution_default_0 + p_conv1_weight_0 -> aten_convolution_default_0 + p_conv1_bias_0 -> aten_convolution_default_0 + aten_convolution_default_1_0 -> aten_relu_default_1_0 + aten_relu_default_0 -> output_aten_add_tensor_0 + aten_relu_default_1_0 -> output_aten_add_tensor_0 + } + """ + module = DrawGraphModel() # noqa: F405 + sample_input = (torch.randn(1, 32, 28, 28),) + delegated_program = capture_program(module, sample_input) + + """ + This piece of code simulates the behavior of the final preprocessing step to obtain the op wrapper list. + In practice, users need to set a breakpoint in the preprocessing step and use the DrawGraph tool to visualize the graph. + """ + qnn_compiler_passes = PassManager( + passes=[ + InsertRequantize(delegated_program.exported_program), + InsertIOQDQ(delegated_program.exported_program), + LayoutTransform( + delegated_program.exported_program, insert_permute=True + ), + FuseConsecutiveTranspose(), + ] + ) + + pass_result = qnn_compiler_passes( + delegated_program.exported_program.graph_module + ) + nodes_to_wrappers = defaultdict(dict) + node_visitors = get_node_visitors( + delegated_program.exported_program, enable_tensor_dump=False + ) + + py_op_wrapper_list = [] + for node in pass_result.graph_module.graph.nodes: + if node.op == "call_function": + if node.target.__name__ in node_visitors: + py_op_wrapper = node_visitors[node.target.__name__].define_node( + node, nodes_to_wrappers + ) + if py_op_wrapper is not None: + if isinstance(py_op_wrapper, List): + py_op_wrapper_list.extend(py_op_wrapper) + else: + py_op_wrapper_list.append(py_op_wrapper) + elif node.op in [ + "get_attr", + "placeholder", + "output", + ]: + continue + # random py_op_wrapper_list to check it's correctness + random.shuffle(py_op_wrapper_list) + DrawGraph("test", ".", py_op_wrapper_list, dot_string=True) + test_file = os.path.join(".", "test.dot") + with open(test_file, "r") as test: + test_data = test.read() + assert sorted(golden_data.split()) == sorted( + test_data.split() + ), "Generated .dot file does not match the golden file." + class TestQNNQuantizedUtils(TestQNN): # TODO: refactor to support different backends @@ -1997,6 +2162,175 @@ def test_qnn_backend_context_direct(self): bundle_program["edge_program_manager"].to_executorch(), ) + def test_qnn_backend_draw_graph(self): + golden_data = """digraph test { + rankdir=TB + aten_convolution_default_0 [label=< + + + + + + +
name: aten_convolution_default_0
data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] + aten_relu_default_0 [label=< + + + + + + +
name: aten_relu_default_0
data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] + quantized_decomposed_quantize_per_tensor_default_8_0 [label=< + + + + + + +
name: quantized_decomposed_quantize_per_tensor_default_8_0
data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 32, 28, 28]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] + b__frozen_param2_0 [label=< + + + + + + +
name: b__frozen_param2_0
data_type: Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
dims: [3, 3, 32, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] + b__frozen_param3_0 [label=< + + + + + + +
name: b__frozen_param3_0
data_type: Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
dims: [32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] + aten_convolution_default_1_0 [label=< + + + + + + +
name: aten_convolution_default_1_0
data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] + aten_relu_default_1_0 [label=< + + + + + + +
name: aten_relu_default_1_0
data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] + aten_add_tensor_0 [label=< + + + + + + +
name: aten_add_tensor_0
data_type: Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE
dims: [1, 28, 28, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] + output_quantized_decomposed_dequantize_per_tensor_tensor_0 [label=< + + + + + + +
name: output_quantized_decomposed_dequantize_per_tensor_tensor_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_READ
dims: [1, 32, 28, 28]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] + input_0_x_0 [label=< + + + + + + +
name: input_0_x_0
data_type: Qnn_DataType_t.QNN_DATATYPE_FLOAT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_APP_WRITE
dims: [1, 32, 28, 28]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED
> color=black fillcolor=transparent shape=box style=rounded] + b__frozen_param0_0 [label=< + + + + + + +
name: b__frozen_param0_0
data_type: Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
dims: [3, 3, 32, 32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] + b__frozen_param1_0 [label=< + + + + + + +
name: b__frozen_param1_0
data_type: Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_32
tensor_type: Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC
dims: [32]
quantization_encoding: Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET
> color=black fillcolor=transparent shape=box style=rounded] + quantized_decomposed_quantize_per_tensor_default_8_0 -> aten_convolution_default_0 + input_0_x_0 -> quantized_decomposed_quantize_per_tensor_default_8_0 + b__frozen_param0_0 -> aten_convolution_default_0 + b__frozen_param1_0 -> aten_convolution_default_0 + aten_convolution_default_0 -> aten_relu_default_0 + quantized_decomposed_quantize_per_tensor_default_8_0 -> aten_convolution_default_1_0 + b__frozen_param2_0 -> aten_convolution_default_1_0 + b__frozen_param3_0 -> aten_convolution_default_1_0 + aten_convolution_default_1_0 -> aten_relu_default_1_0 + aten_relu_default_0 -> aten_add_tensor_0 + aten_relu_default_1_0 -> aten_add_tensor_0 + aten_add_tensor_0 -> output_quantized_decomposed_dequantize_per_tensor_tensor_0 + } + """ + module = DrawGraphModel() # noqa: F405 + sample_input = (torch.randn(1, 32, 28, 28),) + module = self.get_qdq_module(module, sample_input) + delegated_program = capture_program(module, sample_input) + + """ + This piece of code simulates the behavior of the final preprocessing step to obtain the op wrapper list. + In practice, users need to set a breakpoint in the preprocessing step and use the DrawGraph tool to visualize the graph. + """ + qnn_compiler_passes = PassManager( + passes=[ + InsertRequantize(delegated_program.exported_program), + InsertIOQDQ(delegated_program.exported_program), + LayoutTransform( + delegated_program.exported_program, insert_permute=True + ), + FuseConsecutiveTranspose(), + ] + ) + + pass_result = qnn_compiler_passes( + delegated_program.exported_program.graph_module + ) + nodes_to_wrappers = defaultdict(dict) + node_visitors = get_node_visitors( + delegated_program.exported_program, enable_tensor_dump=False + ) + + py_op_wrapper_list = [] + for node in pass_result.graph_module.graph.nodes: + if node.op == "call_function": + if node.target.__name__ in node_visitors: + py_op_wrapper = node_visitors[node.target.__name__].define_node( + node, nodes_to_wrappers + ) + if py_op_wrapper is not None: + if isinstance(py_op_wrapper, List): + py_op_wrapper_list.extend(py_op_wrapper) + else: + py_op_wrapper_list.append(py_op_wrapper) + elif node.op in [ + "get_attr", + "placeholder", + "output", + ]: + continue + # random py_op_wrapper_list to check it's correctness + random.shuffle(py_op_wrapper_list) + DrawGraph("test", ".", py_op_wrapper_list, dot_string=True) + test_file = os.path.join(".", "test.dot") + with open(test_file, "r") as test: + test_data = test.read() + assert sorted(golden_data.split()) == sorted( + test_data.split() + ), "Generated .dot file does not match the golden file." + class TestExampleOssScript(TestQNN): def required_envs(self, conditions=None) -> bool: