diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index aaf65afd279..eeb5084ef60 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -10,6 +10,7 @@ from .annotate_unbind import AnnotateUnbind from .convert_bmm_to_matmul import ConvertBmmToMatmul from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d +from .convert_linear_to_conv2d import ConvertLinearToConv2d from .convert_square_to_pow import ConvertSquareToPow from .decompose_any import DecomposeAny from .decompose_cdist import DecomposeCDist @@ -48,6 +49,7 @@ AnnotateUnbind, ConvertBmmToMatmul, ConvertConv1dToConv2d, + ConvertLinearToConv2d, ConvertSquareToPow, DecomposeAny, DecomposeCDist, diff --git a/backends/qualcomm/_passes/build_quant_io.py b/backends/qualcomm/_passes/build_quant_io.py index b34d50c4e24..bff8dfacac5 100644 --- a/backends/qualcomm/_passes/build_quant_io.py +++ b/backends/qualcomm/_passes/build_quant_io.py @@ -39,6 +39,11 @@ def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: if QCOM_QUANTIZED_IO in n.meta: n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO]) + spec = [] + for user in list(call_delegate[0].users): + spec.append(self._make_spec(user.meta["val"])) + call_delegate[0].meta["spec"] = tuple(spec) + def call(self, graph_module: torch.fx.GraphModule): self._build(graph_module) graph_module.graph.eliminate_dead_code() diff --git a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py index 6c29924defa..d09113ad42a 100644 --- a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py +++ b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py @@ -9,7 +9,7 @@ from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE from executorch.exir.pass_base import ExportPass, PassResult -from .utils import copy_meta +from .utils import append_qdq, copy_meta class ConvertConv1dToConv2d(ExportPass): @@ -26,31 +26,6 @@ def __init__(self, edge_program: torch.export.ExportedProgram): torch.ops.aten.conv_transpose1d.default: torch.ops.aten.conv_transpose2d.input, } - def append_qdq( - self, - graph_module: torch.fx.GraphModule, - node: torch.fx.Node, - qdq_node: torch.fx.Node, - ): - q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default - dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default - if qdq_node.target not in {q_op, dq_op}: - return node - - with graph_module.graph.inserting_after(node): - q_args = (node, *qdq_node.args[1:]) - q_node = graph_module.graph.create_node("call_function", q_op, q_args) - q_node.meta = copy_meta(node.meta) - q_node.meta["val"] = q_node.meta["val"].to(q_args[-1]) - with graph_module.graph.inserting_after(q_node): - dq_args = (q_node, *qdq_node.args[1:]) - dq_node = graph_module.graph.create_node( - "call_function", dq_op, dq_args - ) - dq_node.meta = copy_meta(node.meta) - - return dq_node - def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph for node in graph.nodes: @@ -69,7 +44,7 @@ def call(self, graph_module: torch.fx.GraphModule): unsqueeze_node.meta = copy_meta( input_node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)} ) - qdq_node_after_unsqueeze = self.append_qdq( + qdq_node_after_unsqueeze = append_qdq( graph_module=graph_module, node=unsqueeze_node, qdq_node=input_node, @@ -139,7 +114,7 @@ def call(self, graph_module: torch.fx.GraphModule): conv2d_node.meta = copy_meta( node.meta, lambda m: {**m, "val": m["val"].unsqueeze(2)} ) - qdq_node_after_conv2d = self.append_qdq( + qdq_node_after_conv2d = append_qdq( graph_module=graph_module, node=conv2d_node, qdq_node=list(node.users)[0], diff --git a/backends/qualcomm/_passes/convert_linear_to_conv2d.py b/backends/qualcomm/_passes/convert_linear_to_conv2d.py new file mode 100644 index 00000000000..19caa7017d5 --- /dev/null +++ b/backends/qualcomm/_passes/convert_linear_to_conv2d.py @@ -0,0 +1,232 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.qualcomm._passes.utils import append_qdq, copy_meta +from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule +from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix + + +def _pad_list_to_4(lst): + return lst + [1] * (4 - len(lst)) if len(lst) < 4 else lst[:4] + + +class ConvertLinearToConv2d(ExportPass): + """ + Replace aten.linear.default with equivalent 1x1 conv2d using call_function nodes. + """ + + def __init__(self, edge_program: torch.export.ExportedProgram): + super().__init__() + self.edge_program = edge_program + self.per_block_dq = torch.ops.torchao.dequantize_affine.default + + def _register_tensor( + self, + gm: torch.fx.GraphModule, + node: torch.fx.Node, + tensor_constant: torch.Tensor, + ) -> torch.fx.Node: + new_node_name = get_new_attr_name_with_prefix(node.name)(gm) + gm.register_buffer(new_node_name, tensor_constant) + + with gm.graph.inserting_before(node): + get_attr_node = gm.graph.get_attr(new_node_name) + get_attr_node.meta["val"] = tensor_constant + return get_attr_node + + def _append_dq( + self, + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + qdq_node: torch.fx.Node, + ): + q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + + if qdq_node.target not in {q_op, dq_op}: + return node + + with graph_module.graph.inserting_after(node): + dq_args = (node, *qdq_node.args[1:]) + dq_node = graph_module.graph.create_node("call_function", dq_op, dq_args) + dq_node.meta = copy_meta(node.meta) + return dq_node + + def _create_node( + self, graph_module, target, args, meta_node, new_meta_val, qdq_node + ): + new_node = graph_module.graph.call_function(target, args) + new_node.meta = copy_meta( + meta_node.meta, + lambda m, new_meta_val=new_meta_val: { + **m, + "val": new_meta_val, + }, + ) + dq_node = append_qdq( + graph_module=graph_module, + node=new_node, + qdq_node=qdq_node, + ) + return dq_node + + def _reshape_weight(self, graph_module, weight_node, dq_node): + # After export, constant node will be placeholder from edge_program + weight_val = get_parameter(weight_node, self.edge_program) + assert weight_val is not None, "Cannot get the weight in linear node." + + weight_val = weight_val.reshape(*weight_val.shape, 1, 1) + # Create the new weight node when several node share the same weight + # such as embedding and lm_head in LLM. + if len(list(weight_node.users)) > 1: + weight_node = self._register_tensor(graph_module, weight_node, weight_val) + dq_node = self._append_dq(graph_module, weight_node, dq_node) + else: + set_parameter( + ( + torch.nn.Parameter(weight_val) + if weight_val.dtype == torch.float + else weight_val + ), + weight_node, + self.edge_program, + ) + + # Update node meta val + weight_node.meta["val"] = weight_node.meta["val"].reshape(weight_val.shape) + dq_node.meta["val"] = dq_node.meta["val"].reshape(weight_val.shape) + # Update block size for per-block quant + if dq_node.target == self.per_block_dq: + new_args = list(dq_node.args) + # pad block size + new_args[1] = _pad_list_to_4(list(new_args[1])) + dq_node.args = tuple(new_args) + + return dq_node + + def call(self, graph_module: GraphModule): + graph = graph_module.graph + + for node in list(graph.nodes): + if node.target == torch.ops.aten.linear.default: + input_node = node.args[0] + # In quantization flow, weight_arg will be dq node. + weight_arg = node.args[1] + weight_node = ( + weight_arg if weight_arg.op == "placeholder" else weight_arg.args[0] + ) + bias_arg = node.args[2] if len(node.args) > 2 else None + + input_meta_val = input_node.meta["val"] + output_meta_val = node.meta["val"] + if bias_arg: + bias_meta_val = bias_arg.meta["val"] + + rank = input_meta_val.ndim + with graph.inserting_before(node): + # Step 1: reshape input + # rank = 2: (dim, C) -> (1, C, 1, dim) + # rank = 3: (N, dim, C) -> (N, C, 1, dim) + # rank = 4: (N, H, W, C) -> (N, C, H, W) + order = (0, 3, 1, 2) + if rank <= 3: + # (dim, C) -> (1, C, 1, dim) + # (N, dim, C) -> (N, C, 1, dim) + shape = ( + (1, *input_meta_val.shape, 1) + if rank == 2 + else (*input_meta_val.shape, 1) + ) + x_meta_val = input_meta_val.reshape(shape) + input_node = self._create_node( + graph_module, + torch.ops.aten.reshape.default, + (input_node, shape), + node, + x_meta_val, + input_node, + ) + order = (0, 2, 3, 1) + + x_meta_val = x_meta_val.permute(order) + x = self._create_node( + graph_module, + torch.ops.aten.permute.default, + (input_node, order), + node, + x_meta_val, + input_node, + ) + + # Step 2: reshape weight + weight_arg = self._reshape_weight( + graph_module, weight_node, weight_arg + ) + weight_meta_val = weight_arg.meta["val"] + + conv_args = [x, weight_arg] + conv_args_meta_val = [x_meta_val, weight_meta_val] + if bias_arg: + conv_args.append(bias_arg) + conv_args_meta_val.append(bias_meta_val) + else: + conv_args.append(None) + conv_args_meta_val.append(None) + + conv_args.extend( + [[1, 1], [0, 0], [1, 1], 1] + ) # stride, padding, dilation, groups + conv_node_val = torch.nn.functional.conv2d( + *conv_args_meta_val, + stride=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + ) + conv_node = self._create_node( + graph_module, + torch.ops.aten.conv2d.default, + tuple(conv_args), + node, + conv_node_val, + list(node.users)[0], + ) + + # Step 3: restore shape + # rank = 2: (1, C, 1, dim) -> (dim, C) + # rank = 3: (N, C, 1, dim) -> (N, dim C) + # rank = 4: (N, C, H, W) -> (N, H, W, C) + order = (0, 2, 3, 1) if rank == 4 else (0, 3, 1, 2) + y_meta_val = conv_node_val.permute(order) + y = self._create_node( + graph_module, + torch.ops.aten.permute.default, + (conv_node, order), + node, + y_meta_val, + list(node.users)[0], + ) + if rank <= 3: + target_shape = output_meta_val.shape + y_meta_val = y_meta_val.reshape(target_shape) + y = self._create_node( + graph_module, + torch.ops.aten.reshape.default, + (y, target_shape), + node, + y_meta_val, + list(node.users)[0], + ) + + node.replace_all_uses_with(y) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 152433195cd..461fb07fb16 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -15,6 +15,7 @@ AnnotateUnbind, ConvertBmmToMatmul, ConvertConv1dToConv2d, + ConvertLinearToConv2d, ConvertSquareToPow, DecomposeAny, DecomposeCDist, @@ -82,7 +83,6 @@ def get_capture_program_passes(): (AnnotateStack, True), (AnnotateUnbind, True), (ConvertBmmToMatmul, False), - (ConvertConv1dToConv2d, True), (DecomposeAny, True), (DecomposeColIm, True), (DecomposeMinMaxDim, True), @@ -92,7 +92,7 @@ def get_capture_program_passes(): (I64toI32, True), (LayoutTransform, True), (RecomposePixelUnshuffle, True), - (RecomposeRmsNorm, False), + (RecomposeRmsNorm, True), (Remove0DTensor, True), (RemoveRedundancy, True), (TagQuantIO, False), @@ -190,6 +190,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(RemoveRedundancy(quantization_capture=True)) self.add_pass(ReduceDynamicRange()) self.add_pass(RecomposePixelUnshuffle(quantization_capture=True)) + self.add_pass(RecomposeRmsNorm(quantization_capture=True)) self.add_pass(ReplaceArangeArgs()) self.add_pass(DecomposeCDist()) self.add_pass(DecomposeScaledDotProductAttention()) @@ -203,7 +204,9 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(LiftConstantScalarOperands()) return self._transform(graph_module) - def transform_for_export_pipeline(self, exported_program: ExportedProgram): + def transform_for_export_pipeline( + self, exported_program: ExportedProgram, convert_linear_to_conv2d: bool = False + ): self.add_pass(DecomposeCDist()) self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) @@ -213,6 +216,8 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram): # this pass will rewrite state_dict, it needs to be accomplished before # to_edge_transform_and_lower self.add_pass(ConvertConv1dToConv2d(exported_program)) + if convert_linear_to_conv2d: + self.add_pass(ConvertLinearToConv2d(exported_program)) self.add_pass(ConvertSquareToPow()) self.add_pass(LiftConstantScalarOperands()) self._transform(exported_program.graph_module) diff --git a/backends/qualcomm/_passes/recompose_rms_norm.py b/backends/qualcomm/_passes/recompose_rms_norm.py index 2e2063cdf6e..7c5b864377e 100644 --- a/backends/qualcomm/_passes/recompose_rms_norm.py +++ b/backends/qualcomm/_passes/recompose_rms_norm.py @@ -4,70 +4,139 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch - +from executorch.backends.qualcomm._passes.utils import find_patterns from executorch.backends.qualcomm.builders.node_visitor import dq_ops -from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +def _is_node(node): + return isinstance(node, torch.fx.Node) + + +def _is_call(node): + return _is_node(node) and node.op == "call_function" + + +def _is_placeholder(node): + return _is_node(node) and node.op == "placeholder" + + +def _is_get_attr(node): + return _is_node(node) and node.op == "get_attr" + + +def _is_add(node): + return _is_call(node) and node.target in [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.add.Scalar, + torch.ops.aten.add.Tensor, + torch.ops.aten.add.Scalar, + ] + + +def _is_dq(node): + return _is_call(node) and node.target in dq_ops + + +def _is_mean(node): + return _is_call(node) and node.target in [ + exir_ops.edge.aten.mean.dim, + torch.ops.aten.mean.dim, + ] + + +def _is_mul(node): + return _is_call(node) and node.target in [ + exir_ops.edge.aten.mul.Tensor, + torch.ops.aten.mul.Tensor, + ] + + +def _is_pow(node): + return _is_call(node) and node.target in [ + exir_ops.edge.aten.pow.Tensor_Tensor, + exir_ops.edge.aten.pow.Tensor_Scalar, + torch.ops.aten.pow.Tensor_Scalar, + ] + + +def _is_rsqrt(node): + return _is_call(node) and node.target in [ + exir_ops.edge.aten.rsqrt.default, + torch.ops.aten.rsqrt.default, + ] class RecomposeRmsNorm(ExportPass): """ Merge decomposed operators back to one super node. - TODO: After replacing export_to_edge with to_edge_transform_and_lowering - in examples/models/llama/export_llama_lib.py, this pass can be removed """ - def __init__(self, edge_program: torch.export.ExportedProgram): + def __init__(self, quantization_capture=False): super(RecomposeRmsNorm, self).__init__() - self.edge_program = edge_program + self.rms_norm_target = exir_ops.edge.aten.rms_norm.default + self.skip_targets = [ + exir_ops.edge.aten.to.dtype, + ] + self.quantization_capture = quantization_capture + if quantization_capture: + self.rms_norm_target = torch.ops.aten.rms_norm.default + self.skip_targets = [ + torch.ops.aten.to.dtype, + ] - def _get_eps_node(self, nodes): - # eps: one of inputs of add node - add_node = [n for n in nodes if hasattr(n, "name") and "add" in n.name][0] - for a in add_node.args: - if isinstance(a, float) or a.op != "call_function": - return a - - def _get_gamma_node(self, output_node): - # gamma: one of inputs of output node - for a in output_node.args: - if a.op != "call_function" or a.target in dq_ops: - return a + def _get_input_node(self, node): + input_node = node.args[0] + while input_node.target in self.skip_targets: + input_node = input_node.args[0] + return input_node def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph - partitions = get_source_partitions( - graph, [torch.nn.RMSNorm, torch.ops.aten.rms_norm.default] - ) - for _, src_partitions in partitions.items(): - for src_partition in src_partitions: - input_len = len(src_partition.input_nodes) - if input_len == 1: - input_node = src_partition.input_nodes[0] - elif input_len == 2: - inp_0, inp_1 = src_partition.input_nodes - input_node = inp_0 if len(inp_0.users) == 2 else inp_1 - else: - raise RuntimeError( - f"Found a edge case of rms_node partition {src_partition}, which has {input_len} inputs" - ) - output_node = src_partition.output_nodes[0] - eps = self._get_eps_node(src_partition.nodes) - if isinstance(eps, torch.fx.Node) and is_parameter( - eps, self.edge_program - ): - eps = get_parameter(eps, self.edge_program).item() - gamma_node = self._get_gamma_node(output_node) + # Root Mean Square normalization math equivalent implementation + patterns = [ + # transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm + [_is_mul, "*", _is_mul, _is_rsqrt, _is_add, _is_mean, _is_pow], + # executorch.examples.models.llama.norm.RMSNorm + [_is_mul, "*", _is_mul, _is_rsqrt, _is_add, _is_mean, _is_mul], + ] + + for node in graph.nodes: + if not _is_mul(node): + continue + + rms_norm_patterns = [ + pattern + for pattern in find_patterns(node, patterns) + if pattern is not None + ] + + if len(rms_norm_patterns) > 0: + # Use first matched pattern + rms_norm_pattern = rms_norm_patterns[0][0] + last_mul_node = rms_norm_pattern[0] + gamma_node = None + # weight should be a constant + for arg in last_mul_node.args: + if _is_get_attr(arg) or _is_placeholder(arg) or _is_dq(arg): + gamma_node = arg + + if gamma_node is None: + continue + + eps = rms_norm_pattern[4].args[1] + if isinstance(eps, torch.fx.Node): + eps = eps.meta["val"].constant.item() + input_node = self._get_input_node(rms_norm_pattern[6]) - with graph.inserting_before(output_node): + with graph.inserting_before(last_mul_node): # args schema # (Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor rms_node = graph.create_node( "call_function", - exir_ops.edge.aten.rms_norm.default, + self.rms_norm_target, ( input_node, list(gamma_node.meta["val"].shape), @@ -75,11 +144,11 @@ def call(self, graph_module: torch.fx.GraphModule): eps, ), ) - users = output_node.users.copy() + users = last_mul_node.users.copy() for user in users: - user.replace_input_with(output_node, rms_node) + user.replace_input_with(last_mul_node, rms_node) # copy metadata - rms_node.meta = output_node.meta + rms_node.meta = last_mul_node.meta graph.eliminate_dead_code() graph_module.recompile() diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index ae11ba7b325..20492495156 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict +from typing import Callable, Dict, List import torch from executorch.backends.qualcomm.builders.utils import get_parameter @@ -121,3 +121,118 @@ def is_float_tensor(node: torch.fx.Node) -> bool: if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): return False return node.meta["val"].dtype == torch.float32 + + +def _is_node(node): + return isinstance(node, torch.fx.Node) + + +def _pred(node, pat): + return isinstance(pat, Callable) and pat(node) + + +def _next(node, from_args=True): + if from_args: + yield from [i for i in node.args if _is_node(i)] + else: + yield from list(node.users) + + +def _find_pattern( + node: torch.fx.Node, + pattern: List[Callable[[torch.fx.Node], bool] | str], + from_args: bool = True, + max_wildcard_life: int = 3, + verbose: bool = False, +): + """ + Implement wildcard pattern matching + - node: fx.Node + - pattern: predicate list, can contain followings + Callable(fx.node): predicate + '*': wildcard + - from_args: if True find from node.args, otherwise from node.users + - max_wildcard_life: max number of skips for wildcard + + If not matched, return None. + Otherwise, return list of matched node list, which is the same length as pattern + """ + + asterisk = "*" + + def _probe( + cur, hist, pat_idx, asterisk_life_count=max_wildcard_life, verbose=verbose + ): + if pat_idx == len(pattern): + # Expected len(hist) is equal to pat_idx + assert len(hist) == len(pattern) + if list(hist) not in matched: + matched.append(list(hist)) + return + if verbose: + print( + f"cur:{cur}, idx:{pat_idx}, life={asterisk_life_count}, pattern:{pattern[pat_idx]} hist={hist}" + ) + if _pred(cur, pattern[pat_idx]): + hist.append(cur) + for child in _next(cur, from_args): + _probe(child, hist, pat_idx + 1) + hist.pop() + elif pattern[pat_idx] == asterisk and asterisk_life_count > 0: + # 3 cases: ignore/consume/keep asterisk + # 1, Ignore asterisk + hist.append(None) + _probe(cur, hist, pat_idx + 1) + hist.pop() + + # 2. Consume asterisk + hist.append(None) + for child in _next(cur, from_args): + _probe(child, hist, pat_idx + 1) + hist.pop() + + # 3. keep asterisk and skip to next node + for child in _next(cur, from_args): + _probe(child, hist, pat_idx, asterisk_life_count - 1) + + # Check if pattern is valid + assert all( + isinstance(i, Callable) or (isinstance(i, str) and i == "*") for i in pattern + ), f"Invalid pattern: {pattern}" + + # Start probing + matched = [] + _probe(node, [], 0) + return matched if matched else None + + +def find_patterns(node, patterns, **kwargs): + assert isinstance(patterns, list) and isinstance(patterns[0], list) + results = [] + for pattern in patterns: + result = _find_pattern(node, pattern, **kwargs) + results.append(result) + return results + + +def append_qdq( + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + qdq_node: torch.fx.Node, +): + q_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dq_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + + if qdq_node.target not in {q_op, dq_op}: + return node + + with graph_module.graph.inserting_after(node): + q_args = (node, *qdq_node.args[1:]) + q_node = graph_module.graph.create_node("call_function", q_op, q_args) + q_node.meta = copy_meta(node.meta) + q_node.meta["val"] = q_node.meta["val"].to(q_args[-1]) + with graph_module.graph.inserting_after(q_node): + dq_args = (q_node, *qdq_node.args[1:]) + dq_node = graph_module.graph.create_node("call_function", dq_op, dq_args) + dq_node.meta = copy_meta(node.meta) + return dq_node diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 68873d15b3e..4b9dbeffa73 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -25,6 +25,7 @@ op_ceil, op_clamp, op_conv2d, + op_copy, op_cos, op_cum_sum, op_depth_to_space, @@ -85,6 +86,7 @@ op_sin, op_skip_ops, op_slice_copy, + op_slice_scatter, op_softmax, op_space_to_depth, op_split_with_sizes, @@ -126,6 +128,7 @@ op_ceil, op_clamp, op_conv2d, + op_copy, op_cos, op_cum_sum, op_depth_to_space, @@ -186,6 +189,7 @@ op_sin, op_skip_ops, op_slice_copy, + op_slice_scatter, op_softmax, op_space_to_depth, op_split_with_sizes, diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index e81a80b3517..bc2b62c8c0b 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -59,6 +59,8 @@ QNN_TENSOR_TYPE_MAP = { torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + # Note that there is no float64 tensor data type in Qnn. + torch.float64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8, torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16, torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, @@ -176,6 +178,10 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict): # OIHW (pytorch) -> HWIO (QNN) quant_config[QCOM_AXIS] = 3 quant_config[QCOM_AXIS_ORDER] = (2, 3, 1, 0) + elif "linear" in user_0.target.__name__: + # OI (pytorch) -> OI (QNN) + quant_config[QCOM_AXIS] = 0 + quant_config[QCOM_AXIS_ORDER] = (0, 1) else: raise AttributeError("undetermined axis for block quantization") diff --git a/backends/qualcomm/builders/op_copy.py b/backends/qualcomm/builders/op_copy.py new file mode 100644 index 00000000000..164c910835e --- /dev/null +++ b/backends/qualcomm/builders/op_copy.py @@ -0,0 +1,64 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpReshape, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Copy(NodeVisitor): + target = ["aten.copy.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[1]) + input_tensor = self.get_tensor(input_node, node) + copy_inp_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + copy_input_tensors = [copy_inp_tensor_wrapper] + + if quant_attrs := input_node.meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + # Because there is no output after convert_pt2e, the QCOM_QUANT_ATTRS of node is none + node.meta[QCOM_QUANT_ATTRS] = quant_attrs + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + copy_output_tensors = [output_tensor_wrapper] + + copy_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReshape.op_name, + ) + copy_op.AddInputTensors(copy_input_tensors) + copy_op.AddOutputTensors(copy_output_tensors) + + return copy_op diff --git a/backends/qualcomm/builders/op_slice_scatter.py b/backends/qualcomm/builders/op_slice_scatter.py new file mode 100644 index 00000000000..9fa162d6653 --- /dev/null +++ b/backends/qualcomm/builders/op_slice_scatter.py @@ -0,0 +1,122 @@ +from typing import cast, Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import torch + +from executorch.exir.dialects._ops import ops as exir_ops + +from .node_visitor import NodeVisitor +from .node_visitor_manager import register_node_visitor +from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class SliceScatterVisitor(NodeVisitor): + target = ["aten.slice_scatter.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = self.get_node(node.args[0]) + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + value_node = self.get_node(node.args[1]) + value_tensor = self.get_tensor(value_node, node) + value_tensor_wrapper = self.define_tensor( + value_node, + node, + value_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + dim = cast(int, node.args[2]) if len(node.args) > 2 else 0 + if dim < 0: + dim = dim % len(input_tensor.shape) + + start = ( + cast(int, node.args[3]) + if len(node.args) > 3 and node.args[3] is not None + else 0 + ) + if start < 0: + start = start % input_tensor.shape[dim] + + if len(node.args) > 4: + end = min(cast(int, node.args[4]), input_tensor.shape[dim]) + if end < 0: + end = end % input_tensor.shape[dim] + else: + end = input_tensor.shape[dim] + + step = node.args[5] if len(node.args) > 5 else 1 + + target_index_shape = [] + ranges = [] + # Collect the index + for i in range(dim + 1): + if i == dim: + target_range = torch.tensor(range(start, end, step), dtype=torch.int32) + target_index_shape.append(target_range.size(-1)) + ranges.append(target_range) + else: + size = input_tensor.size(i) + target_index_shape.append(size) + ranges.append(torch.arange(size, dtype=torch.int32)) + # last dim means x-tuple index + target_index_shape.append(dim + 1) + target_index_tensor = ( + torch.cartesian_prod(*ranges).reshape(target_index_shape).contiguous() + ) + + target_index_node = torch.fx.Node( + node.graph, + node.name + "_target_index", + "call_function", + exir_ops.edge.aten.tensor.default, + (), # args + {}, # kwargs + ) + target_index_tensor_wrapper = self.define_tensor( + target_index_node, + node, + target_index_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) + + slice_scatter_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpScatterNd.op_name, + ) + slice_scatter_op.AddInputTensors( + [ + input_tensor_wrapper, + target_index_tensor_wrapper, + value_tensor_wrapper, + ] + ) + slice_scatter_op.AddOutputTensors([output_tensor_wrapper]) + + return slice_scatter_op diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 3576cbab63a..c7d9aaf1030 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -11,8 +11,6 @@ not_supported_operator = [ exir_ops.edge.aten.clone.default, - exir_ops.edge.aten.slice_scatter.default, - exir_ops.edge.aten.copy.default, exir_ops.edge.quantized_decomposed.embedding_4bit.dtype, ] diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 721fc85362f..78fb8d15a4e 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -162,6 +162,11 @@ def annotate_single_in_single_out( ) +@register_annotator([torch.ops.aten.to.dtype]) +def annotate_to_dtype(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.atan.default]) def annotate_atan(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -678,6 +683,22 @@ def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.slice_scatter.default]) +def annotate_slice_scatter(node: Node, quantization_config: QuantizationConfig) -> None: + input = node.args[0] + value = node.args[1] + + input_qspec_map = {} + input_qspec_map[input] = quantization_config.input_activation + input_qspec_map[value] = SharedQuantizationSpec((input, node)) + + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=SharedQuantizationSpec((input, node)), + _annotated=True, + ) + + @register_annotator([torch.ops.aten.sqrt.default]) def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -1067,6 +1088,7 @@ def annotate_cdist(node: Node, quantization_config: QuantizationConfig) -> None: torch.ops.aten.conv1d.default, torch.ops.aten.conv_transpose2d.input, torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.convolution.default, ] ) def annotate_conv(node: Node, quantization_config: QuantizationConfig) -> None: @@ -1121,6 +1143,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None if _is_annotated([node]): return + # block quantization + if quantization_config.block_size is not None: + quantization_config.weight.observer_or_fake_quant_ctr.p.keywords.update( + {"block_size": quantization_config.block_size} + ) + annotate_input_qspec_map( node, act_node, diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index 43d968813a9..297f81fc85d 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -116,6 +116,21 @@ if [ "$BUILD_AARCH64" = true ]; then -B$EXAMPLE_ROOT cmake --build $EXAMPLE_ROOT -j$BUILD_JOB_NUMBER + + LLAMA_EXAMPLE_ROOT=examples/models/llama + cmake $PRJ_ROOT/$LLAMA_EXAMPLE_ROOT \ + -DBUILD_TESTING=OFF \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \ + -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ + -DANDROID_ABI='arm64-v8a' \ + -DANDROID_PLATFORM=android-30 \ + -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ + -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -B$LLAMA_EXAMPLE_ROOT + + cmake --build $LLAMA_EXAMPLE_ROOT -j$BUILD_JOB_NUMBER fi if [ "$BUILD_X86_64" = true ]; then @@ -172,4 +187,18 @@ if [ "$BUILD_X86_64" = true ]; then -B$EXAMPLE_ROOT cmake --build $EXAMPLE_ROOT -j$BUILD_JOB_NUMBER + + LLAMA_EXAMPLE_ROOT=examples/models/llama + cmake $PRJ_ROOT/$LLAMA_EXAMPLE_ROOT \ + -DBUILD_TESTING=OFF \ + -DCMAKE_BUILD_TYPE=$BUILD_TYPE \ + -DANDROID_ABI='arm64-v8a' \ + -DANDROID_PLATFORM=android-30 \ + -DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ + -DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \ + -DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \ + -B$LLAMA_EXAMPLE_ROOT + + cmake --build $LLAMA_EXAMPLE_ROOT -j$BUILD_JOB_NUMBER fi diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index f5f6676e123..54624471763 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -274,6 +274,19 @@ def forward(self, x, y): return torch.cat((y, y, x, x), axis=2) +class CausalMask(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("causal_mask", torch.zeros((1, 1, 1, 128))) + self.mask_length = 128 + + def forward(self, padding_mask): + self.causal_mask[:, :, :, : self.mask_length] = self.causal_mask[ + :, :, :, : self.mask_length + ].masked_fill(padding_mask, 1) + return self.causal_mask + 1 + + class CDist(torch.nn.Module): def __init__(self): super().__init__() @@ -1124,7 +1137,7 @@ def forward(self, x): class Linear(torch.nn.Module): def __init__(self, use_bias: bool = True): super().__init__() - self.linear = torch.nn.Linear(4, 5, use_bias).eval() + self.linear = torch.nn.Linear(512, 32, use_bias).eval() def forward(self, x): return self.linear(x) @@ -1702,6 +1715,20 @@ def forward(self, x, y): ) +class SliceScatter(torch.nn.Module): + def __init__(self, dim, start, end, step): + super().__init__() + self.dim = dim + self.start = start + self.end = end + self.step = step + + def forward(self, x, y): + return x.slice_scatter( + y, dim=self.dim, start=self.start, end=self.end, step=self.step + ) + + class Softmax(torch.nn.Module): def __init__(self, dim): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 5a71244bacd..f5f8a1d904f 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -890,7 +890,7 @@ def test_qnn_backend_linalg_vector_norm(self): def test_qnn_backend_linear(self): module = Linear() # noqa: F405 - sample_input = (torch.randn([3, 4]),) + sample_input = (torch.randn([3, 512]),) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_log(self): @@ -1102,6 +1102,42 @@ def test_qnn_backend_slice_copy(self): for module, sample_input in zip(modules, sample_inputs): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_slice_scatter(self): + test_comb = [ + { + QCOM_MODULE: [ + SliceScatter(dim=0, start=3, end=5, step=1) # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.zeros(8, 8), + torch.ones(2, 8), + ) + ], + }, + { + QCOM_MODULE: [ + SliceScatter(dim=1, start=2, end=6, step=2) # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + ( + torch.zeros(8, 8), + torch.ones(8, 2), + ) + ) + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + self.lower_module_and_test_output(module, sample_input) + index += 1 + def test_qnn_backend_stack(self): module = Stack() # noqa: F405 sample_input = ( @@ -1239,6 +1275,11 @@ def test_qnn_backend_argmin_view_squeeze_conv2d(self): sample_input = (torch.randn(32), torch.randn(32, 3, 32, 32)) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_causal_mask(self): + module = CausalMask() # noqa: F405 + sample_input = (torch.rand((1, 1, 1, 128)) < 0.5,) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_chunk_add(self): module = ChunkAdd() # noqa: F405 torch.manual_seed(8) @@ -1448,7 +1489,7 @@ def test_qnn_backend_16a4w_layer_norm(self): def test_qnn_backend_16a4w_linear(self): module = Linear() # noqa: F405 - sample_input = (torch.randn([3, 4]),) + sample_input = (torch.randn([3, 512]),) module = self.get_qdq_module( module, sample_input, @@ -1458,7 +1499,7 @@ def test_qnn_backend_16a4w_linear(self): def test_qnn_backend_16a4w_per_channel_linear(self): module = Linear(use_bias=False) # noqa: F405 - sample_input = (torch.randn([3, 4]),) + sample_input = (torch.randn([3, 512]),) module = self.get_qdq_module( module, sample_input, @@ -1469,7 +1510,7 @@ def test_qnn_backend_16a4w_per_channel_linear(self): def test_qnn_backend_16a4w_per_channel_linear_with_bias(self): module = Linear() # noqa: F405 - sample_input = (torch.randn([3, 4]),) + sample_input = (torch.randn([3, 512]),) module = self.get_qdq_module( module, sample_input, @@ -2338,16 +2379,37 @@ def test_qnn_backend_linalg_vector_norm(self): def test_qnn_backend_linear(self): module = Linear() # noqa: F405 - sample_input = (torch.randn([3, 4]),) + sample_input = (torch.randn([3, 512]),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + @unittest.skipIf(is_qnn_sdk_version_less_than("2.30"), "UT pass after QNN 2.30") + def test_qnn_backend_linear_block(self): + modules = [ + Linear(use_bias=False), # noqa: F405 + Linear(use_bias=True), # noqa: F405 + ] + + sample_input = (torch.randn([3, 512]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + # update block size for linear weight (OI) + # channel dimension(O) is defaultly sliced in QNN + # divide dimension(I) into 16 groups + module = self.get_qdq_module( + module, + sample_input, + quant_dtype=QuantDtype.use_16a4w_block, + block_size_map={"linear": (1, 32)}, + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_linear_qat(self): """ Prototype to test qat model """ module = Linear() # noqa: F405 - sample_input = (torch.randn([3, 4]),) + sample_input = (torch.randn([3, 512]),) prepared = self.get_prepared_qat_module(module, sample_input) module = self.get_converted_sgd_trained_module(module, prepared, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -2600,6 +2662,43 @@ def test_qnn_backend_slice_copy(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_slice_scatter(self): + test_comb = [ + { + QCOM_MODULE: [ + SliceScatter(dim=0, start=3, end=5, step=1) # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.zeros(8, 8), + torch.ones(2, 8), + ) + ], + }, + { + QCOM_MODULE: [ + SliceScatter(dim=1, start=2, end=6, step=2) # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + ( + torch.zeros(8, 8), + torch.ones(8, 2), + ) + ) + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + index += 1 + def test_qnn_backend_softmax(self): modules = [Softmax(dim=1), Softmax(dim=-1)] # noqa: F405 sample_input = (torch.randn([1, 4, 8, 8]),) @@ -2751,6 +2850,12 @@ def test_qnn_backend_argmin_view_squeeze_conv2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_causal_mask(self): + module = CausalMask() # noqa: F405 + sample_input = (torch.rand((1, 1, 1, 128)) < 0.5,) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_chunk_add(self): module = ChunkAdd() # noqa: F405 torch.manual_seed(8) @@ -4878,7 +4983,7 @@ def test_static_qwen3(self): msg["inference_speed"], inference_speed_ref[self.model] ) - def test_smollm2(self): + def test_static_smollm2(self): if not self.required_envs(): self.skipTest("missing required envs") @@ -4936,6 +5041,60 @@ def test_smollm2(self): self.assertLessEqual(msg["wiki_ppl"], 25) self.assertGreaterEqual(msg["inference_speed"], 200) + def test_qwen2_5(self): + if not self.required_envs([]): + self.skipTest("missing required envs") + prompt = "My favourite condiment is " + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py", + "--prompt", + prompt, + "--decoder_model", + "qwen2.5_0.5B", + "--ptq", + "16a8w", + "--enable_spinquant_r3", + "--max_seq_len", + "128", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.compile_only: + cmds.extend(["--compile_only"]) + elif self.device: + cmds.extend(["--device", self.device]) + if self.host: + cmds.extend(["--host", self.host]) + elif self.enable_x86_64: + cmds.extend(["--enable_x86_64"]) + if self.pre_gen_pte: + cmds.extend(["--pre_gen_pte", self.pre_gen_pte]) + + golden_start_with = "My favourite condiment is iced tea." + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + if not self.compile_only: + model_out = msg["result"][0] + self.assertTrue( + model_out.startswith(golden_start_with), + f"Expected Output: '{golden_start_with}' Actual Output: '{model_out}'", + ) + class TestExampleOssScript(TestQNN): def test_albert(self): diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 14153c6942e..be4e86de50f 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -333,8 +333,9 @@ def to_edge_transform_and_lower_to_qnn( passes_job: Optional[Union[OrderedDict, Dict[str, OrderedDict]]] = None, skip_node_id_set: Optional[set] = None, skip_node_op_set: Optional[set] = None, - skip_mutable_buffer: bool = False, - generate_etrecord: bool = False, + skip_mutable_buffer: Optional[bool] = False, + generate_etrecord: Optional[bool] = False, + convert_linear_to_conv2d: Optional[bool] = False, ) -> EdgeProgramManager: """ Transforms and lowers a given PyTorch module to the QNN backend. @@ -359,8 +360,10 @@ def to_edge_transform_and_lower_to_qnn( Set of node IDs to skip during partitioning. skip_node_op_set (Optional[set]): Set of node operations to skip during partitioning. - skip_mutable_buffer (Optional[set]): + skip_mutable_buffer (Optional[bool]): Whether to skip delegating the mutable buffer in QNN backend. + convert_linear_to_conv2d (Optional[bool]): + Whether to convert linear to conv2d in some cases to improve performance in HTP backend. Returns: EdgeProgramManager: @@ -432,7 +435,9 @@ def ensure_graph_specific_dict(value, graph_names): # If placed in the to_edge_transform_passes, it will be executed # after the lift_constant_tensor_pass, causing the operation builder # to fail to correctly retrieve the parameter by the get_parameter. - aten_programs[graph_name] = QnnPassManager().transform_for_export_pipeline(ep) + aten_programs[graph_name] = QnnPassManager().transform_for_export_pipeline( + ep, convert_linear_to_conv2d=convert_linear_to_conv2d + ) transform_passes[graph_name] = QnnPassManager().get_to_edge_transform_passes( ep, passes_job=passes_job[graph_name], dep_table=dep_table[graph_name] ) @@ -1040,7 +1045,7 @@ def generate_qnn_executorch_compiler_spec( qnn_executorch_options.log_level = ( QnnExecuTorchLogLevel.kLogLevelDebug if debug - else QnnExecuTorchLogLevel.kLogLevelWarn + else QnnExecuTorchLogLevel.kLogLevelError ) qnn_executorch_options.dump_intermediate_outputs = dump_intermediate_outputs diff --git a/examples/models/llama/CMakeLists.txt b/examples/models/llama/CMakeLists.txt index 2cc5902c43a..a485244f9a7 100644 --- a/examples/models/llama/CMakeLists.txt +++ b/examples/models/llama/CMakeLists.txt @@ -219,3 +219,4 @@ endif() target_include_directories(llama_main PUBLIC ${_common_include_directories}) target_link_libraries(llama_main PUBLIC llama_runner ${link_libraries}) target_compile_options(llama_main PUBLIC ${_common_compile_options}) +set_target_properties(llama_main PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'") diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 457349edd10..7ad029ad16a 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -72,8 +72,6 @@ get_quant_embedding_transform, get_quant_weight_transform, ) -from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm - from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis from .source_transformation.sdpa import ( replace_causal_mask, @@ -938,7 +936,6 @@ def _to_edge_and_lower_llama( # noqa: C901 AnnotateStack, ConvertBmmToMatmul, FoldQDQ, - RecomposeRmsNorm, TagQuantIO, ) @@ -980,7 +977,6 @@ def _to_edge_and_lower_llama( # noqa: C901 dep_table = get_passes_dependency_for_capture_program() passes_job[AnnotateStack][QCOM_PASS_ACTIVATE_KEY] = True passes_job[ConvertBmmToMatmul][QCOM_PASS_ACTIVATE_KEY] = True - passes_job[RecomposeRmsNorm][QCOM_PASS_ACTIVATE_KEY] = True passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ "get_quant_io_dtype_fn" @@ -1438,14 +1434,12 @@ def _get_source_transforms( # noqa transforms.append(get_model_with_r1_r2(optimized_rotation_path)) transforms.append(replace_attention_to_attention_sha) transforms.append(replace_causal_mask) - transforms.append(replace_rms_norm_with_native_rms_norm) # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`. transforms.append(convert_linear_to_conv2d) else: transforms.append(replace_kv_cache_with_simple_kv_cache) transforms.append(replace_sdpa_with_flex_sdpa) transforms.append(replace_causal_mask) - transforms.append(replace_rms_norm_with_native_rms_norm) if optimized_rotation_path: transforms.append(fuse_layer_norms) transforms.append(get_model_with_r1_r2(optimized_rotation_path)) diff --git a/examples/qualcomm/oss_scripts/llm_utils/README.md b/examples/qualcomm/oss_scripts/llm_utils/README.md new file mode 100644 index 00000000000..a713d6df5de --- /dev/null +++ b/examples/qualcomm/oss_scripts/llm_utils/README.md @@ -0,0 +1,78 @@ +## Tutorial to run [eval_decoder_model_qnn.py](./eval_decoder_model_qnn.py) +This script, [`eval_decoder_model_qnn.py`](./eval_decoder_model_qnn.py), is designed to evaluate large language models (LLMs) from transformers that have been compiled into ExecuTorch Portable Executable (PTE) format for execution on Qualcomm devices. It leverages the `lm-evaluation-harness` library to perform various NLP evaluation tasks. + +> ⚠️ **Important:** Note that this script runs PTE files generated specifically for Hugging Face Transformers, such as [qwen2_5.py](../qwen2_5/qwen2_5.py), rather than [the static LLaMA version](../llama/llama.py). + +### Features: + +* Evaluates ExecuTorch PTE models on Qualcomm devices (requires ADB setup and QNN SDK). +* Integrates with `lm-evaluation-harness` for standardized LLM evaluation tasks. + +### Prerequisites + +Before running this script, ensure you have the following: + +1. **Setup ExecuTorch** Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. +2. **Setup QNN ExecuTorch** Follow the [tutorial](https://pytorch.org/executorch/main/backends-qualcomm) to build Qualcomm AI Engine Direct Backend. +3. **`lm-evaluation-harness`:** Install the `lm-evaluation-harness` library. +4. **PTE Model:** A pre-exported ExecuTorch PTE model of your decoder LLM. +5. **Tokenizer:** A tokenizer in a format supported by `pytorch-tokenizers` (e.g., SentencePiece `.json` or `.model`, Tiktoken). + +### How to Use + +The script evaluates the model by running the PTE file on a connected Qualcomm device. + +#### Command Line Arguments: + +* `-a`, `--artifact`: (Required) Path for storing generated artifacts by this example. +* `--tokenizer_path`: (Required) Path to your tokenizer file (e.g., `tokenizer.model` or `tokenizer.json`). +* `--pte`: (Required) Path to the ExecuTorch Portable Executable (`.pte`) model file. +* `--logits_quant_attr_path`: (Optional) Path to a JSON file containing quantization attributes. This is needed if your PTE model uses tag quant I/O for logits and requires de-quantization before evaluation. +* `--max_seq_len`: (Optional, default: 128) Maximum sequence length the model can process. +* `--tasks`: (Optional, default: `["wikitext"]`) A list of `lm-evaluation-harness` tasks to evaluate. You can specify multiple tasks separated by spaces (e.g., `--tasks wikitext piqa`). +* `--limit`: (Optional) Number of samples to evaluate per task. If not set, all samples will be evaluated. +* `--num_fewshot`: (Optional) Number of examples to use in few-shot context for evaluation. +* `--model`: (Required for QNN execution) The SoC model name (e.g., `SM8550`, `SM8650`). +* `--device`: (Required for QNN execution) The ADB device ID. +* `--host`: (Required for QNN execution) The ADB host ID (usually `localhost`). +* `--build_folder`: (Optional, default: `build-android`) The build folder for ExecuTorch artifacts, relative to the current directory. + +#### Example Usage: + +```bash +python examples/qualcomm/oss_scripts/llm_utils/eval_decoder_model_qnn.py \ + --artifact ./eval_output \ + --tokenizer_path /path/to/your/tokenizer.model \ + --pte /path/to/your/model.pte \ + --model SM8550 \ + --device YOUR_DEVICE_ID \ + --host localhost \ + --tasks wikitext \ + --limit 1 \ + --max_seq_len 512 +``` + +Replace `/path/to/your/tokenizer.model`, `/path/to/your/model.pte`, and `YOUR_DEVICE_ID` with your actual paths and device ID. + +If your model's logits are quantized and require de-quantization: + +```bash +python examples/qualcomm/oss_scripts/llm_utils/eval_decoder_model_qnn.py \ + --artifact ./eval_output \ + --tokenizer_path /path/to/your/tokenizer.model \ + --pte /path/to/your/model.pte \ + --logits_quant_attr_path /path/to/your/logits_quant_attrs.json \ + --model SM8550 \ + --device YOUR_DEVICE_ID \ + --host localhost \ + --tasks wikitext \ + --limit 1 \ + --max_seq_len 512 +``` + +### Output + +The script will print the evaluation results for each specified task to the console, similar to the `lm-evaluation-harness` output format. For example: + +``` +wikitext: {'word_perplexity': ..., 'byte_perplexity': ..., 'bits_per_byte': ...} diff --git a/examples/qualcomm/oss_scripts/llm_utils/decoder_model_wrapper.py b/examples/qualcomm/oss_scripts/llm_utils/decoder_model_wrapper.py new file mode 100644 index 00000000000..f59dc548c44 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llm_utils/decoder_model_wrapper.py @@ -0,0 +1,160 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from typing import Optional + +import scipy +import torch +import transformers +from transformers import GenerationConfig, PretrainedConfig + +from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM + +TRANSFORMERS_VERSION = "4.53.1" + + +def save_config_to_constant_methods( + config: PretrainedConfig, + generation_config: Optional[GenerationConfig] = None, + **kwargs, +): + # Initialize metadata with values from model config + metadata = { + "get_bos_id": getattr(config, "bos_token_id", None), + "get_eos_id": getattr(config, "eos_token_id", None), + "get_vocab_size": getattr(config, "vocab_size", None), + "get_max_seq_len": getattr(config, "max_position_embeddings", None), + "use_kv_cache": getattr(generation_config, "use_cache", None), + "use_sdpa_with_kv_cache": False, + } + + # Safely access fields from generation_config if it exists + if generation_config is not None: + # Check for cache_config and its attributes + cache_config = getattr(generation_config, "cache_config", None) + if cache_config is not None: + max_seq_len = getattr(cache_config, "max_cache_len", None) + if max_seq_len is not None: + metadata["get_max_seq_len"] = max_seq_len + + # Combine with any additional kwargs and filter out None values + return {k: v for k, v in {**metadata, **kwargs}.items() if v is not None} + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +@torch._dynamo.assume_constant_result +def get_transposed_hadamard_matrix(head_dim): + r3_weight = torch.tensor( + scipy.linalg.hadamard(head_dim, dtype=float) / math.sqrt(head_dim), + dtype=torch.float32, + ) + return r3_weight.transpose(0, 1) + + +def _qnn_attention( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + if getattr(module.config, "enable_spinquant_r3", False): + r3_weight = get_transposed_hadamard_matrix(module.head_dim) + query = torch.matmul(query, r3_weight) + key = torch.matmul(key, r3_weight) + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = torch.nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query.dtype) + attn_weights = torch.nn.functional.dropout( + attn_weights, p=dropout, training=module.training + ) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def _qnn_attention_mask( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + **kwargs, +): + kv_arange = torch.arange(kv_length, device=cache_position.device) + reshaped_cache_position = cache_position.view(-1, 1) + + # Simplest and most efficient way to obtain a causal mask + causal_mask = kv_arange <= reshaped_cache_position + atten_mask = torch.full((causal_mask.shape[0], kv_length), torch.tensor(-65504.0)) + atten_mask = atten_mask.masked_fill(causal_mask, 0) + atten_mask = atten_mask[None, None, :, :].expand(batch_size, -1, -1, -1) + + return atten_mask + + +class QnnCausalLMExportableModule(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.config = model.config + self._metadata = save_config_to_constant_methods( + model.config, model.generation_config + ) + logging.info(f"Metadata to be recorded in PTE: {self._metadata}") + self.exportable_module = TorchExportableModuleForDecoderOnlyLM( + self.model, + max_batch_size=1, + max_cache_len=self._metadata.get("get_max_seq_len"), + ) + self._register_attention_mask_for_4_53(self.exportable_module) + + def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module): + if transformers.__version__ >= TRANSFORMERS_VERSION: + from transformers.masking_utils import AttentionMaskInterface + from transformers.modeling_utils import AttentionInterface + + AttentionInterface.register("qnn_attention", _qnn_attention) + AttentionMaskInterface.register("qnn_attention", _qnn_attention_mask) + exportable_module.model.model.config._attn_implementation = "qnn_attention" + self._metadata.update({"use_sdpa_with_kv_cache": False}) + + def get_example_inputs(self): + example_input_ids = torch.tensor([[1]], dtype=torch.long) + example_cache_position = torch.tensor([0], dtype=torch.long) + return (example_input_ids, example_cache_position) + + def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): + return self.exportable_module(input_ids, cache_position) + + def get_metadata(self): + return self._metadata diff --git a/examples/qualcomm/oss_scripts/llm_utils/eval_decoder_model_qnn.py b/examples/qualcomm/oss_scripts/llm_utils/eval_decoder_model_qnn.py new file mode 100644 index 00000000000..49cdc192f22 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llm_utils/eval_decoder_model_qnn.py @@ -0,0 +1,293 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import json +import os + +from typing import Optional, Union + +import numpy as np + +import torch +from executorch.examples.models.llama.evaluate.eager_eval import EagerEvalWrapper + +from executorch.examples.qualcomm.utils import ( + make_output_dir, + setup_common_args_and_variables, +) + +from lm_eval.evaluator import simple_evaluate +from pytorch_tokenizers import get_tokenizer +from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer +from pytorch_tokenizers.tiktoken import TiktokenTokenizer as Tiktoken + + +def create_device_inputs(example_inputs): + # TODO: support batch inputs if necessary + input_list = "" + inputs = [] + + for index, data in enumerate(example_inputs): + inputs.append(data) + input_list += " ".join([f"input_{index}_{i}.raw" for i in range(len(data))]) + input_list += "\n" + return inputs, input_list + + +class GraphModuleCalibrationWrapper(EagerEvalWrapper): + """ + A wrapper class for calibration + """ + + def __init__( + self, + model: torch.fx.GraphModule, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], + max_seq_length: Optional[int] = None, + use_kv_cache: bool = False, + generate_full_logits: bool = False, + enable_dynamic_shape: bool = True, + ): + super().__init__( + model=model, tokenizer=tokenizer, max_seq_length=max_seq_length + ) + self._model = model.to(self.device) + self._use_kv_cache = use_kv_cache + self._generate_full_logits = generate_full_logits + self._enable_dynamic_shape = enable_dynamic_shape + + def _model_call(self, inps): + if self._use_kv_cache: + if not self._enable_dynamic_shape: + # graph module exported without dynamic shape won't work with a different shape. + # And we have to do single token prefill here. + result_logits = [] + for pos in range(inps.shape[-1]): + pos_tensor = torch.tensor([pos], dtype=torch.int64) + logits = self._model(inps[:, pos : pos + 1], pos_tensor) + result_logits.append(logits) + if self._generate_full_logits: + return torch.cat(result_logits, dim=1) + else: + return torch.stack(result_logits, dim=1) + else: + pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) + # Batch process the whole sequence. + logits = self._model(inps[:, : self._max_seq_length], pos_tensor) + return logits + + else: + return self._model(inps) + + def _model_generate(self, context, max_length, eos_token_id): + raise Exception("unimplemented") + + +class QNNRunnerEvalWrapper(EagerEvalWrapper): + """ + A wrapper class for ExecuTorch Runtime integration with the + lm-evaluation-harness library. + """ + + def __init__( + self, + model: str, + tokenizer: Union[SentencePieceTokenizer, Tiktoken], + soc_model: str, + device: str, + host: str, + max_seq_length: Optional[int] = None, + output_dir: str = ".", + quant_attrs=None, + build_folder: str = "build-android", + ): + super().__init__(None, tokenizer, max_seq_length) + import getpass + + from executorch.examples.qualcomm.utils import SimpleADB + + self._model = model + self.output_dir = output_dir + self.quant_attrs = quant_attrs + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/meta_llama" + self.adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=build_folder, + pte_path=model, + workspace=workspace, + device_id=device, + host_id=host, + soc_model=soc_model, + ) + self.adb.push() + + def _model_call(self, inps): + # Given inps (tokens), return the logits from a single + # forward call + + # Example: + # inps: Tensor of shape (1, N) + # logits: Tensor of shape (1, N, vocab_size) + result_logits = [] + inputs = [] + + for pos in range(self._max_seq_length): + pos_tensor = torch.tensor([pos], dtype=torch.int64) + inputs.append([inps[:, pos : pos + 1], pos_tensor]) + + inputs, input_list = create_device_inputs(inputs) + self.adb.push(inputs=inputs, input_list=input_list, init_env=False) + self.adb.execute() + output_data_folder = f"{self.output_dir}/outputs" + make_output_dir(output_data_folder) + + def post_process(): + for f in sorted( + os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) + ): + output_tensor = None + if self.quant_attrs: + output_tensor = torch.from_numpy( + np.fromfile( + os.path.join(output_data_folder, f), dtype=np.uint16 + ).reshape(1, 1, -1) + ) + output_tensor = ( + output_tensor.to(torch.float32) - self.quant_attrs["zero_point"] + ) * self.quant_attrs["scale"] + else: + output_tensor = torch.from_numpy( + np.fromfile( + os.path.join(output_data_folder, f), dtype=np.float32 + ).reshape(1, 1, -1) + ) + + result_logits.append(output_tensor) + + self.adb.pull(output_path=self.output_dir, callback=post_process) + return torch.cat(result_logits, dim=1) + + +def gen_eval_wrapper( + args: argparse.ArgumentParser, +): + """ + Generates a wrapper interface around the provided model and tokenizer for + the lm-evaluation-harness library. + + Returns: + eval_wrapper (LM): A wrapper interface for the lm-evaluation-harness library. + """ + tokenizer = get_tokenizer(args.tokenizer_path) + + # ExecuTorch Binary Evaluation + if (model := args.pte) is not None: # pyre-ignore + assert args.device is not None, "please specify the device to execute pte" + quant_attrs = None + if args.logits_quant_attr_path is not None: + quant_attrs = json.load(open(f"{args.logits_quant_attr_path}")) + return QNNRunnerEvalWrapper( + model=model, + tokenizer=tokenizer, + soc_model=args.model, + device=args.device, + host=args.host, + max_seq_length=args.max_seq_len - 1, + output_dir=args.artifact, + quant_attrs=quant_attrs, + build_folder=args.build_folder, + ) + else: + raise RuntimeError("Currently only support evaluate pte on device") + + +def eval_llama( + args: argparse.ArgumentParser, +) -> None: + + # Generate the eval wrapper + eval_wrapper = gen_eval_wrapper(args) + + # Evaluate the model + with torch.no_grad(): + eval_results = simple_evaluate( + model=eval_wrapper, + tasks=args.tasks, + num_fewshot=args.num_fewshot, + limit=args.limit, + ) + + for task, res in eval_results["results"].items(): + print(f"{task}: {res}") + + +def main() -> None: + seed = 42 + torch.manual_seed(seed) + parser = setup_common_args_and_variables() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example.", + type=str, + ) + + parser.add_argument( + "--tokenizer_path", + help="path to tokenizer.json.", + type=str, + ) + + parser.add_argument( + "--pte", + type=str, + default=None, + help="[For ExecuTorch] Path to the ExecuTorch model being evaluated. If provided, don't go through the export flow", + ) + parser.add_argument( + "--logits_quant_attr_path", + type=str, + default=None, + help="For the pte with tag quant io, it needs to be dequantize and compute ppl.", + ) + + parser.add_argument( + "--max_seq_len", + type=int, + default=128, + help="This refers to maximum number of tokens that the model can process & consider at once to generate predictions/responses.", + ) + parser.add_argument( + "--tasks", + nargs="+", + type=str, + default=["wikitext"], + help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="number of samples to evalulate. If not set, evaluate all samples", + ) + parser.add_argument( + "--num_fewshot", + type=int, + default=None, + metavar="N", + help="Number of examples in few-shot context", + ) + + args = parser.parse_args() + + eval_llama(args) + + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/examples/qualcomm/oss_scripts/llm_utils/qnn_decoder_model_manager.py b/examples/qualcomm/oss_scripts/llm_utils/qnn_decoder_model_manager.py new file mode 100644 index 00000000000..2a2cd250a40 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llm_utils/qnn_decoder_model_manager.py @@ -0,0 +1,331 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from functools import partial +from typing import Callable, List + +import torch +from executorch.backends.qualcomm._passes import TagQuantIO +from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) +from executorch.backends.qualcomm.builders.utils import is_graph_output +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, + QCOM_QUANT_ATTRS_MAP, +) +from executorch.backends.qualcomm.utils.utils import ( + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + get_soc_to_chipset_map, + to_edge_transform_and_lower_to_qnn, +) +from executorch.devtools.backend_debug import print_delegation_info +from executorch.examples.qualcomm.oss_scripts.llm_utils.decoder_model_wrapper import ( + QnnCausalLMExportableModule, +) +from executorch.examples.qualcomm.utils import make_quantizer +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from pytorch_tokenizers import get_tokenizer +from torchao.quantization.pt2e import MinMaxObserver +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig + + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + +HUGGING_FACE_REPO_IDS = { + "qwen2.5_0.5B": "Qwen/Qwen2.5-0.5B", + "qwen2.5_1.5B_instruct": "Qwen/Qwen2.5-1.5B-Instruct", + "qwen2.5_0.5B_instruct": "Qwen/Qwen2.5-0.5B-Instruct", +} + + +def get_qnn_llm_edge_manager(model_name, max_seq_len=128, enable_spinquant_r3=True): + model_id = HUGGING_FACE_REPO_IDS[model_name] + config = AutoConfig.from_pretrained(model_id) + device = "cpu" + batch_size = 1 + dtype = "float32" + cache_implementation = "static" + attn_implementation = "eager" + + # Set configs + config.max_seq_len = max_seq_len + config.ar_len = 1 # kv mode + config.max_batch_size = batch_size + config.enable_spinquant_r3 = enable_spinquant_r3 + + # Some config has head_dim provided that is different from equation below(e.g., qwen3) + if not hasattr(config, "head_dim"): + config.head_dim = config.hidden_size // config.num_attention_heads + + model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map=device, + torch_dtype=dtype, + config=config, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_seq_len, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_seq_len, + }, + ), + ).eval() + model_wrapper = QnnCausalLMExportableModule(model) + + return QnnLLMEdgeManager(model_name, model_wrapper, config) + + +class QnnLLMEdgeManager: + def __init__(self, model_name, model_wrapper, config, verbose=True) -> None: + self.model_name = model_name + self.model_wrapper = model_wrapper + self.graph_module = model_wrapper + self.config = config + self.verbose = verbose + self.use_fp16 = True + self.passes_job = get_capture_program_passes() + self.edge_prog_mgr = None + self.logits_quant_attrs = None + + def source_transform( + self, transforms: List[Callable[[torch.nn.Module], torch.nn.Module]] + ) -> "QnnLLMEdgeManager": + """ + Apply source transforms to the model. The transforms are callables that + takes nn.Module as input and returns nn.Module. + Args: + transforms (List[Callable[[torch.nn.Module], torch.nn.Module]]): A + list of source transforms. + """ + for transform in transforms: + self.graph_module = transform(self.graph_module) + + if self.verbose: + logging.info(f"Applied source transforms: {transforms}") + logging.info(f"Model after source transforms: {self.graph_module}") + return self + + def _tag_ios(self, node, fixed_point_type, config): + # shape of k caches and v caches + kv_cache_shape = { + # single head, kv input + (config.head_dim, config.max_seq_len), + (config.max_seq_len, config.head_dim), + # single head, kv output + (config.head_dim, config.ar_len), + (config.ar_len, config.head_dim), + } + + logit_out_shape = { + ( + config.max_batch_size, + config.ar_len, + config.vocab_size, + ) + } + + quant_io_type = None + + if node.op == "placeholder": + if ( + len(users := list(node.users)) == 1 + and users[0].meta["val"].size()[-2:] in kv_cache_shape + ): + quant_io_type = fixed_point_type["kv_type"] + if is_graph_output(node): + if node.meta["val"].size()[-2:] in kv_cache_shape: + quant_io_type = fixed_point_type["kv_type"] + elif node.meta["val"].size() in logit_out_shape: + quant_io_type = fixed_point_type["io_type"] + + return quant_io_type + + def export(self): + with torch.no_grad(): + self.graph_module = torch.export.export( + self.graph_module, + args=self.model_wrapper.get_example_inputs(), + strict=True, + ).module() + + def pt2e_calibrate( + self, + calibration_tasks, + calibration_limit, + calibration_seq_length, + calibration_data, + tokenizer_path, + ): + try: + from executorch.examples.qualcomm.oss_scripts.llm_utils.eval_decoder_model_qnn import ( + GraphModuleCalibrationWrapper, + ) + from lm_eval.evaluator import simple_evaluate + except ImportError: + raise ImportError( + "Please install the llm eval dependency via examples/models/llama/install_requirements.sh" + ) + + tokenizer = get_tokenizer(tokenizer_path) + logging.info( + f"Calibrating with tasks: {calibration_tasks}, limit: {calibration_limit}, calibration_data: {calibration_data}, tokenizer_path: {tokenizer_path}, seq_length: {self.config.max_seq_len}" + ) + + def calibrate_template( + module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int + ): + # TODO: change criteria & support batch inputs if necessary + pos = 0 + token_list = tokenizer.encode(prompts, bos=True, eos=False) + + with torch.no_grad(): + while token_list[-1] != tokenizer.eos_id and pos < max_len: + cur_pos = torch.tensor([pos], dtype=torch.long) + logits = module(torch.full((1, 1), token_list[pos]), cur_pos) + pos += 1 + if pos >= len(token_list): + token_list.append(torch.argmax(logits, dim=-1).item()) + logging.info( + f"Result of LLM with static cache:\n {tokenizer.decode(token_list)} \n\n\n" + ) + + calibrate_template( + module=self.graph_module, + tokenizer=tokenizer, + prompts=calibration_data, + max_len=calibration_seq_length, + ) + if calibration_tasks is not None and calibration_limit is not None: + eval_wrapper = GraphModuleCalibrationWrapper( + model=self.graph_module, + tokenizer=tokenizer, + max_seq_length=calibration_seq_length, + use_kv_cache=True, + generate_full_logits=True, + enable_dynamic_shape=False, + ) + + # Evaluate the model + with torch.no_grad(): + eval_results = simple_evaluate( + model=eval_wrapper, + tasks=calibration_tasks, + limit=calibration_limit, + ) + + for task, res in eval_results["results"].items(): + print(f"{task}: {res}") + logging.info("Calibration finish...") + + def pt2e_quantize( + self, + quant_dtype, + fixed_point_type, + calibration_tasks, + calibration_limit, + calibration_data, + tokenizer_path, + ): + self.export() + + quantizer = make_quantizer( + quant_dtype=quant_dtype, + per_channel_linear=True, + per_channel_conv=True, + act_observer=MinMaxObserver, + ) + if quant_dtype == QuantDtype.use_16a4w_block: + + def extract_linear_nodes(graph): + linear_nodes = [] + for node in graph.nodes: + if node.target == torch.ops.aten.linear.default: + linear_nodes.append(node) # linear node + linear_nodes.append(node.args[1]) # weight node + return linear_nodes + + linear_nodes = extract_linear_nodes(self.graph_module.graph) + block_size_map = {n.name: (1, 16) for n in linear_nodes} + quantizer.set_block_size_map(block_size_map) + self.graph_module = prepare_pt2e(self.graph_module, quantizer) + self.pt2e_calibrate( + calibration_tasks, + calibration_limit, + self.config.max_seq_len, + calibration_data, + tokenizer_path, + ) + self.graph_module = convert_pt2e(self.graph_module) + + self.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True + self.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ + "get_quant_io_dtype_fn" + ] = partial( + self._tag_ios, fixed_point_type=fixed_point_type, config=self.config + ) + self.use_fp16 = False + + def to_edge_transform_and_lower_to_qnn( + self, soc_model, skip_node_id_set, skip_node_op_set + ): + backend_options = generate_htp_compiler_spec(use_fp16=self.use_fp16) + compiler_spec = generate_qnn_executorch_compiler_spec( + soc_model=get_soc_to_chipset_map()[soc_model], + backend_options=backend_options, + ) + with torch.no_grad(): + self.edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + self.graph_module, + self.model_wrapper.get_example_inputs(), + compiler_spec, + constant_methods=self.model_wrapper.get_metadata(), + passes_job=self.passes_job, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + convert_linear_to_conv2d=True, + ) + + print_delegation_info(self.edge_prog_mgr.exported_program().graph_module) + if not self.use_fp16: + logit_out_shape = { + ( + self.config.max_batch_size, + self.config.ar_len, + self.config.vocab_size, + ) + } + for n in self.edge_prog_mgr.exported_program().graph.nodes: + if n.op == "output": + for node, output_encoding in n.meta[QCOM_QUANT_ATTRS_MAP].items(): + if node.meta["val"].size() in logit_out_shape: + self.logits_quant_attrs = output_encoding + + def get_logits_quant_attrs(self): + return self.logits_quant_attrs + + def to_executorch(self, artifact, pte_filename): + executorch_config = ExecutorchBackendConfig( + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + ), + passes=[BuildQuantIo()], + ) + exec_prog_mgr = self.edge_prog_mgr.to_executorch(config=executorch_config) + with open(f"{artifact}/{pte_filename}.pte", "wb") as file: + exec_prog_mgr.write_to_file(file) + logging.info(f"Saved exported program to {artifact}/{pte_filename}.pte") diff --git a/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py b/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py new file mode 100644 index 00000000000..6ee493e46b0 --- /dev/null +++ b/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py @@ -0,0 +1,273 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import getpass +import json +import logging +import os +import subprocess +from multiprocessing.connection import Client + +import torch + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + +from executorch.examples.qualcomm.oss_scripts.llm_utils.qnn_decoder_model_manager import ( + get_qnn_llm_edge_manager, + HUGGING_FACE_REPO_IDS, +) + +from executorch.examples.qualcomm.utils import ( + make_output_dir, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, +) + +from transformers import AutoTokenizer + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + +PTE_FILENAME = "qwen_qnn_q16" + + +def compile(args): # noqa: C901 + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + if not args.compile_only and args.device is None: + raise RuntimeError( + "device serial is required if not compile only. " + "Please specify a device serial by -s/--device argument." + ) + + manager = get_qnn_llm_edge_manager( + args.decoder_model, args.max_seq_len, args.enable_spinquant_r3 + ) + + fixed_point_type = {} + if args.ptq: + if args.ptq == "8a8w": + fixed_point_type["io_type"] = torch.uint8 + fixed_point_type["kv_type"] = torch.uint8 + elif args.ptq in ( + "16a8w", + "16a4w", + "16a4w_block", + "16a16w", + ): + fixed_point_type["io_type"] = torch.uint16 + fixed_point_type["kv_type"] = torch.uint16 + else: + raise ValueError( + f"No support for quant type {args.ptq}. Support 8a8w, 16a8w, 16a4w and 16a4w_block." + ) + quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") + model_id = HUGGING_FACE_REPO_IDS[args.decoder_model] + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer_json_path = tokenizer.save_pretrained(args.artifact)[-1] + + manager.pt2e_quantize( + quant_dtype, + fixed_point_type, + args.calibration_tasks, + args.calibration_limit, + args.prompt, + tokenizer_json_path, + ) + + manager.to_edge_transform_and_lower_to_qnn( + args.model, skip_node_id_set, skip_node_op_set + ) + if args.ptq: + logits_quant_attrs = manager.get_logits_quant_attrs() + json.dump( + { + "scale": logits_quant_attrs["scale"], + "zero_point": logits_quant_attrs["zero_point"], + }, + open(f"{args.artifact}/{PTE_FILENAME}_quant_attrs.txt", "w"), + ) + + manager.to_executorch(args.artifact, PTE_FILENAME) + + +def inference(args): + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{PTE_FILENAME}" + pte_path = f"{args.artifact}/{PTE_FILENAME}.pte" + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + outputs = [] + + def post_process(): + with open(f"{args.artifact}/outputs/result.txt", "r") as f: + outputs.append(f.read()) + + model_id = HUGGING_FACE_REPO_IDS[args.decoder_model] + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer_json_path = tokenizer.save_pretrained(args.artifact)[-1] + seq_len = args.max_seq_len + if args.enable_x86_64: + # x86 emulator is intended for CI and not performance. Check only the first few tokens. + seq_len = min(seq_len, 16) + + qnn_sdk = os.getenv("QNN_SDK_ROOT") + target = "x86_64-linux-clang" + runner_cmd = " ".join( + [ + f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{args.build_folder}/lib &&", + f"{args.build_folder}/examples/models/llama/llama_main", + f'--prompt "{args.prompt}"', + f"--tokenizer_path {tokenizer_json_path}", + f"--model_path {pte_path}", + f"--seq_len {seq_len}", + "--temperature 0", + f" > {output_data_folder}/result.txt", + ] + ) + subprocess.run( + runner_cmd, + shell=True, + executable="/bin/bash", + capture_output=True, + ) + post_process() + else: + runner_cmd = " ".join( + [ + f"cd {workspace} &&", + "./llama_main", + f'--prompt "{args.prompt}"', + "--tokenizer_path tokenizer.json", + f"--model_path {PTE_FILENAME}.pte", + f"--seq_len {seq_len}", + "--temperature 0", + " > outputs/result.txt", + ] + ) + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=pte_path, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + runner="examples/models/llama/llama_main", + ) + # No pregen inputs, input_list is not required + adb.push(inputs=[], input_list="", files=[tokenizer_json_path]) + adb.execute(custom_runner_cmd=runner_cmd) + + adb.pull(output_path=args.artifact, callback=post_process) + + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps( + { + "result": outputs, + } + ) + ) + else: + for idx, output in enumerate(outputs): + logging.info(f"Results[{idx}]:\n{output}") + + +def main(args): + if args.compile_only and args.pre_gen_pte: + raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true") + + if args.compile_only: + compile(args) + elif args.pre_gen_pte: + inference(args) + else: + compile(args) + inference(args) + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example.", + default="qwen2_5", + type=str, + ) + + parser.add_argument( + "-P", + "--ptq", + choices=["8a8w", "16a8w", "16a4w", "16a4w_block"], + help="If specified, will do PTQ quantization.", + type=str, + ) + + parser.add_argument( + "--pre_gen_pte", + help="Run the pre-generated Qwen in the given directory.", + type=str, + ) + + parser.add_argument( + "--prompt", + help="User prompts for Qwen.", + required=True, + type=str, + ) + + parser.add_argument( + "--decoder_model", + choices=["qwen2.5_0.5B", "qwen2.5_0.5B_instruct", "qwen2.5_1.5B_instruct"], + help="The Qwen model to export. Current available options are: [qwen2.5_0.5B, qwen2.5_0.5B_instruct, qwen2.5_1.5B_instruct]", + required=True, + ) + + parser.add_argument( + "--max_seq_len", + help="This refers to maximum number of tokens that the model can process & consider at once to generate predictions/responses.", + default=128, + type=int, + ) + parser.add_argument( + "--calibration_tasks", + nargs="+", + type=str, + default=None, + help="Tasks for GPTQ calibration from lm_eval", + ) + parser.add_argument( + "--calibration_limit", + type=int, + default=None, + help="number of samples used for calibration from lm_eval", + ) + parser.add_argument( + "--enable_spinquant_r3", + action="store_true", + help="Specify to enable spin quant R3", + ) + + try: + args = parser.parse_args() + if args.artifact is None: + args.artifact = args.decoder_model + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/extension/llm/runner/text_decoder_runner.h b/extension/llm/runner/text_decoder_runner.h index 2f9e9a67331..59761396f5c 100644 --- a/extension/llm/runner/text_decoder_runner.h +++ b/extension/llm/runner/text_decoder_runner.h @@ -76,10 +76,11 @@ class ET_EXPERIMENTAL TextDecoderRunner { } } ctx; - ET_SWITCH_THREE_TYPES( + ET_SWITCH_FOUR_TYPES( Float, Half, BFloat16, + UInt16, logits_tensor.scalar_type(), ctx, "logits_to_token", diff --git a/extension/llm/sampler/sampler.cpp b/extension/llm/sampler/sampler.cpp index 18e82418841..3beda885d6f 100644 --- a/extension/llm/sampler/sampler.cpp +++ b/extension/llm/sampler/sampler.cpp @@ -198,6 +198,7 @@ int32_t Sampler::sample(T* logits) { } template int32_t Sampler::sample(float* logits); +template int32_t Sampler::sample(uint16_t* logits); template int32_t Sampler::sample( executorch::aten::Half* logits); template int32_t Sampler::sample( diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 895536b72be..a688313db6b 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -1341,6 +1341,25 @@ struct promote_types { CTYPE_ALIAS, \ __VA_ARGS__)) +#define ET_SWITCH_FOUR_TYPES( \ + T1, T2, T3, T4, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, \ + CONTEXT, \ + NAME, \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::T1, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::T2, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::T3, \ + CTYPE_ALIAS, \ + __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::T4, \ + CTYPE_ALIAS, \ + __VA_ARGS__)) + } // namespace runtime } // namespace executorch