diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py
index 632e67569f7..b9efcd7aa6c 100644
--- a/backends/qualcomm/_passes/annotate_quant_attrs.py
+++ b/backends/qualcomm/_passes/annotate_quant_attrs.py
@@ -9,10 +9,16 @@
import torch
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
from executorch.backends.qualcomm.utils.constants import (
+ QCOM_AXIS,
+ QCOM_DTYPE,
QCOM_ENCODING,
QCOM_QUANT_ATTRS,
+ QCOM_QUANT_MAX,
+ QCOM_QUANT_MIN,
QCOM_REQUANTIZE,
+ QCOM_SCALE,
QCOM_SCALES,
+ QCOM_ZERO_POINT,
QCOM_ZERO_POINTS,
)
from executorch.exir.dialects._ops import ops as exir_ops
@@ -52,45 +58,59 @@ def _expand(self, tensor, dim, axis) -> torch.Tensor:
order[axis], order[0] = order[0], order[axis]
return tensor.permute(order)
- # Find the the last dq node between regular op nodes
+ # Find the the last dq nodes between regular op nodes
# Return dq2 in example below when q1 is given as node parameter:
# ... -> n1 -> q1 -> dq1 -> q2 -> dq2 -> n2 -> ...
- def _find_last_dq_node(self, node: torch.fx.node.Node) -> torch.fx.node.Node:
- if list(node.users)[0].target in q_ops.union(dq_ops):
- return self._find_last_dq_node(list(node.users)[0])
- return node
+ def _find_last_dq_nodes(self, node: torch.fx.node.Node) -> torch.fx.node.Node:
+ if node is None:
+ return []
+
+ # If the node is last dq between regular op node, return it in a list
+ if node.target in dq_ops:
+ if all(user.target not in q_ops for user in node.users):
+ return [node]
+
+ last_dq_nodes = []
+ for user in list(node.users):
+ last_dq_nodes.extend(self._find_last_dq_nodes(user))
+
+ return last_dq_nodes
def _annotate_requant(self, n):
# Record requant attributes:
- # node1 -> q_ui8 -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
- # We store quant info for dq_ui8 and q_int32 in node1.meta
+ # node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
+ # We store {node2: quant_attr in dq_int32} in node1.meta
if n.target in q_ops and n.args[0].target not in dq_ops:
- dq_node = self._find_last_dq_node(n)
+ dq_nodes = self._find_last_dq_nodes(n)
q_attrs = get_quant_attrs(self.edge_program, n)
- dq_attrs = get_quant_attrs(self.edge_program, dq_node)
-
- # TODO: Store multiple pairs of requantize attributes when we have an op builder
- # that has multiple outputs that requires quant attributes.
- if self.skip_advanced_requant:
- if q_attrs["dtype"] != dq_attrs["dtype"]:
- dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
- n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
- else:
- # When dtype is the same but other specs such as scale and offset are different,
- # insert requant to improve accuracy.
- # Users can turn this feature off if any inference speed drop is observed.
- if any(
- q_attrs[attr] != dq_attrs[attr]
- for attr in [
- "scale",
- "zero_point",
- "quant_min",
- "quant_max",
- "dtype",
- ]
- ):
- dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
- n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
+ for dq_node in dq_nodes:
+ dq_attrs = get_quant_attrs(self.edge_program, dq_node)
+ # TODO: Store multiple pairs of requantize attributes when we have an op builder
+ # that has multiple outputs that requires quant attributes.
+ if self.skip_advanced_requant:
+ if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]:
+ dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
+ user_node = list(dq_node.users)[0]
+ n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
+ n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
+ else:
+ # When dtype is the same but other specs such as scale and offset are different,
+ # insert requant to improve accuracy.
+ # Users can turn this feature off if any inference speed drop is observed.
+ if any(
+ q_attrs[attr] != dq_attrs[attr]
+ for attr in [
+ QCOM_SCALE,
+ QCOM_ZERO_POINT,
+ QCOM_QUANT_MIN,
+ QCOM_QUANT_MAX,
+ QCOM_DTYPE,
+ ]
+ ):
+ dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
+ user_node = list(dq_node.users)[0]
+ n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
+ n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
# Dequant all the fold_quant parameters back to fp32.
# If an operation is not supported by QNN and got fallback, it will expect a fp32 param.
@@ -98,14 +118,14 @@ def _dequant_fold_params(self, n, quant_attrs, param):
if quant_attrs[QCOM_ENCODING] in [
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
]:
- dim, axis = param.dim(), quant_attrs["axis"]
+ dim, axis = param.dim(), quant_attrs[QCOM_AXIS]
scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis)
offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis)
param = param.sub(offsets).mul(scales).to(torch.float32).contiguous()
set_parameter(param, n.args[0], self.edge_program)
else:
- scale = quant_attrs["scale"]
- offset = quant_attrs["zero_point"]
+ scale = quant_attrs[QCOM_SCALE]
+ offset = quant_attrs[QCOM_ZERO_POINT]
param = param.sub(offset).mul(scale).to(torch.float32).contiguous()
set_parameter(param, n.args[0], self.edge_program)
diff --git a/backends/qualcomm/_passes/insert_requantize.py b/backends/qualcomm/_passes/insert_requantize.py
index 5291edeb9fa..11aad02a0cf 100644
--- a/backends/qualcomm/_passes/insert_requantize.py
+++ b/backends/qualcomm/_passes/insert_requantize.py
@@ -4,6 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+from collections import defaultdict
+from typing import Dict, List
+
import torch
from executorch.backends.qualcomm.utils.constants import (
@@ -38,6 +41,42 @@ def __init__(
super(InsertRequantize, self).__init__()
self.edge_program = edge_program
+ def _make_hashable(self, value):
+ if isinstance(value, dict):
+ return tuple(sorted(value.items()))
+ return value
+
+ def _invert_dict(self, requantize_dict):
+ inverted_dict = defaultdict(list)
+ for user_node_name, quant_attr in requantize_dict.items():
+ hashable_quant_attr = self._make_hashable(quant_attr)
+ inverted_dict[hashable_quant_attr].append(user_node_name)
+ return inverted_dict
+
+ def _insert_to_copy(
+ self,
+ graph_module: torch.fx.GraphModule,
+ node: torch.fx.node,
+ quant_attr: Dict,
+ user_nodes: List[str],
+ ):
+ with graph_module.graph.inserting_after(node):
+ users = list(node.users.keys())
+ inserted_n = graph_module.graph.create_node(
+ "call_function",
+ exir_ops.edge.aten._to_copy.default,
+ (node,),
+ )
+ inserted_n.meta["val"] = node.meta["val"]
+ inserted_n.meta[QCOM_QUANT_ATTRS] = quant_attr
+
+ # create node and replace input
+ if node.meta.get(QCOM_QUANTIZED_IO):
+ inserted_n.meta[QCOM_QUANTIZED_IO] = node.meta[QCOM_QUANTIZED_IO]
+
+ for user in filter(lambda u: u.name in user_nodes, users):
+ user.replace_input_with(node, inserted_n)
+
# TODO: Implement this function when we have an op with
# multiple outputs that requires quant attributes.
def _multi_output_annotation(self) -> None:
@@ -46,21 +85,20 @@ def _multi_output_annotation(self) -> None:
def _single_output_annotation(
self, gm: torch.fx.GraphModule, n: torch.fx.node
) -> None:
- with gm.graph.inserting_after(n):
- users = list(n.users.keys())
- inserted_n = gm.graph.create_node(
- "call_function",
- exir_ops.edge.aten._to_copy.default,
- (n,),
- )
-
- inserted_n.meta["val"] = n.meta["val"]
- inserted_n.meta[QCOM_QUANT_ATTRS] = n.meta.pop(QCOM_REQUANTIZE)
- if n.meta.get(QCOM_QUANTIZED_IO):
- inserted_n.meta[QCOM_QUANTIZED_IO] = n.meta[QCOM_QUANTIZED_IO]
+ # {user_node_name: quant_attr}
+ requantize_dict = n.meta.pop(QCOM_REQUANTIZE)
+ # {quant_attr: user_node_name_list}
+ group_quant_attr_dict = self._invert_dict(requantize_dict)
+ # TODO: If users of the node contain output node,
+ # we replace the node with to_copy op. However, it would
+ # be problem when the node has multiple to_copy ops
+ add_output = len(group_quant_attr_dict) == 1
- for user in users:
- user.replace_input_with(n, inserted_n)
+ for hashable_quant_attr, user_nodes in group_quant_attr_dict.items():
+ user_nodes_copy = user_nodes.copy()
+ if add_output:
+ user_nodes_copy.append("output")
+ self._insert_to_copy(gm, n, dict(hashable_quant_attr), user_nodes_copy)
def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
for n in graph_module.graph.nodes:
diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py
index 851b547eb6c..66bada86dc1 100644
--- a/backends/qualcomm/_passes/layout_transform.py
+++ b/backends/qualcomm/_passes/layout_transform.py
@@ -14,7 +14,6 @@
QCOM_INSERTED_PERMUTE,
QCOM_LAYOUT_CHANGE,
QCOM_QUANT_ATTRS,
- QCOM_REQUANTIZE,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
@@ -133,8 +132,6 @@ def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
# if dimemsion is not kept, we'll have no clue how to do layout transform
if len(node.args) < 3 or not node.args[2]:
return False
- if node.target in self.qdq_opset:
- return QCOM_REQUANTIZE in node.meta
return node.target in self.layout_agnostic_ops
def is_edge_condition(self, node):
diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md
index a81df0d6def..3a97e8d6d6a 100644
--- a/backends/qualcomm/builders/README.md
+++ b/backends/qualcomm/builders/README.md
@@ -206,21 +206,21 @@ Now, we can start to fill in function body step by step:
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
```
Through the information in [Check Operator Spec](#check-operator-spec) section, we could easily extract the desired nodes.
The `get_tensor` method is responsible for retrieving torch tensor in correct axis order if `layout_transform` pass happened to apply.
The `define_tensor` method is for generating tensor object for QNN API and will be memorized by aforementioned `node_to_wrappers`.
And yet, there are arguments worth for addressing more:
- - **node**: current graph node
+ - **tensor_source_node**: current graph source node of the tensor
+ - **target_build_node**: current node to build, which is important for fixed point mixed-precision to work properly
- **tensor**: torch tensor emitted by node
- **tensor_type**: type compatible with QNN SDK, oftenly use `QNN_TENSOR_TYPE_NATIVE` for intermediate outputs and `QNN_TENSOR_TYPE_STATIC` for constant parameters
- **nodes_to_wrappers**: dictionary of graph node and its output tensor (note: the tensor here is not a torch tensor but a wrapped object for QNN)
- - **is_input_tensor**: flag to tell if current tensor is input activation or parameter, which is important for fixed point mixed-precision to work properly
- **node_name**: (optional) tensor name for user to specify
- **wrapper_idx**: (optional) defaults to zero if node is not a tuple, otherwise it acts as an indexer to output tensors. e.g. when slicing input tensor into multiple outputs, `wrapper_idx` is necessary for getting correct wrapped tensor object
@@ -230,23 +230,24 @@ Now, we can start to fill in function body step by step:
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
+ node,
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
bias_node = node.args[3]
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
+ node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
```
- The logic should be similar and straightforward. Please carefully set arguments `tensor_type`, `is_input_tensor` according to tensors' property.
+ The logic should be similar and straightforward. Please carefully set arguments `tensor_type`
+ according to tensors' property.
3. Define parameters:
```python
@@ -266,11 +267,11 @@ Now, we can start to fill in function body step by step:
```python
output_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
```
Althought the input / output activations might map to the graph IOs (a.k.a. user inputs / outputs) with corresponding type `QNN_TENSOR_TYPE_APP_READ` / `QNN_TENSOR_TYPE_APP_WRITE`. Users are still expected to have `QNN_TENSOR_TYPE_NATIVE` for all nodes' IOs and leave the detection logic handled inside `define_tensor` method.
diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py
index 95b147aba51..672cf8b623c 100644
--- a/backends/qualcomm/builders/node_visitor.py
+++ b/backends/qualcomm/builders/node_visitor.py
@@ -173,16 +173,19 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict):
)
def get_quant_encoding_conf(
- self, node: torch.fx.Node, is_input_tensor: bool = False
+ self, node: torch.fx.Node, target_node: torch.fx.Node
) -> Tuple[Any, Dict]:
if not node.meta.get(QCOM_QUANT_ATTRS, None):
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
{},
)
+ is_input_tensor = node != target_node
quant_attrs = (
- node.meta[QCOM_REQUANTIZE]
- if QCOM_REQUANTIZE in node.meta and is_input_tensor
+ node.meta[QCOM_REQUANTIZE][target_node.name]
+ if QCOM_REQUANTIZE in node.meta
+ and is_input_tensor
+ and target_node.name in node.meta[QCOM_REQUANTIZE]
else node.meta[QCOM_QUANT_ATTRS]
)
if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING:
@@ -282,11 +285,11 @@ def define_custom_tensor_wrapper(
def define_tensor(
self,
- node: torch.fx.Node,
+ tensor_source_node: torch.fx.Node,
+ target_build_node: torch.fx.Node,
tensor: torch.Tensor,
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
- is_input_tensor: bool,
node_name: str = None,
wrapper_idx: int = 0,
) -> PyQnnWrapper.TensorWrapper:
@@ -294,28 +297,32 @@ def define_tensor(
Covert torch.Tensor to TensorWrapper
Args:
- node: EdgeIR Node
+ tensor_source_node: EdgeIR Node
+ target_build_node: Current node to build
tensor: EdgeIR Tensor
tensor_type: QNN tensor type
nodes_to_wrappers: Set contains edge_graph values(node targets)
- is_input_tensor: Whether tensor is a fake input tensor relatively to
- the op builder that is calling this function
"""
if node_name is None:
- node_name = node.name
+ node_name = tensor_source_node.name
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached
- tensor_name = f"{node.name}_{wrapper_idx}"
- if is_graph_input(node, self.edge_program):
- tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name
- if is_graph_output(node):
+ tensor_name = f"{tensor_source_node.name}_{wrapper_idx}"
+ if is_graph_input(tensor_source_node, self.edge_program):
+ tensor_name = (
+ "input_"
+ + str(self.external_ids[tensor_source_node])
+ + "_"
+ + tensor_name
+ )
+ if is_graph_output(tensor_source_node):
tensor_name = "output_" + tensor_name
dims = [1] if len(tensor.size()) == 0 else tensor.size()
- tensor_type = self.get_tensor_type(node, tensor_type)
+ tensor_type = self.get_tensor_type(tensor_source_node, tensor_type)
quant_encoding, quant_configs = self.get_quant_encoding_conf(
- node, is_input_tensor
+ tensor_source_node, target_build_node
)
dtype = self.get_data_type(tensor, quant_configs)
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
@@ -334,7 +341,7 @@ def define_tensor(
if quant_configs:
tensor = self.get_quant_tensor_value(
tensor,
- node.meta[QCOM_QUANT_ATTRS],
+ tensor_source_node.meta[QCOM_QUANT_ATTRS],
quant_configs,
)
tensor_wrapper = PyQnnWrapper.TensorWrapper(
diff --git a/backends/qualcomm/builders/op_add.py b/backends/qualcomm/builders/op_add.py
index 1cc5ae7fe6f..b5edfd7bb52 100644
--- a/backends/qualcomm/builders/op_add.py
+++ b/backends/qualcomm/builders/op_add.py
@@ -27,11 +27,11 @@ def define_node(
) -> PyQnnWrapper.PyQnnOpWrapper:
out_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
out_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
add_output_tensors = [output_tensor_wrapper]
@@ -43,10 +43,10 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
tensor_type,
nodes_to_wrappers,
- is_input_tensor=True,
)
add_input_tensors.append(input_tensor_wrapper)
diff --git a/backends/qualcomm/builders/op_avg_pool2d.py b/backends/qualcomm/builders/op_avg_pool2d.py
index 5ad3fc36c99..394d4008587 100644
--- a/backends/qualcomm/builders/op_avg_pool2d.py
+++ b/backends/qualcomm/builders/op_avg_pool2d.py
@@ -32,19 +32,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
# kernel info
filter_size = cast(List[int], node.args[1])
diff --git a/backends/qualcomm/builders/op_batch_norm.py b/backends/qualcomm/builders/op_batch_norm.py
index 9ca299e7432..aa14df0cb74 100644
--- a/backends/qualcomm/builders/op_batch_norm.py
+++ b/backends/qualcomm/builders/op_batch_norm.py
@@ -49,10 +49,10 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
bias_node = node.args[2]
@@ -65,20 +65,20 @@ def define_node(
self.update_encoding(bias_node, bias_tensor, eps)
bias_tensor_wrapper = self.define_tensor(
bias_node,
+ node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
filter_tensor = filter_tensor / torch.sqrt(var_tensor + eps)
self.update_encoding(filter_node, filter_tensor, eps)
filter_tensor_wrapper = self.define_tensor(
filter_node,
+ node,
filter_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
batch_norm_input_tensors = [
@@ -89,11 +89,11 @@ def define_node(
output_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
batch_norm_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_bmm.py b/backends/qualcomm/builders/op_bmm.py
index 794d66991d3..46fbff1cc7e 100644
--- a/backends/qualcomm/builders/op_bmm.py
+++ b/backends/qualcomm/builders/op_bmm.py
@@ -32,20 +32,20 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
bmm_input_tensors.append(input_tensor_wrapper)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
bmm_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_cat.py b/backends/qualcomm/builders/op_cat.py
index cf18690498d..7f160856390 100644
--- a/backends/qualcomm/builders/op_cat.py
+++ b/backends/qualcomm/builders/op_cat.py
@@ -36,10 +36,10 @@ def define_node(
list_of_tensor_wrappers.append(
self.define_tensor(
tensor_input,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
)
@@ -52,11 +52,11 @@ def define_node(
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
# node args[1] might not exist
diff --git a/backends/qualcomm/builders/op_ceil.py b/backends/qualcomm/builders/op_ceil.py
index 883befbccf4..19fe14d6392 100644
--- a/backends/qualcomm/builders/op_ceil.py
+++ b/backends/qualcomm/builders/op_ceil.py
@@ -29,19 +29,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
ceil_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_clamp.py b/backends/qualcomm/builders/op_clamp.py
index 0c69a8d333c..0f9a9ffa196 100644
--- a/backends/qualcomm/builders/op_clamp.py
+++ b/backends/qualcomm/builders/op_clamp.py
@@ -31,10 +31,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
# default value of output_min and output_max
@@ -51,11 +51,11 @@ def define_node(
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
clamp_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py
index 30207a03920..9daeab6d4bf 100644
--- a/backends/qualcomm/builders/op_conv2d.py
+++ b/backends/qualcomm/builders/op_conv2d.py
@@ -119,16 +119,16 @@ def _define_conv1d(
op_wrapper_list = [] # op_wrapper to return
unsqueeze_input_node = node.args[0]
input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf(
- unsqueeze_input_node,
+ unsqueeze_input_node, node
)
unsqueeze_input_tensor = self.get_tensor(unsqueeze_input_node, node)
unsqueeze_input_tensor_wrapper = self.define_tensor(
unsqueeze_input_node,
+ node,
unsqueeze_input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
unsqueeze_output_tensor = unsqueeze_input_tensor.unsqueeze(1).contiguous()
dtype = self.get_data_type(unsqueeze_output_tensor, input_quant_configs)
@@ -165,10 +165,10 @@ def _define_conv1d(
filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous()
filter_tensor_wrapper = self.define_tensor(
filter_node,
+ node,
filter_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
conv_input_tensors = [unsqueeze_output_tensor_wrapper, filter_tensor_wrapper]
if node.args[2] is not None:
@@ -176,10 +176,10 @@ def _define_conv1d(
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
+ node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
conv_input_tensors.append(bias_tensor_wrapper)
@@ -249,11 +249,11 @@ def _define_conv1d(
)
squeeze_output_tensor = self.get_tensor(node, node)
squeeze_output_tensor_wrapper = self.define_tensor(
+ node,
node,
squeeze_output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
node_name=node.name,
)
squeeze_op.AddInputTensors([conv_output_tensor_wrapper])
@@ -274,10 +274,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
filter_node = node.args[1]
@@ -288,10 +288,10 @@ def define_node(
filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous()
filter_tensor_wrapper = self.define_tensor(
filter_node,
+ node,
filter_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper]
@@ -300,20 +300,20 @@ def define_node(
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
+ node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
conv_input_tensors.append(bias_tensor_wrapper)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
conv_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_cos.py b/backends/qualcomm/builders/op_cos.py
index 98caed10d18..3858a947d93 100644
--- a/backends/qualcomm/builders/op_cos.py
+++ b/backends/qualcomm/builders/op_cos.py
@@ -30,19 +30,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
cos_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_depth_to_space.py b/backends/qualcomm/builders/op_depth_to_space.py
index e7343720987..56c57b4bd5e 100644
--- a/backends/qualcomm/builders/op_depth_to_space.py
+++ b/backends/qualcomm/builders/op_depth_to_space.py
@@ -32,19 +32,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
block_size = []
diff --git a/backends/qualcomm/builders/op_dequantize.py b/backends/qualcomm/builders/op_dequantize.py
index f80103b4b89..507ecc4e3e3 100644
--- a/backends/qualcomm/builders/op_dequantize.py
+++ b/backends/qualcomm/builders/op_dequantize.py
@@ -27,20 +27,20 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
inp_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
dequant_input_tensors.append(inp_tensor_wrapper)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
dequant_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_div.py b/backends/qualcomm/builders/op_div.py
index d3eabb33562..ce3f96abc7f 100644
--- a/backends/qualcomm/builders/op_div.py
+++ b/backends/qualcomm/builders/op_div.py
@@ -27,11 +27,11 @@ def define_node(
) -> PyQnnWrapper.PyQnnOpWrapper:
out_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
out_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
div_output_tensors = [output_tensor_wrapper]
@@ -43,10 +43,10 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
tensor_type,
nodes_to_wrappers,
- is_input_tensor=True,
)
div_input_tensors.append(input_tensor_wrapper)
diff --git a/backends/qualcomm/builders/op_embedding.py b/backends/qualcomm/builders/op_embedding.py
index 8ae3b64fbfa..5b0d1600393 100644
--- a/backends/qualcomm/builders/op_embedding.py
+++ b/backends/qualcomm/builders/op_embedding.py
@@ -32,31 +32,31 @@ def define_node(
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
+ node,
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=True,
)
indices_node = node.args[1]
indices_tensor = self.get_tensor(indices_node, node)
indices_tensor_wrapper = self.define_tensor(
indices_node,
+ node,
indices_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
gather_input_tensors = [weight_tensor_wrapper, indices_tensor_wrapper]
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
gather_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_expand.py b/backends/qualcomm/builders/op_expand.py
index 3f5c266cdd8..c098ed00c94 100644
--- a/backends/qualcomm/builders/op_expand.py
+++ b/backends/qualcomm/builders/op_expand.py
@@ -31,19 +31,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
sizes = cast(List[int], node.args[1])
diff --git a/backends/qualcomm/builders/op_gelu.py b/backends/qualcomm/builders/op_gelu.py
index 9bd050cf09e..c178740448e 100644
--- a/backends/qualcomm/builders/op_gelu.py
+++ b/backends/qualcomm/builders/op_gelu.py
@@ -30,19 +30,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
gelu_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_group_norm.py b/backends/qualcomm/builders/op_group_norm.py
index 44a07e4e588..d498b202d71 100644
--- a/backends/qualcomm/builders/op_group_norm.py
+++ b/backends/qualcomm/builders/op_group_norm.py
@@ -32,41 +32,41 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
weight_node = node.args[1]
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
+ node,
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
bias_node = node.args[2]
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
+ node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
group = node.args[6]
epsilon = node.args[7]
output_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
group_norm_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_hardsigmoid.py b/backends/qualcomm/builders/op_hardsigmoid.py
index 196777d6287..1acc08a387d 100644
--- a/backends/qualcomm/builders/op_hardsigmoid.py
+++ b/backends/qualcomm/builders/op_hardsigmoid.py
@@ -32,19 +32,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
hardsigmoid_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_hardswish.py b/backends/qualcomm/builders/op_hardswish.py
index 36eda8b3425..ed28ff95f78 100644
--- a/backends/qualcomm/builders/op_hardswish.py
+++ b/backends/qualcomm/builders/op_hardswish.py
@@ -30,19 +30,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
hardswish_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_hardtanh.py b/backends/qualcomm/builders/op_hardtanh.py
index 8d903852779..68bafaaab8b 100644
--- a/backends/qualcomm/builders/op_hardtanh.py
+++ b/backends/qualcomm/builders/op_hardtanh.py
@@ -32,10 +32,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
# default value of output_min and output_max
@@ -50,11 +50,11 @@ def define_node(
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
hardtanh_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_index.py b/backends/qualcomm/builders/op_index.py
index 6f8dc558fe5..4ddab23aeae 100644
--- a/backends/qualcomm/builders/op_index.py
+++ b/backends/qualcomm/builders/op_index.py
@@ -31,10 +31,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
if len(node.args[1]) > 1:
@@ -47,21 +47,21 @@ def define_node(
indices_tensor_wrapper = self.define_tensor(
indices_node,
+ node,
indices_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
gather_input_tensors = [input_tensor_wrapper, indices_tensor_wrapper]
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
gather_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py
index af5311dfb2a..c317cc0a8b7 100644
--- a/backends/qualcomm/builders/op_index_put.py
+++ b/backends/qualcomm/builders/op_index_put.py
@@ -24,10 +24,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
indicies_node = node.args[1]
indices_list = [
@@ -45,10 +45,10 @@ def define_node(
indices_tensor_wrapper = self.define_tensor(
indice_node[0],
+ node,
indices_qnn,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
value_node = node.args[2]
@@ -56,18 +56,18 @@ def define_node(
value_tensor_wrapper = self.define_tensor(
value_node,
+ node,
value_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
index_put_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py
index 635e12d2ee1..2006c716489 100644
--- a/backends/qualcomm/builders/op_layer_norm.py
+++ b/backends/qualcomm/builders/op_layer_norm.py
@@ -34,10 +34,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
normalized_shapes = node.args[1]
@@ -57,31 +57,31 @@ def define_node(
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
+ node,
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
bias_node = node.args[3]
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
+ node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
epsilon = node.args[4]
output_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
layer_norm_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py
index e4f16d4473d..a16bbd28c98 100644
--- a/backends/qualcomm/builders/op_linear.py
+++ b/backends/qualcomm/builders/op_linear.py
@@ -38,10 +38,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
linear_input_tensors.append(input_tensor_wrapper)
@@ -59,10 +59,10 @@ def define_node(
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
+ node,
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
linear_input_tensors.append(weight_tensor_wrapper)
@@ -78,20 +78,20 @@ def define_node(
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
+ node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
linear_input_tensors.append(bias_tensor_wrapper)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
linear_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_log_softmax.py b/backends/qualcomm/builders/op_log_softmax.py
index fdd298f988b..d395d5eb66e 100644
--- a/backends/qualcomm/builders/op_log_softmax.py
+++ b/backends/qualcomm/builders/op_log_softmax.py
@@ -32,20 +32,20 @@ def define_node(
log_softmax_inp_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
log_softmax_input_tensors = [log_softmax_inp_tensor_wrapper]
output_tensor = self.get_tensor(node, node)
log_softmax_output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
log_softmax_output_tensors = [log_softmax_output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_matmul.py b/backends/qualcomm/builders/op_matmul.py
index 2ea6798e26c..c9215d11b4d 100644
--- a/backends/qualcomm/builders/op_matmul.py
+++ b/backends/qualcomm/builders/op_matmul.py
@@ -32,20 +32,20 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
matmul_input_tensors.append(input_tensor_wrapper)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
matmul_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_max_pool2d.py b/backends/qualcomm/builders/op_max_pool2d.py
index 27f14889bf3..8d0087eb2c6 100644
--- a/backends/qualcomm/builders/op_max_pool2d.py
+++ b/backends/qualcomm/builders/op_max_pool2d.py
@@ -32,10 +32,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
users = list(node.users.keys())
@@ -51,11 +51,11 @@ def define_node(
output_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
# kernel info
filter_size = cast(List[int], node.args[1])
diff --git a/backends/qualcomm/builders/op_mean_dim.py b/backends/qualcomm/builders/op_mean_dim.py
index e60e3e790b9..313b24420db 100644
--- a/backends/qualcomm/builders/op_mean_dim.py
+++ b/backends/qualcomm/builders/op_mean_dim.py
@@ -32,10 +32,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
# mean dims and keep dims
@@ -51,11 +51,11 @@ def define_node(
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
reduce_mean_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_mul.py b/backends/qualcomm/builders/op_mul.py
index db9808bb105..3138d3b8c9b 100644
--- a/backends/qualcomm/builders/op_mul.py
+++ b/backends/qualcomm/builders/op_mul.py
@@ -27,11 +27,11 @@ def define_node(
) -> PyQnnWrapper.PyQnnOpWrapper:
out_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
out_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
mul_output_tensors = [output_tensor_wrapper]
@@ -43,10 +43,10 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
tensor_type,
nodes_to_wrappers,
- is_input_tensor=True,
)
mul_input_tensors.append(input_tensor_wrapper)
diff --git a/backends/qualcomm/builders/op_pad.py b/backends/qualcomm/builders/op_pad.py
index 9ca385ff850..10948859be9 100644
--- a/backends/qualcomm/builders/op_pad.py
+++ b/backends/qualcomm/builders/op_pad.py
@@ -31,20 +31,20 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
pad_inp_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
pad_input_tensors = [pad_inp_tensor_wrapper]
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
pad_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_pow.py b/backends/qualcomm/builders/op_pow.py
index b4153d458ba..cf5b7595697 100644
--- a/backends/qualcomm/builders/op_pow.py
+++ b/backends/qualcomm/builders/op_pow.py
@@ -30,11 +30,11 @@ def define_node(
) -> PyQnnWrapper.PyQnnOpWrapper:
out_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
out_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
pow_output_tensors = [output_tensor_wrapper]
@@ -46,10 +46,10 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
tensor_type,
nodes_to_wrappers,
- is_input_tensor=True,
)
# scalar input
@@ -77,10 +77,10 @@ def define_node(
scalar_tensor_wrapper = self.define_tensor(
scalar_node,
+ node,
scalar_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
pow_input_tensors = [input_tensor_wrapper, scalar_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_prelu.py b/backends/qualcomm/builders/op_prelu.py
index 5da017b8b72..4057b3d5559 100644
--- a/backends/qualcomm/builders/op_prelu.py
+++ b/backends/qualcomm/builders/op_prelu.py
@@ -38,10 +38,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
prelu_inp_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
if node.target.__name__ == "aten.leaky_relu.default":
@@ -89,20 +89,20 @@ def define_node(
scalar_tensor_wrapper = self.define_tensor(
scalar_node,
+ node,
coeff_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=True,
)
prelu_input_tensors = [prelu_inp_tensor_wrapper, scalar_tensor_wrapper]
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
prelu_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_quantize.py b/backends/qualcomm/builders/op_quantize.py
index 9d53d655712..4921f96b467 100644
--- a/backends/qualcomm/builders/op_quantize.py
+++ b/backends/qualcomm/builders/op_quantize.py
@@ -28,10 +28,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
inp_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
quant_input_tensors.append(inp_tensor_wrapper)
@@ -43,11 +43,11 @@ def define_node(
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
quant_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_relu.py b/backends/qualcomm/builders/op_relu.py
index 8ddc842e5ea..29335797e28 100644
--- a/backends/qualcomm/builders/op_relu.py
+++ b/backends/qualcomm/builders/op_relu.py
@@ -29,20 +29,20 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
relu_inp_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
relu_input_tensors = [relu_inp_tensor_wrapper]
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
relu_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_reshape.py b/backends/qualcomm/builders/op_reshape.py
index fde6e7647cd..ff4a603fa5b 100644
--- a/backends/qualcomm/builders/op_reshape.py
+++ b/backends/qualcomm/builders/op_reshape.py
@@ -29,18 +29,18 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
node.meta["val"],
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
reshape_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py
index 3a5101b12da..d1daa6c1e54 100644
--- a/backends/qualcomm/builders/op_rms_norm.py
+++ b/backends/qualcomm/builders/op_rms_norm.py
@@ -36,10 +36,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
# should be a immutable list
@@ -60,10 +60,10 @@ def define_node(
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
+ node,
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
# Fake node, nn moudle seems to be inconsistant with document
@@ -80,10 +80,10 @@ def define_node(
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
bias_tensor_wrapper = self.define_tensor(
bias_node,
+ node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=False,
)
epsilon = node.args[3]
@@ -97,11 +97,11 @@ def define_node(
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
rms_nrom_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_rsqrt.py b/backends/qualcomm/builders/op_rsqrt.py
index 34086dda48a..162b485e9e5 100644
--- a/backends/qualcomm/builders/op_rsqrt.py
+++ b/backends/qualcomm/builders/op_rsqrt.py
@@ -29,20 +29,20 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
rsqrt_inp_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
rsqrt_input_tensors = [rsqrt_inp_tensor_wrapper]
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
rsqrt_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_select_copy.py b/backends/qualcomm/builders/op_select_copy.py
index fdeec3845ee..148888f1497 100644
--- a/backends/qualcomm/builders/op_select_copy.py
+++ b/backends/qualcomm/builders/op_select_copy.py
@@ -32,19 +32,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
dim = cast(int, node.args[1])
diff --git a/backends/qualcomm/builders/op_sigmoid.py b/backends/qualcomm/builders/op_sigmoid.py
index 92ba447d437..ae6e6709c0a 100644
--- a/backends/qualcomm/builders/op_sigmoid.py
+++ b/backends/qualcomm/builders/op_sigmoid.py
@@ -29,20 +29,20 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
sigmoid_inp_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
sigmoid_input_tensors = [sigmoid_inp_tensor_wrapper]
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
sigmoid_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_sin.py b/backends/qualcomm/builders/op_sin.py
index 40e466f59ef..89fce6bee9c 100644
--- a/backends/qualcomm/builders/op_sin.py
+++ b/backends/qualcomm/builders/op_sin.py
@@ -30,19 +30,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
sin_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_slice_copy.py b/backends/qualcomm/builders/op_slice_copy.py
index 3a294e35486..8d12e03c0bb 100644
--- a/backends/qualcomm/builders/op_slice_copy.py
+++ b/backends/qualcomm/builders/op_slice_copy.py
@@ -32,19 +32,19 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
tensor_type,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
dim = cast(int, node.args[1])
diff --git a/backends/qualcomm/builders/op_softmax.py b/backends/qualcomm/builders/op_softmax.py
index cda40aed458..f6f826e2a40 100644
--- a/backends/qualcomm/builders/op_softmax.py
+++ b/backends/qualcomm/builders/op_softmax.py
@@ -31,20 +31,20 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
softmax_inp_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
softmax_input_tensors = [softmax_inp_tensor_wrapper]
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
softmax_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_space_to_depth.py b/backends/qualcomm/builders/op_space_to_depth.py
index a9b61c520ed..0282cf3f15a 100644
--- a/backends/qualcomm/builders/op_space_to_depth.py
+++ b/backends/qualcomm/builders/op_space_to_depth.py
@@ -32,19 +32,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
block_size = []
diff --git a/backends/qualcomm/builders/op_split_with_sizes.py b/backends/qualcomm/builders/op_split_with_sizes.py
index 58503ff3f87..8e75fd3c10d 100644
--- a/backends/qualcomm/builders/op_split_with_sizes.py
+++ b/backends/qualcomm/builders/op_split_with_sizes.py
@@ -33,10 +33,10 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
input_tensor_wrappers = [input_tensor_wrapper]
@@ -45,11 +45,11 @@ def define_node(
for index in range(len(node.meta["val"])):
output_tensor = self.get_tensor(node, node, index)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
wrapper_idx=index,
)
output_tensor_wrappers.append(output_tensor_wrapper)
diff --git a/backends/qualcomm/builders/op_sqrt.py b/backends/qualcomm/builders/op_sqrt.py
index 7847d00e8b8..dc6691460ca 100644
--- a/backends/qualcomm/builders/op_sqrt.py
+++ b/backends/qualcomm/builders/op_sqrt.py
@@ -31,20 +31,20 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
sqrt_input_tensors = [input_tensor_wrapper]
out_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
out_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
sqrt_output_tensors = [output_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_squeeze.py b/backends/qualcomm/builders/op_squeeze.py
index 00cbda54cf0..b828bb7b0b9 100644
--- a/backends/qualcomm/builders/op_squeeze.py
+++ b/backends/qualcomm/builders/op_squeeze.py
@@ -30,19 +30,19 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
squeeze_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_sub.py b/backends/qualcomm/builders/op_sub.py
index 66c028253c6..954ca9d3917 100644
--- a/backends/qualcomm/builders/op_sub.py
+++ b/backends/qualcomm/builders/op_sub.py
@@ -27,11 +27,11 @@ def define_node(
) -> PyQnnWrapper.PyQnnOpWrapper:
out_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
out_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
sub_output_tensors = [output_tensor_wrapper]
@@ -43,10 +43,10 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
tensor_type,
nodes_to_wrappers,
- is_input_tensor=True,
)
sub_input_tensors.append(input_tensor_wrapper)
diff --git a/backends/qualcomm/builders/op_sum_int_list.py b/backends/qualcomm/builders/op_sum_int_list.py
index abe35c22445..74181f46cb3 100644
--- a/backends/qualcomm/builders/op_sum_int_list.py
+++ b/backends/qualcomm/builders/op_sum_int_list.py
@@ -32,10 +32,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
sum_input_tensors = [input_tensor_wrapper]
@@ -50,11 +50,11 @@ def define_node(
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
sum_output_tensors = [output_tensor_wrapper]
sum_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_tanh.py b/backends/qualcomm/builders/op_tanh.py
index f82ff92f442..ddc9fd2a2a6 100644
--- a/backends/qualcomm/builders/op_tanh.py
+++ b/backends/qualcomm/builders/op_tanh.py
@@ -30,19 +30,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
tanh_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_to.py b/backends/qualcomm/builders/op_to.py
index e17ee2790b9..f5cfd4ecf6e 100644
--- a/backends/qualcomm/builders/op_to.py
+++ b/backends/qualcomm/builders/op_to.py
@@ -45,11 +45,12 @@ def is_cast_node(self, node):
return True
input_tensor = self.get_tensor(input_node, node)
- _, inp_qconfs = self.get_quant_encoding_conf(input_node, False)
+ # Get real quant conf of input node
+ _, inp_qconfs = self.get_quant_encoding_conf(input_node, input_node)
inp_dtype = self.get_data_type(input_tensor, inp_qconfs)
output_tensor = self.get_tensor(node, node)
- _, out_qconfs = self.get_quant_encoding_conf(node, False)
+ _, out_qconfs = self.get_quant_encoding_conf(node, node)
out_dtype = self.get_data_type(output_tensor, out_qconfs)
is_qparam_castable = (
lambda o1, o2, s1, s2, diff: abs(s1 - s2) < self.epsilon
@@ -84,20 +85,20 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
qnn_op = OpCast if self.is_cast_node(node) else OpConvert
diff --git a/backends/qualcomm/builders/op_topk.py b/backends/qualcomm/builders/op_topk.py
index 84c29925f27..1bbf19c84bd 100644
--- a/backends/qualcomm/builders/op_topk.py
+++ b/backends/qualcomm/builders/op_topk.py
@@ -33,10 +33,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
- is_input_tensor=True,
)
k = cast(int, node.args[1])
@@ -62,21 +62,21 @@ def define_node(
# QNN constraint, topk output_0 requires having the same quant config as input
node.meta["quant_attrs"] = input_node.meta.get("quant_attrs")
output_val_tensor_wrapper = self.define_tensor(
+ node,
node,
output_val_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
# topk output_1 is index, do not quantize it.
node.meta.pop("quant_attrs", None)
output_index_tensor_wrapper = self.define_tensor(
+ node,
node,
output_idx_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
wrapper_idx=1,
)
topk_output_tensors = [output_val_tensor_wrapper, output_index_tensor_wrapper]
diff --git a/backends/qualcomm/builders/op_transpose.py b/backends/qualcomm/builders/op_transpose.py
index 20e30da3358..d29fc73084c 100644
--- a/backends/qualcomm/builders/op_transpose.py
+++ b/backends/qualcomm/builders/op_transpose.py
@@ -33,10 +33,10 @@ def define_node(
input_tensor = self.get_tensor(input_node, permute_node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
# permutation
@@ -45,11 +45,11 @@ def define_node(
output_tensor = input_tensor.permute(permute_order)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
transpose_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_unsqueeze.py b/backends/qualcomm/builders/op_unsqueeze.py
index 48c9207398c..55790129462 100644
--- a/backends/qualcomm/builders/op_unsqueeze.py
+++ b/backends/qualcomm/builders/op_unsqueeze.py
@@ -30,19 +30,19 @@ def define_node(
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
unsqueeze_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_upsample_bilinear2d.py b/backends/qualcomm/builders/op_upsample_bilinear2d.py
index 53291786a8a..160f15494d8 100644
--- a/backends/qualcomm/builders/op_upsample_bilinear2d.py
+++ b/backends/qualcomm/builders/op_upsample_bilinear2d.py
@@ -30,19 +30,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
reisze_bilinear_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/builders/op_upsample_nearest2d.py b/backends/qualcomm/builders/op_upsample_nearest2d.py
index 75e61d77e53..6b7949716cb 100644
--- a/backends/qualcomm/builders/op_upsample_nearest2d.py
+++ b/backends/qualcomm/builders/op_upsample_nearest2d.py
@@ -30,19 +30,19 @@ def define_node(
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
+ node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=True,
)
output_tensor = self.get_tensor(node, node)
output_tensor_wrapper = self.define_tensor(
+ node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
- is_input_tensor=False,
)
reisze_nearest_op = PyQnnWrapper.PyQnnOpWrapper(
diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py
index c58d0844b40..6167df64b9e 100644
--- a/backends/qualcomm/quantizer/custom_annotation.py
+++ b/backends/qualcomm/quantizer/custom_annotation.py
@@ -27,6 +27,12 @@ def annotate_matmul_16a8w( # noqa: C901
) -> None:
"""
This function is specific for matmul op 16a8w.
+ For k, we will tag such as the below, and
+ for v, we will tag 8a until conv op.
+ q (16 bits) ──┬─> matmul op (16 bits)
+ past k (8 bits) ┬─> cat op (8 bits) ─┘
+ rotatary add (16 bits) ─┬> cat op (new k) (8 bits) ┘
+ rotatary sub (16 bits) ─┘
"""
def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
@@ -65,6 +71,21 @@ def annotate_cat(node: Node, quantization_config: QuantizationConfig):
_annotated=True,
)
+ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
+ input_qspec_map = {}
+ input_act = node.args[0]
+ input_spec = quantization_config.input_activation
+ input_qspec_map[input_act] = input_spec
+
+ weight = node.args[1]
+ input_qspec_map[weight] = quantization_config.weight
+
+ node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
+ input_qspec_map=input_qspec_map,
+ output_qspec=quantization_config.output_activation,
+ _annotated=True,
+ )
+
def annotate_single_in_single_out(
node: Node, quantization_config: QuantizationConfig
) -> None:
@@ -83,17 +104,37 @@ def annotate_matmul_input1(node: Node):
quantization_config_8a8w = get_8a8w_qnn_ptq_config(
act_symmetric=True, act_observer=MinMaxObserver
)
+ quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config(
+ act_dtype=torch.uint8,
+ weight_dtype="int4",
+ act_observer=MinMaxObserver,
+ act_symmetric=True,
+ )
while isinstance(node, Node) and node.op == "call_function":
if node.target in [
torch.ops.aten.permute.default,
+ torch.ops.aten.squeeze.dim,
torch.ops.aten.transpose.int,
+ torch.ops.aten.view.default,
+ torch.ops.aten.reshape.default,
]:
annotate_single_in_single_out(node, quantization_config_8a8w)
node = node.args[0]
elif node.target == torch.ops.aten.cat.default:
annotate_cat(node, quantization_config_8a8w)
- node = node.args[0][0]
+ # For v, we tag 8a until conv op.
+ # For k, we tag 8a until add or sub op (rotatary embedding).
+ # The arguments of cat op: (the past kv cache, the new kv cache)
+ node = node.args[0][1]
+ elif node.target == torch.ops.aten.conv2d.default:
+ annotate_conv2d(
+ node, quantization_config=quantization_config_8a4w_per_channel
+ )
+ break
+ elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]:
+ break
else:
+ print(f"The node ({node}) is not expected in the input1 of the matmul")
node = node.args[0]
quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver)
diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py
index abe51066ba0..a6c551e2413 100644
--- a/backends/qualcomm/quantizer/qconfig.py
+++ b/backends/qualcomm/quantizer/qconfig.py
@@ -249,6 +249,7 @@ def get_ptq_per_channel_quant_config(
act_quantization_spec = QuantizationSpec(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
qscheme=torch.per_tensor_symmetric,
+ ch_axis=0,
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
)
else:
diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py
index 4b489ea5157..43b78d341cd 100644
--- a/backends/qualcomm/tests/test_qnn_delegate.py
+++ b/backends/qualcomm/tests/test_qnn_delegate.py
@@ -1655,8 +1655,12 @@ def test_qnn_backend_multi_graphs(self):
to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i]))
for i, edge_prog in enumerate(edge_progs)
]
- prog_mgr = generate_multi_graph_program(
- compiler_specs=compiler_specs[0], exported_programs=exported_programs
+ prog_mgr, _ = generate_multi_graph_program(
+ compiler_specs=compiler_specs[0],
+ processed_bytes=[
+ prog.graph_module.lowered_module_0.processed_bytes
+ for prog in exported_programs
+ ],
)
for index, module in enumerate(modules):
self.verify_output(
@@ -2120,9 +2124,12 @@ def test_qnn_backend_multi_graphs(self):
to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i]))
for i, edge_prog in enumerate(edge_progs)
]
- prog_mgr = generate_multi_graph_program(
+ prog_mgr, _ = generate_multi_graph_program(
compiler_specs=compiler_specs[0],
- exported_programs=exported_programs,
+ processed_bytes=[
+ prog.graph_module.lowered_module_0.processed_bytes
+ for prog in exported_programs
+ ],
)
for index, module in enumerate(modules):
self.verify_output(
diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py
index 2e0ee4f7c63..3d2a9f8c85d 100644
--- a/backends/qualcomm/utils/utils.py
+++ b/backends/qualcomm/utils/utils.py
@@ -6,6 +6,7 @@
import operator
import re
+import time
import warnings
from collections import OrderedDict
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple
@@ -740,7 +741,7 @@ def preprocess_binary(ctx_bin, compiler_specs):
for k, v in type_map.items():
dtype_map.setdefault(v, k)
- qnn_in_order, executorch_in_order, executorch_out_order = [], [], []
+ qnn_in_order, executorch_in_order, executorch_out_order = None, None, None
if custom_info is not None:
# since some context binaries might fail to open on host
# if they are compiled with special flags:
@@ -748,9 +749,9 @@ def preprocess_binary(ctx_bin, compiler_specs):
# use custom information here instead
inputs = build_tensor(custom_info["graph_inputs"], dtype_map)
outputs = build_tensor(custom_info["graph_outputs"], dtype_map)
- qnn_in_order = custom_info["qnn_in_order"]
- executorch_in_order = custom_info["executorch_in_order"]
- executorch_out_order = custom_info["executorch_out_order"]
+ qnn_in_order = custom_info.get("qnn_in_order", None)
+ executorch_in_order = custom_info.get("executorch_in_order", None)
+ executorch_out_order = custom_info.get("executorch_out_order", None)
graph_name = custom_info["graph_name"]
else:
# get context-binary io tensor info through qnn manager
@@ -800,7 +801,9 @@ def draw_graph(title, path, graph_module: torch.fx.GraphModule):
def generate_multi_graph_program(
compiler_specs: List[CompileSpec],
- exported_programs: List[ExportedProgram] = None,
+ processed_bytes: List[bytes],
+ input_nodes_dict: List[torch.fx.Node] = None,
+ output_nodes_dict: List[torch.fx.Node] = None,
backend_config: ExecutorchBackendConfig = None,
constant_methods: Optional[Dict[str, Any]] = None,
) -> ExecutorchProgramManager:
@@ -813,10 +816,6 @@ def generate_multi_graph_program(
executorch_in_order,
executorch_out_order,
) = ({}, {}, {}, {}, {})
-
- processed_bytes = [
- prog.graph_module.lowered_module_0.processed_bytes for prog in exported_programs
- ]
qnn_mgr = PyQnnManagerAdaptor.QnnManager(
generate_qnn_executorch_option(compiler_specs), processed_bytes
)
@@ -829,38 +828,36 @@ def generate_multi_graph_program(
graph_outputs[graph_name] = qnn_mgr.GetGraphOutputs(graph_name)
# We need to obtain the order of the IOs to correctly map QNN with nn.module
- for i, graph_name in enumerate(graph_names):
- # input
- input_names = [
- node.name
- for node in exported_programs[i].graph_module.graph.nodes
- if node.op == "placeholder"
- ]
- qnn_input_names = [wrapper.GetName() for wrapper in graph_inputs[graph_name]]
- input_order_list = []
- for input_name in input_names:
- # e.g., input_0_tokens_0
- pattern = rf"^input_(\d+)_({input_name})_(\d+)$"
- for j in range(len(qnn_input_names)):
- if re.match(pattern, qnn_input_names[j]):
- input_order_list.append(j)
- break
- assert (
- len(input_order_list) == len(input_names) == len(qnn_input_names)
- ), "Order list length is different from names"
- executorch_in_order[graph_name] = input_order_list
- qnn_in_order[graph_name] = sorted(
- range(len(input_order_list)), key=lambda k: input_order_list[k]
- )
-
- # output
- get_item_list = [
- node
- for node in exported_programs[i].graph_module.graph.nodes
- if node.op == "output"
- ][0].args[0]
- output_order_list = [item.args[1] for item in get_item_list]
- executorch_out_order[graph_name] = output_order_list
+ for graph_name in graph_names:
+ if input_nodes_dict:
+ # input
+ input_names = [node.name for node in input_nodes_dict[graph_name]]
+ qnn_input_names = [
+ wrapper.GetName() for wrapper in graph_inputs[graph_name]
+ ]
+ # The input of intermideate module including call_function node
+ # could not be reorder by node name
+ if len(input_names) == len(qnn_input_names):
+ input_order_list = []
+ for input_name in input_names:
+ # e.g., input_0_tokens_0
+ pattern = rf"^input_(\d+)_({input_name})_(\d+)$"
+ for j in range(len(qnn_input_names)):
+ if re.match(pattern, qnn_input_names[j]):
+ input_order_list.append(j)
+ break
+ assert len(input_order_list) == len(
+ input_names
+ ), "Order list length is different from names"
+ executorch_in_order[graph_name] = input_order_list
+ qnn_in_order[graph_name] = sorted(
+ range(len(input_order_list)), key=lambda k: input_order_list[k]
+ )
+ if output_nodes_dict:
+ # output
+ get_item_list = output_nodes_dict[graph_name][0].args[0]
+ output_order_list = [item.args[1] for item in get_item_list]
+ executorch_out_order[graph_name] = output_order_list
qnn_mgr.Destroy()
@@ -869,15 +866,15 @@ def generate_multi_graph_program(
bundle_progs = [
from_context_binary(
ctx_path=binary_info,
- op_name=f"loader_{graph_name}",
+ op_name=f"loader_{graph_name}_{int(time.time())}",
soc_model=compiler_options.soc_info.soc_model,
custom_info={
"graph_inputs": graph_inputs[graph_name],
"graph_outputs": graph_outputs[graph_name],
"graph_name": graph_name,
- "qnn_in_order": qnn_in_order[graph_name],
- "executorch_in_order": executorch_in_order[graph_name],
- "executorch_out_order": executorch_out_order[graph_name],
+ "qnn_in_order": qnn_in_order.get(graph_name, None),
+ "executorch_in_order": executorch_in_order.get(graph_name, None),
+ "executorch_out_order": executorch_out_order.get(graph_name, None),
},
)
for graph_name in graph_names
@@ -900,9 +897,101 @@ def generate_multi_graph_program(
break
edge_prog_mgr = edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs))
- return edge_prog_mgr.to_executorch(
+ exec_prog = edge_prog_mgr.to_executorch(
+ config=backend_config or ExecutorchBackendConfig()
+ )
+ return exec_prog, bundle_progs
+
+
+def generate_composite_llama_program(
+ graph_names: List[str],
+ sample_inputs_list: List[Tuple[Any]],
+ lower_module_dict: Dict[str, List[LoweredBackendModule]],
+ call_delegate_node_name_dict: Dict[str, List[str]],
+ call_delegate_inputs_dict: Dict[str, List[Tuple[str, int | None]]],
+ outputs_dict: Dict[str, List[Tuple[str, int]]],
+ backend_config: ExecutorchBackendConfig = None,
+ constant_methods: Optional[Dict[str, Any]] = None,
+) -> ExecutorchProgramManager:
+ class CompositeLlamaModule(torch.nn.Module):
+ def __init__(
+ self,
+ lower_module_list,
+ call_delegate_node_name_list,
+ call_delegate_inputs_list,
+ outputs_list,
+ ) -> None:
+ super().__init__()
+ self.lower_module_list = lower_module_list
+ self.call_delegate_node_name_list = call_delegate_node_name_list
+ self.call_delegate_inputs_list = call_delegate_inputs_list
+ self.outputs_list = outputs_list
+
+ def reorder(
+ self,
+ call_delegate_inputs: List[Tuple[str, int | None]],
+ module_inputs: dict[str, torch.Tensor],
+ all_ret: dict[str, torch.Tensor],
+ ) -> Tuple[torch.Tensor]:
+ ret = []
+ for name, index in call_delegate_inputs:
+ if index is not None:
+ # Get tensor from previous results
+ ret.append(all_ret[name][index])
+ else:
+ # Get tensor from the inputs of module
+ ret.append(module_inputs[name])
+ return tuple(ret)
+
+ def forward(
+ self,
+ tokens: torch.Tensor,
+ atten_mask: torch.Tensor,
+ input_pos: Optional[torch.Tensor] = None,
+ *args,
+ ) -> Tuple[torch.Tensor]:
+ all_ret = {}
+ module_input_dict = {
+ "tokens": tokens,
+ "atten_mask": atten_mask,
+ "input_pos": input_pos,
+ }
+ for num, arg in enumerate(args):
+ module_input_dict[f"args_{num}"] = arg
+ for lower_module, call_delegate_node_name, call_delegate_inputs in zip(
+ self.lower_module_list,
+ self.call_delegate_node_name_list,
+ self.call_delegate_inputs_list,
+ ):
+ inp = self.reorder(call_delegate_inputs, module_input_dict, all_ret)
+ ret = lower_module(*inp)
+ all_ret[call_delegate_node_name] = ret
+ llama_outputs = []
+ for output_src_name, index in self.outputs_list:
+ llama_outputs.append(all_ret[output_src_name][index])
+ return tuple(llama_outputs)
+
+ progs_dict = {}
+ for graph_name, sample_inputs in zip(graph_names, sample_inputs_list):
+ composite_llama_module = CompositeLlamaModule(
+ lower_module_dict[graph_name],
+ call_delegate_node_name_dict[graph_name],
+ call_delegate_inputs_dict[graph_name],
+ outputs_dict[graph_name],
+ )
+ prog = torch.export.export(composite_llama_module, sample_inputs)
+ progs_dict[graph_name] = prog
+ # leverage ExecutorchProgramManager for generating pte with multi-methods
+ edge_prog_mgr = to_edge(
+ progs_dict,
+ constant_methods=constant_methods,
+ # do not alter name for custom op
+ compile_config=EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False),
+ )
+ exec_prog = edge_prog_mgr.to_executorch(
config=backend_config or ExecutorchBackendConfig()
)
+ return exec_prog
def generate_htp_compiler_spec(
diff --git a/examples/qualcomm/oss_scripts/llama2/model/static_llama.py b/examples/qualcomm/oss_scripts/llama2/model/static_llama.py
index 6894537b82f..d1b618ed071 100755
--- a/examples/qualcomm/oss_scripts/llama2/model/static_llama.py
+++ b/examples/qualcomm/oss_scripts/llama2/model/static_llama.py
@@ -11,9 +11,8 @@
import torch
import torch.nn as nn
-
+import torch.nn.functional as F
from executorch.examples.models.llama.llama_transformer import (
- FeedForward,
ModelArgs,
precompute_freqs_cis,
)
@@ -58,37 +57,44 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
def prepare_sha(self):
self.wq_sha = nn.ModuleList(
[
- nn.Linear(self.dim, self.head_dim, bias=False)
+ nn.Conv2d(self.dim, self.head_dim, 1, bias=False)
for _ in range(self.n_heads)
]
)
self.wk_sha = nn.ModuleList(
[
- nn.Linear(self.dim, self.head_dim, bias=False)
+ nn.Conv2d(self.dim, self.head_dim, 1, bias=False)
for _ in range(self.n_kv_heads)
]
)
self.wv_sha = nn.ModuleList(
[
- nn.Linear(self.dim, self.head_dim, bias=False)
+ nn.Conv2d(self.dim, self.head_dim, 1, bias=False)
for _ in range(self.n_kv_heads)
]
)
+ self.wo_sha = nn.Conv2d(self.n_heads * self.head_dim, self.dim, 1, bias=False)
self.forward_mha = self.forward
self.forward = self.forward_sha
-
for i in range(self.n_heads):
self.wq_sha[i].weight.data.copy_(
- self.wq.weight[i * self.head_dim : (i + 1) * self.head_dim]
+ self.wq.weight[
+ i * self.head_dim : (i + 1) * self.head_dim, :, None, None
+ ]
)
for i in range(self.n_kv_heads):
self.wk_sha[i].weight.data.copy_(
- self.wk.weight[i * self.head_dim : (i + 1) * self.head_dim]
+ self.wk.weight[
+ i * self.head_dim : (i + 1) * self.head_dim, :, None, None
+ ]
)
self.wv_sha[i].weight.data.copy_(
- self.wv.weight[i * self.head_dim : (i + 1) * self.head_dim]
+ self.wv.weight[
+ i * self.head_dim : (i + 1) * self.head_dim, :, None, None
+ ]
)
+ self.wo_sha.weight.data.copy_(self.wo.weight[:, :, None, None])
def forward_sha(
self,
@@ -99,9 +105,22 @@ def forward_sha(
k_caches: Optional[List[torch.Tensor]] = None,
v_caches: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- q = [wq_sha(hidden_states) for wq_sha in self.wq_sha]
- k = [wk_sha(hidden_states) for wk_sha in self.wk_sha]
- v = [wv_sha(hidden_states) for wv_sha in self.wv_sha]
+ bsz, seq_len, _ = hidden_states.shape
+ hidden_states = torch.reshape(
+ hidden_states, (bsz, seq_len, 1, self.dim)
+ ).transpose(1, 3)
+ q = [
+ wq_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2)
+ for wq_sha in self.wq_sha
+ ]
+ k = [
+ wk_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2)
+ for wk_sha in self.wk_sha
+ ]
+ v = [
+ wv_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2)
+ for wv_sha in self.wv_sha
+ ]
for i in range(len(q)):
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
for i in range(len(k)):
@@ -129,7 +148,11 @@ def forward_sha(
output_y.append(y)
y = torch.concat(output_y, dim=-1)
- y = self.wo(y)
+ y = y.reshape(bsz, seq_len, 1, -1)
+ y = y.transpose(1, 3)
+ y = self.wo_sha(y)
+ y = y.transpose(1, 3)
+ y = y.reshape(bsz, seq_len, -1)
if self.output_new_cache_only:
if k_caches and v_caches:
@@ -148,12 +171,12 @@ def forward(
k_caches: List[torch.Tensor],
v_caches: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- bsz, seqlen, _ = hidden_states.shape
+ bsz, seq_len, _ = hidden_states.shape
q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
- q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
- k = k.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
- v = v.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
+ q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
+ k = k.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
+ v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
q = apply_rotary_emb_single(q, freqs_cos, freqs_sin)
k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1)
@@ -203,6 +226,45 @@ def forward(
return y, output_kh, output_vh
+class FeedForward(nn.Module):
+ def __init__(self, args: ModelArgs):
+ super().__init__()
+ assert args.hidden_dim is not None
+ self.hidden_dim: int = args.hidden_dim
+ self.dim: int = args.dim
+ self.w1 = nn.Linear(self.dim, self.hidden_dim, bias=False)
+ self.w2 = nn.Linear(self.hidden_dim, self.dim, bias=False)
+ self.w3 = nn.Linear(self.dim, self.hidden_dim, bias=False)
+
+ def prepare_feedfoward_conv(self):
+ self.w1_conv = nn.Conv2d(self.dim, self.hidden_dim, 1, bias=False)
+ self.w2_conv = nn.Conv2d(self.hidden_dim, self.dim, 1, bias=False)
+ self.w3_conv = nn.Conv2d(self.dim, self.hidden_dim, 1, bias=False)
+
+ self.forward_no_conv = self.forward
+ self.forward = self.forward_feedfoward_conv
+
+ self.w1_conv.weight.data.copy_(self.w1.weight[:, :, None, None])
+ self.w2_conv.weight.data.copy_(self.w2.weight[:, :, None, None])
+ self.w3_conv.weight.data.copy_(self.w3.weight[:, :, None, None])
+
+ del self.w1
+ del self.w2
+ del self.w3
+
+ def forward_feedfoward_conv(self, x):
+ bsz, _, _ = x.size()
+ x = torch.reshape(x, (bsz, -1, self.dim, 1))
+ x = x.transpose(1, 2) # Transpose right before and after Conv
+ x = self.w2_conv(F.silu(self.w1_conv(x)) * self.w3_conv(x))
+ x = x.transpose(1, 2)
+ x = torch.reshape(x, (bsz, -1, self.dim))
+ return x
+
+ def forward(self, x):
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, output_new_cache_only=False):
super().__init__()
@@ -268,6 +330,22 @@ def __init__(self, config: ModelArgs, output_new_cache_only=True):
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
+ def prepare_output_conv(self):
+ def forward_output_conv(x):
+ bsz, _, _ = x.size()
+ x = torch.reshape(x, (bsz, -1, 1, self.dim))
+ x = x.transpose(1, 3) # Transpose right before and after Conv
+ x = self.output_conv(x)
+ x = x.transpose(1, 3)
+ x = torch.reshape(x, (bsz, -1, self.vocab_size))
+ return x
+
+ self.output_conv = nn.Conv2d(self.dim, self.vocab_size, 1, bias=False)
+ self.output_conv.weight.data.copy_(self.output.weight[:, :, None, None])
+
+ del self.output
+ self.output = forward_output_conv
+
def forward(
self,
tokens: torch.Tensor,
diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama3_2/llama.py
index 13fb99a4202..a18690e941d 100755
--- a/examples/qualcomm/oss_scripts/llama3_2/llama.py
+++ b/examples/qualcomm/oss_scripts/llama3_2/llama.py
@@ -18,7 +18,6 @@
from multiprocessing.connection import Client
import torch
-from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
@@ -32,10 +31,12 @@
from executorch.backends.qualcomm.utils.utils import (
capture_program,
convert_linear_to_conv2d,
+ generate_composite_llama_program,
generate_htp_compiler_spec,
generate_multi_graph_program,
generate_qnn_executorch_compiler_spec,
get_soc_to_chipset_map,
+ update_spill_fill_size,
)
from executorch.examples.qualcomm.oss_scripts.llama2.model.static_llama import (
LlamaModel,
@@ -78,7 +79,9 @@ def _kv_calibrate(
# TODO: change criteria & support batch inputs if necessary
pos = torch.tensor(0, dtype=torch.int32)
max_cache_len = max_seq_len - 1
- token_list = sp_model.encode(user_prompts, bos=True, eos=False)
+ token_list = sp_model.encode(
+ user_prompts, bos=True, eos=False, allowed_special="all"
+ )
with torch.no_grad():
while token_list[-1] != sp_model.eos_id and pos < max_cache_len:
@@ -118,28 +121,30 @@ def _prefill_calibrate(
max_cache_len = max_seq_len - 1
# TODO: change criteria & support batch inputs if necessary
- token_list = sp_model.encode(user_prompts, bos=True, eos=False)
- token_list = torch.tensor(token_list)[:max_cache_len].reshape(1, -1)
- last_prompt_pos = token_list.numel()
- if last_prompt_pos < max_cache_len:
- token_list = torch.cat(
- [
- token_list,
- torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int32),
- ],
- dim=1,
- )
- else:
- token_list = token_list[:, :max_cache_len]
+ token_list = sp_model.encode(
+ user_prompts, bos=True, eos=False, allowed_special="all"
+ )
+ pos = len(token_list)
with torch.no_grad():
- logits, new_k_caches, new_v_caches = module(
- token_list,
- atten_mask,
- )
- predict = [torch.argmax(logits[:, last_prompt_pos - 1], dim=-1).item()]
+ while token_list[-1] != sp_model.eos_id and pos < max_cache_len:
+ tmp_token_list = torch.tensor(token_list).reshape(1, -1)
+ if pos < max_cache_len:
+ tmp_token_list = torch.cat(
+ [
+ tmp_token_list,
+ torch.zeros((1, max_cache_len - pos), dtype=torch.int32),
+ ],
+ dim=1,
+ )
+ logits, new_k_caches, new_v_caches = module(
+ tmp_token_list,
+ atten_mask,
+ )
+ token_list.append(torch.argmax(logits[:, pos - 1], dim=-1).item())
+ pos += 1
- print(f"calibration data:\n{sp_model.decode(predict)}")
+ print(f"calibration data:\n{sp_model.decode(token_list)}")
def calibrate(
@@ -186,41 +191,94 @@ def __init__(self, llama_model, pte_filename) -> None:
tokens, atten_mask = self.get_example_inputs(use_kv_cache=False)
self.inputs = (tokens, atten_mask)
- def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type):
+ def _tag_ios(self, gm: torch.fx.GraphModule, fixed_point_type):
if not self.has_quant_io:
return
# shape of k caches and v caches
- input_cache_shape = {
+ kv_cache_shape = {
+ # single head, kv mode input
(self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]),
(self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]),
+ # single head, kv mode output
+ (self.llama_meta["get_head_dim"], 1),
+ (1, self.llama_meta["get_head_dim"]),
+ # single head, bert mode
+ (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"] - 1),
+ (self.llama_meta["get_max_seq_len"] - 1, self.llama_meta["get_head_dim"]),
}
+ io_shape = {
+ # kv mode
+ (
+ self.llama_meta["get_max_batch_size"],
+ 1,
+ self.llama_meta["get_vocab_size"],
+ ),
+ # bert mode
+ (
+ self.llama_meta["get_max_batch_size"],
+ self.llama_meta["get_max_seq_len"] - 1,
+ self.llama_meta["get_vocab_size"],
+ ),
+ }
+
+ atten_mask_shape = {
+ # kv mode
+ (self.llama_meta["get_max_batch_size"], self.llama_meta["get_max_seq_len"]),
+ # bert mode
+ (
+ self.llama_meta["get_max_seq_len"] - 1,
+ self.llama_meta["get_max_seq_len"] - 1,
+ ),
+ }
+
+ freq_shape = {
+ # kv mode
+ (1, self.llama_meta["get_head_dim"] // 2),
+ # bert mode
+ (
+ self.llama_meta["get_max_seq_len"] - 1,
+ self.llama_meta["get_head_dim"] // 2,
+ ),
+ }
+
+ freq_op = {
+ # kv mode
+ exir_ops.edge.aten.select.int,
+ # bert mode
+ exir_ops.edge.aten.slice_copy.Tensor,
+ }
+
for n in gm.graph.nodes:
- if (
- n.op == "placeholder"
- and len(users := list(n.users)) == 1
- and users[0].meta["val"].size()[-2:] in input_cache_shape
- ):
- n.meta[QCOM_QUANTIZED_IO] = kv_type
+ if n.op == "placeholder":
+ if (
+ len(users := list(n.users)) == 1
+ and users[0].meta["val"].size()[-2:] in kv_cache_shape
+ ):
+ n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["kv_type"]
+ elif n.meta["val"].size() in io_shape:
+ n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"]
+ elif n.meta["val"].size() in atten_mask_shape:
+ n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"]
elif n.op == "output":
for a in n.args[0]:
- # single head, kv mode
- if (
- a.meta["val"].flatten().size()[0]
- == self.llama_meta["get_head_dim"]
- ):
- a.meta[QCOM_QUANTIZED_IO] = kv_type
- # single head, prefill mode
- elif a.meta["val"].flatten().size()[0] == self.llama_meta[
- "get_head_dim"
- ] * (self.llama_meta["get_max_seq_len"] - 1):
- a.meta[QCOM_QUANTIZED_IO] = kv_type
+ if a.meta["val"].size()[-2:] in kv_cache_shape:
+ a.meta[QCOM_QUANTIZED_IO] = fixed_point_type["kv_type"]
+ elif a.meta["val"].size() in io_shape:
+ a.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"]
+ quant_attrs = a.meta["quant_attrs"]
# Tag sharding io
if exir_ops.edge.llama.fallback.default in [
u.target for u in list(n.users.keys())
] + [n.target]:
- n.meta[QCOM_QUANTIZED_IO] = sharding_type
+ n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"]
+
+ # Tag select op as quantized tensors for freq_sin and freq_cos. It is caused by sharding
+ if n.target in freq_op and n.meta["val"].size() in freq_shape:
+ n.meta[QCOM_QUANTIZED_IO] = fixed_point_type["io_type"]
+
+ return quant_attrs
def quantize(self, quant_dtype, args, custom_annotations=()):
self.quant_dtype = quant_dtype
@@ -254,16 +312,12 @@ def quantize(self, quant_dtype, args, custom_annotations=()):
def lowering_modules(
self,
work_space,
- kv_type=torch.uint8,
- sharding_type=torch.uint16,
+ fixed_point_type,
use_fp16=False,
soc_model=QcomChipset.SM8650,
num_sharding=0,
):
executorch_config = ExecutorchBackendConfig(
- passes=[
- BuildQuantIo(),
- ],
# For shared buffer, user must pass the memory address
# which is allocated by RPC memory to executor runner.
# Therefore, won't want to pre-allocate
@@ -276,7 +330,9 @@ def lowering_modules(
)
with torch.no_grad():
# backend option
- backend_options = generate_htp_compiler_spec(use_fp16=use_fp16)
+ backend_options = generate_htp_compiler_spec(
+ use_fp16=use_fp16, use_multi_contexts=num_sharding > 0
+ )
compiler_specs = generate_qnn_executorch_compiler_spec(
soc_model=soc_model,
backend_options=backend_options,
@@ -297,10 +353,9 @@ def lowering_modules(
shares=num_sharding,
)
- self._tag_kv_ios(
+ self.quant_attrs = self._tag_ios(
edge_prog.exported_program.graph_module,
- kv_type=kv_type,
- sharding_type=sharding_type,
+ fixed_point_type=fixed_point_type,
)
edge_prog_mgr = EdgeProgramManager(
edge_programs={"forward": edge_prog.exported_program},
@@ -308,13 +363,18 @@ def lowering_modules(
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
edge_prog_mgr = edge_prog_mgr.to_backend(partitioner)
+ if num_sharding > 0:
+ update_spill_fill_size(edge_prog_mgr.exported_program())
exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config)
- with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file:
+ with open(f"{work_space}/{pte_filename}.pte", "wb") as file:
exec_prog_mgr.write_to_file(file)
def get_example_inputs(self, use_kv_cache=True):
return self.llama_model.get_example_inputs(use_kv_cache)
+ def get_quant_attrs(self):
+ return self.quant_attrs
+
def compile(args, pte_filename):
os.makedirs(args.artifact, exist_ok=True)
@@ -371,24 +431,25 @@ def compile(args, pte_filename):
for layer in llama_instance.layers:
if getattr(layer.attention, "prepare_sha", None):
layer.attention.prepare_sha()
-
- use_fp16 = False
- if args.ptq != None:
- kv_type = torch.uint8
+ if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
+ layer.feed_forward.prepare_feedfoward_conv()
+
+ use_fp16 = True
+ fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32}
+ if args.ptq:
+ use_fp16 = False
+ fixed_point_type["kv_type"] = torch.uint8
if args.ptq == "8a8w":
- sharding_type = torch.uint8
+ fixed_point_type["io_type"] = torch.uint8
elif args.ptq == "16a4w":
- sharding_type = torch.uint16
+ fixed_point_type["io_type"] = torch.uint16
else:
assert args.ptq in [
"8a8w",
"16a4w",
], f"No support for quant type {args.ptq}. Support 8a8w and 16a4w."
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
- else:
- use_fp16 = True
- kv_type = torch.float32
- sharding_type = torch.float32
+
assert args.tokenizer_model is not None, "Need tokenizer model for calibration"
if args.dtype_override is not None:
@@ -404,7 +465,7 @@ def compile(args, pte_filename):
llama_instance_list[i].eval(), pte_filename
)
- if args.ptq != None:
+ if args.ptq:
start_quantize_ts = time.time()
for llama_instance in llama_instance_list:
llama_instance.quantize(
@@ -421,16 +482,17 @@ def compile(args, pte_filename):
logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}")
start_lowering_ts = time.time()
+ quant_attrs = None
if len(llama_instance_list) == 1:
llama_instance_list[0].lowering_modules(
args.artifact,
- kv_type=kv_type,
- sharding_type=sharding_type,
+ fixed_point_type,
use_fp16=use_fp16,
soc_model=get_soc_to_chipset_map()[args.model],
num_sharding=args.num_sharding,
)
+ quant_attrs = llama_instance_list[0].get_quant_attrs()
else:
sample_inputs_list = [
llama_instace.inputs for llama_instace in llama_instance_list
@@ -451,12 +513,13 @@ def compile(args, pte_filename):
)
for i in range(len(llama_instance_list)):
- llama_instance_list[i]._tag_kv_ios(
+ quant_attrs = llama_instance_list[i]._tag_ios(
edge_progs[i].exported_program.graph_module,
- kv_type=kv_type,
- sharding_type=sharding_type,
+ fixed_point_type,
)
- backend_options = generate_htp_compiler_spec(use_fp16=use_fp16)
+ backend_options = generate_htp_compiler_spec(
+ use_fp16=use_fp16, use_multi_contexts=args.num_sharding > 0
+ )
graph_names = ["prefill_forward", "kv_forward"]
compiler_specs = [
generate_qnn_executorch_compiler_spec(
@@ -468,15 +531,19 @@ def compile(args, pte_filename):
)
for graph_name in graph_names
]
+ skip_node_op_set = {"llama.fallback.default"}
exported_programs = [
- to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i]))
+ to_backend(
+ edge_prog.exported_program,
+ QnnPartitioner(compiler_specs[i], skip_node_op_set=skip_node_op_set),
+ )
for i, edge_prog in enumerate(edge_progs)
]
+ if args.num_sharding > 0:
+ for exported_program in exported_programs:
+ update_spill_fill_size(exported_program)
executorch_config = ExecutorchBackendConfig(
- passes=[
- BuildQuantIo(),
- ],
# For shared buffer, user must pass the memory address
# which is allocated by RPC memory to executor runner.
# Therefore, won't want to pre-allocate
@@ -488,20 +555,124 @@ def compile(args, pte_filename):
extract_delegate_segments=True,
)
- prog_mgr = generate_multi_graph_program(
- compiler_specs=compiler_specs[0],
- exported_programs=exported_programs,
- backend_config=executorch_config,
- constant_methods=llama_instance_list[1].llama_meta, # kv method meta
- )
- with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file:
- prog_mgr.write_to_file(file)
+ lower_module_dict = {name: [] for name in graph_names}
+ call_delegate_inputs_dict = {name: [] for name in graph_names}
+ call_delegate_node_name_dict = {name: [] for name in graph_names}
+ outputs_dict = {name: [] for name in graph_names}
+ input_nodes_dict = {name: [] for name in graph_names}
+ for prog, graph_name in zip(exported_programs, graph_names):
+ for node in prog.graph_module.graph.nodes:
+ if (
+ node.op == "call_function"
+ and "executorch_call_delegate" in node.name
+ ):
+ call_delegate_node_name_dict[graph_name].append(node.name)
+ call_delegate_inputs_list = []
+ for arg in node.args:
+ if arg.op == "call_function":
+ while "getitem" not in arg.name:
+ arg = arg.args[0]
+ call_delegate_inputs_list.append(
+ (arg.args[0].name, arg.args[1])
+ )
+ elif arg.op == "placeholder":
+ call_delegate_inputs_list.append((arg.name, None))
+ # No extra needs to do for get_attr node
+ call_delegate_inputs_dict[graph_name].append(
+ call_delegate_inputs_list
+ )
+ elif node.op == "output":
+ for arg in node.args[0]:
+ outputs_dict[graph_name].append((arg.args[0].name, arg.args[1]))
+
+ if args.num_sharding > 0:
+ bundle_progs_list = []
+ for num in range(args.num_sharding - 1, -1, -1):
+ processed_bytes = []
+ for prog, graph_name in zip(exported_programs, graph_names):
+ processed_bytes.append(
+ getattr(
+ prog.graph_module, f"lowered_module_{num}"
+ ).processed_bytes
+ )
+
+ call_delegate_node = [
+ list(node.users.keys())[0]
+ for node in prog.graph_module.graph.nodes
+ if node.op == "get_attr"
+ and node.name == f"lowered_module_{num}"
+ ]
+ input_nodes_dict[graph_name] = [
+ node
+ for node in call_delegate_node[0].args
+ if node.op == "placeholder"
+ ]
+
+ prog_mgr, bundle_progs = generate_multi_graph_program(
+ compiler_specs=compiler_specs[0],
+ processed_bytes=processed_bytes,
+ input_nodes_dict=input_nodes_dict,
+ backend_config=executorch_config,
+ constant_methods=llama_instance_list[
+ 1
+ ].llama_meta, # kv method meta
+ )
+ bundle_progs_list.append(bundle_progs)
+ for graph_name in graph_names:
+ lower_module_dict[graph_name].append(
+ prog_mgr.exported_program(graph_name).graph_module._modules.get(
+ "lowered_module_0"
+ )
+ )
+
+ exec_prog = generate_composite_llama_program(
+ graph_names=graph_names,
+ sample_inputs_list=sample_inputs_list,
+ lower_module_dict=lower_module_dict,
+ call_delegate_node_name_dict=call_delegate_node_name_dict,
+ call_delegate_inputs_dict=call_delegate_inputs_dict,
+ outputs_dict=outputs_dict,
+ backend_config=executorch_config,
+ constant_methods=llama_instance_list[1].llama_meta, # kv method meta
+ )
+ with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file:
+ exec_prog.write_to_file(file)
+ else:
+ processed_bytes = []
+ input_nodes_dict = {name: [] for name in graph_names}
+ output_nodes_dict = {name: [] for name in graph_names}
+ for prog, graph_name in zip(exported_programs, graph_names):
+ processed_bytes.append(
+ prog.graph_module.lowered_module_0.processed_bytes
+ )
+ input_nodes_dict[graph_name] = [
+ node
+ for node in prog.graph_module.graph.nodes
+ if node.op == "placeholder"
+ ]
+ output_nodes_dict[graph_name] = [
+ node
+ for node in prog.graph_module.graph.nodes
+ if node.op == "output"
+ ]
+
+ prog_mgr, _ = generate_multi_graph_program(
+ compiler_specs=compiler_specs[0],
+ processed_bytes=processed_bytes,
+ input_nodes_dict=input_nodes_dict,
+ output_nodes_dict=output_nodes_dict,
+ backend_config=executorch_config,
+ constant_methods=llama_instance_list[1].llama_meta, # kv method meta
+ )
+ with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file:
+ prog_mgr.write_to_file(file)
end_lowering_ts = time.time()
logging.info(f"Time for compiling: {end_lowering_ts - start_lowering_ts}")
+ return quant_attrs
-def inference(args, pte_filename, pre_gen_pte=""):
+def inference(args, quant_attrs, pte_filename, pre_gen_pte=""):
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama"
if args.model_mode == "prefill":
@@ -524,6 +695,8 @@ def inference(args, pte_filename, pre_gen_pte=""):
f"--eval_mode {eval_mode}",
f"--temperature {args.temperature}",
f"--system_prompt '{args.system_prompt}'",
+ f"--logits_scale {quant_attrs['scale']}",
+ f"--logits_offset {quant_attrs['zero_point']}",
]
)
runner_cmd = " ".join(
@@ -706,16 +879,42 @@ def main():
raise RuntimeError(f"No such model_mode {args.model_mode}.")
if args.pre_gen_pte:
- inference(args, pte_filename, args.pre_gen_pte)
+ quant_attrs = json.load(
+ open(f"{args.pre_gen_pte}/{pte_filename}_quant_attrs.txt")
+ )
+ inference(args, quant_attrs, pte_filename, args.pre_gen_pte)
exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}")
if args.compile_only:
- compile(args, pte_filename)
+ quant_attrs = compile(args, pte_filename)
+ if quant_attrs:
+ json.dump(
+ {
+ "scale": quant_attrs["scale"],
+ "zero_point": quant_attrs["zero_point"],
+ },
+ open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"),
+ )
+ else:
+ logging.warning("Quant attributes of the logit is None.")
exit(f"Finish compile_only and save to {args.artifact}")
try:
- compile(args, pte_filename)
- inference(args, pte_filename)
+ quant_attrs = compile(args, pte_filename)
+ if quant_attrs:
+ logging.info(
+ f"Logit scale: {quant_attrs['scale']}; Logit offset: {quant_attrs['zero_point']}"
+ )
+ json.dump(
+ {
+ "scale": quant_attrs["scale"],
+ "zero_point": quant_attrs["zero_point"],
+ },
+ open(f"{args.artifact}/{pte_filename}_quant_attrs.txt", "w"),
+ )
+ else:
+ logging.warning("Quant attributes of the logit is None.")
+ inference(args, quant_attrs, pte_filename)
except Exception as e:
if args.ip and args.port != -1:
with Client((args.ip, args.port)) as conn:
diff --git a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp
index d05def243ba..2af882580e1 100644
--- a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp
+++ b/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp
@@ -48,6 +48,8 @@ DEFINE_int32(
eval_mode,
0,
"0: PromptProcessor(prefill) / 1: TokenGenerator(kv) / 2: HybridMode (prefill+kv)");
+DEFINE_double(logits_scale, 0.0, "Logits scale");
+DEFINE_int32(logits_offset, 0, "Logits offset");
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
@@ -56,6 +58,8 @@ int main(int argc, char** argv) {
example::Runner runner(
{FLAGS_model_path},
FLAGS_tokenizer_path.c_str(),
+ FLAGS_logits_scale,
+ FLAGS_logits_offset,
FLAGS_temperature,
FLAGS_eval_mode);
std::vector buf;
diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp
index ccf386309c9..941ff97685b 100644
--- a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp
+++ b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp
@@ -134,7 +134,7 @@ void HybridMemory::init_io() {
auto init_kv = [&]() {
ptr->kv_logits.resize(vocab_size_);
- ptr->kv_attention_mask.resize((kv_cache_len_ + 1), -255);
+ ptr->kv_attention_mask.resize((kv_cache_len_ + 1), 0);
ptr->k_cache.reserve(num_layers_);
for (int layer = 0; layer < num_layers_; layer++) {
ptr->k_cache.emplace_back();
@@ -315,9 +315,9 @@ void HybridMemory::prepare_prefill_io(
for (int i = 0; i < prefill_cache_len_; ++i) {
for (int j = 0; j < prefill_cache_len_; ++j) {
if (i < j) {
- ptr->prefill_atten_mask[i * prefill_cache_len_ + j] = -255;
- } else {
ptr->prefill_atten_mask[i * prefill_cache_len_ + j] = 0;
+ } else {
+ ptr->prefill_atten_mask[i * prefill_cache_len_ + j] = 65535;
}
}
}
@@ -458,7 +458,7 @@ void HybridMemory::update_kv_io(
// update position_ids
ptr->input_pos = static_cast(pos);
// update causal mask for next token
- ptr->kv_attention_mask[kv_cache_len_ - pos] = 0;
+ ptr->kv_attention_mask[kv_cache_len_ - pos] = 65535;
// update v_cache
auto& v_cache_in = v_cache_in_[kv_forward_name_];
@@ -496,4 +496,13 @@ void HybridMemory::update_kv_io(
}
}
+void HybridMemory::update_prefill_io(
+ int64_t cur_token,
+ int64_t pos,
+ std::vector>& output_tensors) {
+ (void)output_tensors;
+ IO* ptr = static_cast(data_ptr_.get());
+ ptr->prefill_input_toks[pos] = static_cast(cur_token);
+}
+
} // namespace example
diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h
index ca3a8848871..bb107ffd77e 100644
--- a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h
+++ b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h
@@ -43,6 +43,10 @@ class Memory {
int64_t cur_token,
int64_t pos,
std::vector>& output_tensors) = 0;
+ virtual void update_prefill_io(
+ int64_t cur_token,
+ int64_t pos,
+ std::vector>& output_tensors) = 0;
void* get_mutable_ptr();
std::vector get_input_tensors(
int shard_index,
@@ -97,17 +101,22 @@ class HybridMemory : public Memory {
int64_t pos,
std::vector>& output_tensors)
override;
+ void update_prefill_io(
+ int64_t cur_token,
+ int64_t pos,
+ std::vector>& output_tensors)
+ override;
struct IO {
int32_t input_tok;
int32_t input_pos;
std::vector>> k_cache;
std::vector> v_cache;
std::vector> k_cache_out;
- std::vector kv_attention_mask;
- std::vector kv_logits;
+ std::vector kv_attention_mask;
+ std::vector kv_logits;
std::vector prefill_input_toks;
- std::vector prefill_atten_mask;
- std::vector prefill_logits;
+ std::vector prefill_atten_mask;
+ std::vector prefill_logits;
};
private:
diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp
index e87240dfdfe..02a53861b89 100644
--- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp
+++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp
@@ -40,11 +40,15 @@ std::string statsToJsonString(const Runner::Stats& stats);
Runner::Runner(
const std::vector& models_path,
const std::string& tokenizer_path,
+ const float logits_scale,
+ const int32_t logits_offset,
const float temperature,
const int eval_mode)
: n_bos_(1),
n_eos_(1),
tokenizer_path_(tokenizer_path),
+ logits_scale_(logits_scale),
+ logits_offset_(logits_offset),
temperature_(temperature),
eval_mode_(static_cast(eval_mode)) {
for (size_t i = 0; i < models_path.size(); ++i) {
@@ -205,13 +209,23 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
return res;
}
-template
-int32_t Runner::logitsToToken(const Tensor& logits_tensor) {
- T* logits = logits_tensor.mutable_data_ptr();
-
+int32_t Runner::logitsToToken(const Tensor& logits_tensor, int64_t pos) {
+ static std::vector logits_f(vocab_size_);
+ const uint16_t* logits = logits_tensor.data_ptr();
// Since the logits are for all tokens, get the last token probabilities
- T* logits_last = logits;
- return sampler_->sample(logits_last);
+ auto* logits_last = logits;
+
+ // offset to the meaningful logit we want.
+ if (logits_tensor.sizes().data()[1] > 1) {
+ auto vocab_size = logits_tensor.size(2);
+ logits_last += pos * vocab_size;
+ }
+
+ // dequantize
+ for (int i = 0; i < vocab_size_; i++) {
+ logits_f[i] = (logits_last[i] - logits_offset_) * logits_scale_;
+ }
+ return sampler_->sample(logits_f.data());
}
void Runner::run_model_step(
@@ -266,11 +280,11 @@ Error Runner::generate(
if (!system_prompt.empty()) {
prompt_.append("<|start_header_id|>system<|end_header_id|>\n\n");
prompt_.append(system_prompt);
- prompt_.append("<|eot_id|>\n");
+ prompt_.append("<|eot_id|>");
}
prompt_.append("<|start_header_id|>user<|end_header_id|>\n\n");
prompt_.append(prompt);
- prompt_.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>");
+ prompt_.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
if (token_callback) {
token_callback("<|begin_of_text|>");
@@ -308,32 +322,48 @@ Error Runner::generate(
auto prefill_execute = [&](const std::string& method_name) {
for (int i = 0; i < num_prompt_tokens; i++) {
ptr->prefill_input_toks[i] = static_cast(prompt_tokens[i]);
- auto piece_res = tokenizer_->decode(prompt_tokens[i], prompt_tokens[i]);
- token_callback(piece_res.get());
}
- // inference
- run_model_step(method_name, inputs[method_name]);
- Tensor& logits_tensor = output_tensors[method_name].back()[0];
- // offset to the meaningful logit we want.
- float* logits = logits_tensor.mutable_data_ptr() +
- (num_prompt_tokens - 1) * vocab_size_;
- prev_token = prompt_tokens[num_prompt_tokens - 1];
- long sample_start_time_ms = time_in_ms();
- cur_token = sampler_->sample(logits);
- stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;
- stats_.first_token_ms = time_in_ms();
- stats_.prompt_eval_end_ms = time_in_ms();
- auto piece_res = tokenizer_->decode(prev_token, cur_token);
- ET_CHECK(piece_res.ok());
if (token_callback) {
- token_callback(piece_res.get().c_str());
+ token_callback(prompt_);
+ }
+
+ pos = num_prompt_tokens - 1;
+ cur_token = prompt_tokens[pos];
+ while (pos < seq_len - 1) {
+ // inference
+ run_model_step(method_name, inputs[method_name]);
+ Tensor& logits_tensor = output_tensors[method_name].back()[0];
+ prev_token = cur_token;
+ long sample_start_time_ms = time_in_ms();
+ cur_token = logitsToToken(logits_tensor, pos);
+ stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;
+
+ io_mem_->update_prefill_io(cur_token, ++pos, output_tensors[method_name]);
+ auto piece_res = tokenizer_->decode(prev_token, cur_token);
+ ET_CHECK(piece_res.ok());
+ if (token_callback) {
+ token_callback(piece_res.get().c_str());
+ }
+
+ if (pos == num_prompt_tokens) {
+ stats_.first_token_ms = time_in_ms();
+ stats_.prompt_eval_end_ms = time_in_ms();
+ }
+
+ if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) {
+ ET_LOG(Info, "\nReached to the end of generation");
+ break;
+ }
+ // prefill model inferences once for prompt in the hybrid mode
+ if (eval_mode_ == EvalMode::kHybrid) {
+ break;
+ }
}
- pos += num_prompt_tokens;
};
auto kv_execute = [&](const std::string& method_name) {
ptr->input_tok = static_cast(cur_token);
- ptr->kv_attention_mask[kv_cache_len_] = 0;
+ ptr->kv_attention_mask[kv_cache_len_] = 65535;
while (pos < seq_len - 1) {
// inference
run_model_step(method_name, inputs[method_name]);
@@ -347,10 +377,9 @@ Error Runner::generate(
stats_.prompt_eval_end_ms = time_in_ms();
}
}
-
prev_token = cur_token;
long sample_start_time_ms = time_in_ms();
- cur_token = logitsToToken(logits_tensor);
+ cur_token = logitsToToken(logits_tensor, pos);
stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;
if (pos < num_prompt_tokens - 1) {
diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h
index 79b8370982b..75ad6402199 100644
--- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h
+++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h
@@ -29,6 +29,8 @@ class Runner {
explicit Runner(
const std::vector& models_path,
const std::string& tokenizer_path,
+ const float logits_scale,
+ const int32_t logits_offset,
const float temperature,
const int eval_mode);
@@ -73,8 +75,9 @@ class Runner {
private:
template
T getMetadataHelper(std::string method_name, T default_val);
- template
- int32_t logitsToToken(const executorch::aten::Tensor& logits_tensor);
+ int32_t logitsToToken(
+ const executorch::aten::Tensor& logits_tensor,
+ int64_t pos);
void run_model_step(
const std::string& method_name,
std::vector>& inputs);
@@ -90,6 +93,8 @@ class Runner {
const int32_t n_eos_;
std::vector> modules_;
std::string tokenizer_path_;
+ float logits_scale_;
+ int32_t logits_offset_;
float temperature_;
std::unique_ptr tokenizer_;
std::unique_ptr sampler_;