diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 154a360689e..af2e99c86db 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -11,6 +11,7 @@ from .canonicalize_conv import CanonicalizeConv from .convert_bmm_to_matmul import ConvertBmmToMatmul from .convert_linear_to_conv2d import ConvertLinearToConv2d +from .convert_mha_to_sha import ConvertMhaToSha from .convert_square_to_pow import ConvertSquareToPow from .decompose_any import DecomposeAny from .decompose_binary_alpha import DecomposeBinaryAlpha @@ -55,6 +56,7 @@ CanonicalizeConv, ConvertBmmToMatmul, ConvertLinearToConv2d, + ConvertMhaToSha, ConvertSquareToPow, DecomposeAny, DecomposeBinaryAlpha, diff --git a/backends/qualcomm/_passes/canonicalize_conv.py b/backends/qualcomm/_passes/canonicalize_conv.py index dc5c26c1a94..8836ed44328 100644 --- a/backends/qualcomm/_passes/canonicalize_conv.py +++ b/backends/qualcomm/_passes/canonicalize_conv.py @@ -9,7 +9,6 @@ import torch from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter -from executorch.backends.qualcomm.utils.constants import QCOM_REQUANTIZE from executorch.exir.pass_base import ExportPass, PassResult from torch._guards import detect_fake_mode @@ -197,14 +196,6 @@ def call(self, graph_module: torch.fx.GraphModule): ) squeeze_node.meta = copy_meta(node.meta) - if QCOM_REQUANTIZE in input_node.meta: - input_node.meta.pop(QCOM_REQUANTIZE) - if QCOM_REQUANTIZE in node.meta: - squeeze_node.meta[QCOM_REQUANTIZE] = node.meta[ - QCOM_REQUANTIZE - ] - conv2d_node.meta.pop(QCOM_REQUANTIZE, None) - for user in node.users.copy(): user.replace_input_with(node, squeeze_node) diff --git a/backends/qualcomm/_passes/convert_bmm_to_matmul.py b/backends/qualcomm/_passes/convert_bmm_to_matmul.py index 3d4e44dfa42..262a3b9ef0f 100644 --- a/backends/qualcomm/_passes/convert_bmm_to_matmul.py +++ b/backends/qualcomm/_passes/convert_bmm_to_matmul.py @@ -47,7 +47,13 @@ def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph partitions = get_source_partitions( graph, - [operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default], + [ + "matmul", + operator.matmul, + torch.matmul, + torch.bmm, + torch.ops.aten.matmul.default, + ], ) for _, src_partitions in partitions.items(): for src_partition in src_partitions: diff --git a/backends/qualcomm/_passes/convert_mha_to_sha.py b/backends/qualcomm/_passes/convert_mha_to_sha.py new file mode 100644 index 00000000000..dcf152cc9e2 --- /dev/null +++ b/backends/qualcomm/_passes/convert_mha_to_sha.py @@ -0,0 +1,627 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch + +from executorch.backends.qualcomm._passes.utils import find_pattern +from executorch.backends.qualcomm.utils.constants import ( + QCOM_BLOCK_SIZE, + QCOM_QUANT_ATTRS, + QCOM_REQUANTIZE, + QCOM_SCALE, + QCOM_SCALES, + QCOM_ZERO_POINT, + QCOM_ZERO_POINTS, +) + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +from executorch.exir.passes.constant_prop_pass import constant_prop_pass + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def _is_node(node): + return isinstance(node, torch.fx.Node) + + +def _is_output(node): + return _is_node(node) and node.op == "output" + + +def _is_call(node): + return _is_node(node) and node.op == "call_function" + + +def _is_unsqueeze(node): + return _is_call(node) and node.target == exir_ops.edge.aten.unsqueeze_copy.default + + +def _is_view(node): + return _is_call(node) and node.target == exir_ops.edge.aten.view_copy.default + + +def _is_permute(node): + return _is_call(node) and node.target == exir_ops.edge.aten.permute_copy.default + + +def _is_matmul(node): + return _is_call(node) and node.target == exir_ops.edge.aten.matmul.default + + +def _is_bmm(node): + return _is_call(node) and node.target == exir_ops.edge.aten.bmm.default + + +def _is_expand(node): + return _is_call(node) and node.target == exir_ops.edge.aten.expand_copy.default + + +def _is_conv(node): + return _is_call(node) and node.target == exir_ops.edge.aten.convolution.default + + +def _is_softmax(node): + return _is_call(node) and node.target in [ + exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten._safe_softmax.default, + ] + + +def _shape(node): + assert "val" in node.meta + return list(node.meta["val"].shape) + + +@dataclass +class Sha: + axis: int + heads: int + + def __repr__(self): + return f"Sha(axis={self.axis}, heads={self.heads})" + + +class ConvertMhaToSha(ExportPass): + """ + b=batch, e=emb=h*d, h=heads, d=head_size, s=seq_len, p=past, c=s+p + + i[bse] ─┬─ q[bse] ─ [bhsd] ─ RoPE ─ [bhsd] ───────────────────── qk[bhsc] ─ mask ─ softmax ─ qkv[bhsd] ─ [bse] ─ o[bse] + ├─ k[bse] ─ [bhsd] ─ RoPE ─ [bhds] ─ k_cat[bhdc] ─(k_exp)─┘ │ + │ past_k[bhdp] ──┘ │ + └─ v[bse] ─ [bhsd] ───────────────── v_cat[bhcd] ─(v_exp)-────────────────────────────┘ + past_v[bhpd] ──┘ + """ + + def __init__( + self, + edge_program: torch.export.ExportedProgram, + verbose=False, + ): + super().__init__() + self.edge_program = edge_program + self.verbose = verbose + + def _nodes(self, graph_module, wanted_sources, node_checker=None): + nodes = [] + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in wanted_sources: + if node_checker is None or node_checker(node): + nodes.append(node) + return nodes + + def _get_attention_output(self, softmax): + """Output of MHA block or input of output projection""" + + pattern_qk = [_is_softmax, "*", lambda x: _is_matmul(x) or _is_bmm(x)] + qk = find_pattern(softmax, pattern_qk) + if not qk: + return None, None, None + + patterns_qkv = [ + _is_softmax, + "*", + lambda x: _is_matmul(x) or _is_bmm(x), + "*", + _is_permute, + _is_view, + ] + + qkv = find_pattern(softmax, patterns_qkv, from_args=False) + if qkv is None: + return None, None, None + + permute, reshape = qkv[0][-2:] + matmul = qkv[0][2] + attn_output = matmul + sha_axis = 1 + remove_nodes = [permute, reshape] + + # the shape of attn_output should be [bhsd] + shape = _shape(attn_output.args[0]) + heads = shape[sha_axis] + sha = Sha(axis=sha_axis, heads=heads) + + return attn_output, sha, remove_nodes + + def _update_requantize_user(self, node): + if QCOM_REQUANTIZE in node.meta: + user_node_list = [user.name for user in node.users.keys()] + + new_dict = {} + for original_key in node.meta[QCOM_REQUANTIZE]: + for new_key in user_node_list: + # new_keys are the name of the split nodes whose naming pattern follows: _h_xxx + if original_key in new_key: + new_dict.update( + {new_key: node.meta[QCOM_REQUANTIZE][original_key]} + ) + node.meta[QCOM_REQUANTIZE] = new_dict + + def _split( # noqa: C901 + self, + graph_module: torch.fx.GraphModule, + attn_output: torch.fx.Node, + sha: Sha, + remove_nodes: List, + ): + """ + Main MHA to SHAs + - Start from the attention output or the input of the output projection node, assuming the head axis is 2. + - Recursively visit parent nodes until reaching the static Linear/Conv2D nodes, which must be the Q/K/V projection nodes. + - Splitting begins from the end of the recursion, which must be the Q/K/V projection nodes. + - The visit call will return the split nodes, which will be used by subsequent child visitors. + + Known issue + - Packed Q/K/V projection is not supported yet + """ + + def _visit_reshape(node, sha): + """Reshape: handle GQA pattern""" + in_shape, out_shape = _shape(node.args[0]), _shape(node) + if out_shape[sha.axis] % sha.heads == 1: + return _no_split(node, sha) + + assert ( + out_shape[sha.axis] % sha.heads == 0 + ), f"mismatching num_heads, {out_shape[sha.axis]} % {sha.heads} != 0" + + pattern_simple_gqa = [ + _is_view, + lambda x: _is_expand(x) and len(_shape(x)) == 5, + _is_unsqueeze, + ] + + if gqa := find_pattern(node, pattern_simple_gqa): + # GQA pattern: skip these and adjust sha.heads + if self.verbose: + logging.info(f"{__name__}:_visit_reshape: {node} is for GQA!") + _, expand, unsqueeze = gqa[0] + expand_shape = expand.args[1] + unsqueeze_dim = unsqueeze.args[1] + repeat_count = expand_shape[unsqueeze_dim] + kv_sha = Sha(sha.axis, in_shape[sha.axis]) + new_arg0s = _visit(unsqueeze.args[0], kv_sha) + new_arg0s = [arg for arg in new_arg0s for _ in range(repeat_count)] + else: + new_arg0s = _visit(node.args[0], sha) + + out_shape[sha.axis] //= sha.heads + new_args = [(arg0, out_shape) for arg0 in new_arg0s] + return _split_call(node, sha, new_args, out_shape) + + def _visit_permute(node, sha): + """Transpose: permute sha axis as well""" + out_shape = _shape(node) + assert ( + out_shape[sha.axis] % sha.heads == 0 + ), f"mismatching num_heads, {out_shape[sha.axis]} % {sha.heads} != 0" + out_shape[sha.axis] //= sha.heads + permute = node.args[1] + sha_permuted = Sha(axis=permute[sha.axis], heads=sha.heads) + new_arg0s = _visit(node.args[0], sha_permuted) + new_args = [(arg0, node.args[1]) for arg0 in new_arg0s] + return _split_call(node, sha, new_args, out_shape) + + def _visit_expand(node, sha): + out_shape = _shape(node) + if out_shape[sha.axis] != 1: + assert ( + out_shape[sha.axis] % sha.heads == 0 + ), f"mismatching num_heads, {out_shape[sha.axis]} % {sha.heads} != 0" + out_shape[sha.axis] //= sha.heads + + exp_shape = node.args[1] + if exp_shape[sha.axis] == 1: + return _visit_default(node, sha) + + assert ( + exp_shape[sha.axis] % sha.heads == 0 + ), f"mismatching expand shape, {exp_shape[sha.axis]} % {sha.heads} != 0" + new_exp_shape = type(exp_shape)( + [ + dim // sha.heads if axis == sha.axis else dim + for axis, dim in enumerate(exp_shape) + ] + ) + new_args = [(node.args[0], new_exp_shape)] * sha.heads + new_nodes = _split_call(node, sha, new_args, out_shape) + return new_nodes + + def _visit_cat(node, sha): + out_shape = _shape(node) + if out_shape[sha.axis] != 1: + assert ( + out_shape[sha.axis] % sha.heads == 0 + ), f"mismatching num_heads, {out_shape[sha.axis]} % {sha.heads} != 0" + out_shape[sha.axis] //= sha.heads + + assert isinstance(node.args[0], (tuple, list)) # concat + split_arg0s = [_visit(arg, sha) for arg in node.args[0]] + new_arg0s = list(zip(*split_arg0s)) + split_arg1s = [_visit(arg, sha) for arg in node.args[1:]] + new_arg1s = list(zip(*split_arg1s)) + new_args = [(arg0, *arg1) for arg0, arg1 in zip(new_arg0s, new_arg1s)] + + new_nodes = _split_call(node, sha, new_args, out_shape) + return new_nodes + + def _visit_default(node, sha): + out_shape = _shape(node) + + if out_shape[sha.axis] != 1: + assert ( + out_shape[sha.axis] % sha.heads == 0 + ), f"mismatching num_heads, {out_shape[sha.axis]} % {sha.heads} != 0" + out_shape[sha.axis] //= sha.heads + + assert not isinstance( + node.args[0], (tuple, list) + ), f"Unexpected cat node:{node}" + split_args = [_visit(arg, sha) for arg in node.args] + new_args = list(zip(*split_args)) + new_nodes = _split_call(node, sha, new_args, out_shape) + return new_nodes + + def _is_mha(node, sha): + if not _is_node(node): + return False + out_shape = _shape(node) + return len(out_shape) > sha.axis and out_shape[sha.axis] == sha.heads + + def _visit_binary(node, sha): + """elementwise binary operator visit mha inputs only""" + out_shape = _shape(node) + if out_shape[sha.axis] != 1: + assert ( + out_shape[sha.axis] % sha.heads == 0 + ), f"mismatching num_heads, {out_shape[sha.axis]} % {sha.heads} != 0" + out_shape[sha.axis] //= sha.heads + + split_args = [ + (_visit(arg, sha) if _is_mha(arg, sha) else [arg] * sha.heads) + for arg in node.args + ] + new_args = list(zip(*split_args)) + new_nodes = _split_call(node, sha, new_args, out_shape) + return new_nodes + + def _visit_placeholder(node, sha): + in_shape = _shape(node) + if ( + in_shape + and len(in_shape) > sha.axis + and in_shape[sha.axis] == sha.heads + ): # split past_kv by heads + new_nodes = _split_placeholder( + node, axis=sha.axis, size=1, count=sha.heads + ) + else: + # position embedding, attention mask, and R3 weights + new_nodes = _no_split(node, sha) + return new_nodes + + def _get_slicers(count, axis, size): + return [ + tuple( + [ + ( + slice(size * idx, size * (idx + 1)) + if ax == axis + else slice(None) + ) + for ax in range(axis + 1) + ] + ) + for idx in range(count) + ] + + def _split_call(node, sha, new_args, out_shape): + with graph_module.graph.inserting_after(node): + new_nodes = [] + slicers = _get_slicers(sha.heads, sha.axis, out_shape[sha.axis]) + for head, (args, slicer) in enumerate(zip(new_args, slicers)): + name = f"{node.name}_h_{head}" + new_nodes.append( + _duplicate_call(node, args, None, slicer, name=name) + ) + return new_nodes + + def _create_call( + op_target, args: Tuple, kwargs: Optional[dict] = None, name: str = None + ): + return graph_module.graph.create_node( + "call_function", + op_target, + args=args, + kwargs=kwargs or {}, + name=name, + ) + + def _no_split(node, sha): + return [node] * sha.heads + + def _copy_meta(dst_node, src_node, slicer): + dst_node.meta = src_node.meta.copy() + dst, src = dst_node.meta, src_node.meta + if "val" in src: + dst["val"] = src["val"].clone()[slicer] + if src_tensor_meta := src.get("tensor_meta", None) is not None: + tensor_meta = dict(zip(src_tensor_meta._fields, [*src_tensor_meta])) + tensor_meta["shape"] = dst["val"].shape + tensor_meta["stride"] = dst["val"].stride() + dst["tensor_meta"] = type(src_tensor_meta)(**tensor_meta) + # PCQ + if QCOM_QUANT_ATTRS in src and QCOM_SCALES in src[QCOM_QUANT_ATTRS]: + dst[QCOM_QUANT_ATTRS] = src[QCOM_QUANT_ATTRS].copy() + # slice for per channel quantize + dst[QCOM_QUANT_ATTRS][QCOM_SCALES] = src[QCOM_QUANT_ATTRS][ + QCOM_SCALES + ].clone()[slicer] + dst[QCOM_QUANT_ATTRS][QCOM_ZERO_POINTS] = src[QCOM_QUANT_ATTRS][ + QCOM_ZERO_POINTS + ].clone()[slicer] + + # LPBQ + if QCOM_QUANT_ATTRS in src and QCOM_BLOCK_SIZE in src[QCOM_QUANT_ATTRS]: + dst[QCOM_QUANT_ATTRS] = src[QCOM_QUANT_ATTRS].copy() + dst[QCOM_QUANT_ATTRS][QCOM_SCALE] = src[QCOM_QUANT_ATTRS][ + QCOM_SCALE + ].clone()[slicer] + dst[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT] = src[QCOM_QUANT_ATTRS][ + QCOM_ZERO_POINT + ].clone()[slicer] + + if "example_value" in src: + dst["example_value"] = src["example_value"].clone()[slicer] + + if QCOM_REQUANTIZE in src: + # We assume there is no requantize happens on the per-channel quantization weights, only per-tensor quantization + dst[QCOM_REQUANTIZE] = src[QCOM_REQUANTIZE].copy() + + def _duplicate_call( + node, args: Tuple, kwargs: Optional[dict] = None, slicer=None, name=None + ): + """Create SHA nodes by duplicating""" + assert ( + node.op == "call_function" + ), f"Unexpected node:{node.name}:{node.target}" + new_node = _create_call(node.target, args, kwargs, name=name) + _copy_meta(new_node, node, slicer) + return new_node + + def _split_placeholder(node, axis, size, count): + slice_op = exir_ops.edge.aten.slice_copy.Tensor + with graph_module.graph.inserting_after(node): + sliced_nodes = [] + for head, slicer in zip(range(count), _get_slicers(count, axis, size)): + sliced = _create_call( + slice_op, + (node, axis, slicer[axis].start, slicer[axis].stop), + name=f"{node.name}_h_{head}", + ) + _copy_meta(sliced, node, slicer) + sliced_nodes.append(sliced) + return sliced_nodes + + def _visit_linear_conv(node, sha): + """ + 0. Reshape of making multi-heads of MHA + - embedding = head * head_dim + - [batch, sequence, embedding] -> [batch, sequence, head, head_dim], + - [batch, sequence, embedding, 1] -> [batch, sequence, head, head_dim], embedding=head * head_dim + + 1. **q/k/v projections => stop recursion** + - 3D input and output + - Split output features + - ConvInplaceLinear + - [3d-unsqueeze-4d-permute-conv2d-permute-squeeze-3d] + - input: permute_copy(input): 4D[batch, in_feature, 1, num_input] => re-use + - weight[out_feature = heads * head_dim, in_feature, 1, 1] => heads * [head_dim, in_feature, 1, 1] + - So, split_axis=0 for Conv2D + + 2. **R3 of SpinQuant => continue recursion** + - 4D input and output + - ConvInplaceLinear + - [4d-permute-conv2d-permute-4d], **same as 3D case but no squeeze/unsqueeze** + - input: 4D [batch, head_dim, heads, num_input] => heads * [batch, head_dim, 1, num_input] + - weight: 2D [head_dim, head_dim, 1, 1] => re-use + """ + + def _is_making_mha(cur): + cur_sha = sha + pattern_conv_mha = ([_is_conv, "*", _is_permute, "*", _is_view], False) + if mha := find_pattern(cur, *pattern_conv_mha): + permute, reshape = mha[0][-3], mha[0][-1] + permutation = permute.args[1] + cur_sha = Sha( + permutation.index(sha.axis), sha.heads + ) # to reverse permute + else: + return False + + # Check whether this reshape is to make multi-heads or not + if len(reshape.args[1]) == 4: + # got MHA reshape + in_shape, out_shape = _shape(reshape.args[0]), _shape(reshape) + if ( + len(out_shape) > cur_sha.axis + 1 + and in_shape[cur_sha.axis] + == out_shape[cur_sha.axis] * out_shape[cur_sha.axis + 1] + ): + return True + return False + + if _is_making_mha(node): + if self.verbose: + logging.info( + f"{__name__}:_visit_linear_conv: {node} is making MHA!" + ) + out_feature, *_ = _shape(node.args[1]) + assert out_feature % sha.heads == 0 + out_feature_per_head = out_feature // sha.heads + + split_axis = 0 + new_weights = _split_placeholder( + node.args[1], + axis=split_axis, + size=out_feature_per_head, + count=sha.heads, + ) + if node.args[2] is not None: + new_bias = _split_placeholder( + node.args[2], + axis=split_axis, + size=out_feature_per_head, + count=sha.heads, + ) + + with graph_module.graph.inserting_after(node): + new_nodes = [] + slicers = _get_slicers(sha.heads, 1, out_feature_per_head) + if node.args[2] is not None: + for head, (weight, bias, slicer) in enumerate( + zip(new_weights, new_bias, slicers) + ): + name = f"{node.name}_h_{head}" + sliced = _duplicate_call( + node, + (node.args[0], weight, bias) + node.args[3:], + None, + slicer, + name=name, + ) + new_nodes.append(sliced) + else: + for head, (weight, slicer) in enumerate( + zip(new_weights, slicers) + ): + name = f"{node.name}_h_{head}" + sliced = _duplicate_call( + node, + (node.args[0], weight) + node.args[2:], + None, + slicer, + name=name, + ) + new_nodes.append(sliced) + + return new_nodes + else: + return _visit_default(node, sha) + + def _concat_sha_nodes(node, sha): + """Concat sha nodes and replace old node""" + sha_nodes = visited[node] + with graph_module.graph.inserting_after(sha_nodes[0]): + cat = exir_ops.edge.aten.cat.default + name = f"{node.name}_sha_concat" + new_node = _create_call(cat, (sha_nodes, sha.axis), name=name) + new_node.meta = node.meta.copy() + fake_tensors = [n.meta["val"] for n in sha_nodes] + result_fake_tensor = torch.cat(fake_tensors, sha.axis) + new_node.meta["val"] = result_fake_tensor + node.replace_all_uses_with(new_node) + + def _visit(node, sha): + if not _is_node(node): + return [node for _ in range(sha.heads)] + + if node in visited: + return visited[node] + + visitors = { + "placeholder": _visit_placeholder, + exir_ops.edge.aten.expand_copy.default: _visit_expand, + exir_ops.edge.aten.view_copy.default: _visit_reshape, + exir_ops.edge.aten.permute_copy.default: _visit_permute, + exir_ops.edge.aten.convolution.default: _visit_linear_conv, + exir_ops.edge.aten.mm.default: _visit_linear_conv, + exir_ops.edge.aten.cat.default: _visit_cat, + exir_ops.edge.aten.add.Tensor: _visit_binary, + exir_ops.edge.aten.mul.Tensor: _visit_binary, + exir_ops.edge.aten.eq.Tensor: _no_split, + } + + target = node.target if _is_call(node) else node.op + visited[node] = visitors.get(target, _visit_default)(node, sha) + + if [user for user in node.users.keys() if _is_output(user)]: + _concat_sha_nodes(node, sha) + return visited[node] + + if self.verbose: + logging.info(f"{__name__}:_split: attn_output:{attn_output}, sha:{sha}!") + visited = {} + _visit(attn_output, sha) + opt_sha = Sha(axis=3, heads=sha.heads) + _concat_sha_nodes(attn_output, opt_sha) + for remove_node in remove_nodes: + assert _is_permute(remove_node) or _is_view( + remove_node + ), "The removed nodes must be either transpose or reshape" + rnode_input = remove_node.args[0] + for user in list(remove_node.users): + new_args = tuple( + rnode_input if arg is remove_node else arg for arg in user.args + ) + user.args = new_args + for remove_node in remove_nodes: + graph_module.graph.erase_node(remove_node) + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + softmaxes = self._nodes( + graph_module, + [ + exir_ops.edge.aten._softmax.default, + exir_ops.edge.aten._safe_softmax.default, + ], + ) + for softmax in softmaxes: + attn_output, sha, remove_nodes = self._get_attention_output(softmax) + if not attn_output: + continue + + self._split(graph_module, attn_output, sha, remove_nodes) + modified = True + + if modified: + for node in graph_module.graph.nodes: + self._update_requantize_user(node) + graph_module.graph.eliminate_dead_code() + constant_prop_pass(self.edge_program) # need to fuse sha weights + graph_module.recompile() + graph_module.graph.lint() + + return PassResult(graph_module, modified=modified) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 80b4675d2f1..56ca984269a 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -16,6 +16,7 @@ CanonicalizeConv, ConvertBmmToMatmul, ConvertLinearToConv2d, + ConvertMhaToSha, ConvertSquareToPow, DecomposeAny, DecomposeBinaryAlpha, @@ -87,7 +88,6 @@ def get_capture_program_passes(): (AnnotateQuantAttrs, True), (AnnotateStack, True), (AnnotateUnbind, True), - (CanonicalizeConv, True), (ConvertBmmToMatmul, False), (DecomposeAny, True), (DecomposeColIm, True), @@ -242,8 +242,12 @@ def transform_for_export_pipeline( ep = lift_constant_tensor_pass(exported_program) return ep - def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram): + def transform_for_preprocess_pipeline( + self, exported_program: ExportedProgram, use_mha2sha=False + ): self.add_pass(FoldQDQ(exported_program, force_fold=True)) + if use_mha2sha: + self.add_pass(ConvertMhaToSha(exported_program)) self.add_pass(InsertRequantize()) self.add_pass(InsertIOQDQ(exported_program)) self.add_pass(LayoutTransform(exported_program, insert_permute=True)) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index eebfa4d9eb4..a475542de23 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -177,7 +177,7 @@ def _next(node, from_args=True): yield from list(node.users) -def _find_pattern( +def find_pattern( node: torch.fx.Node, pattern: List[Callable[[torch.fx.Node], bool] | str], from_args: bool = True, @@ -190,6 +190,7 @@ def _find_pattern( - pattern: predicate list, can contain followings Callable(fx.node): predicate '*': wildcard + '?': any single node - from_args: if True find from node.args, otherwise from node.users - max_wildcard_life: max number of skips for wildcard @@ -197,7 +198,7 @@ def _find_pattern( Otherwise, return list of matched node list, which is the same length as pattern """ - asterisk = "*" + asterisk, question = "*", "?" def _probe( cur, hist, pat_idx, asterisk_life_count=max_wildcard_life, verbose=verbose @@ -212,7 +213,7 @@ def _probe( print( f"cur:{cur}, idx:{pat_idx}, life={asterisk_life_count}, pattern:{pattern[pat_idx]} hist={hist}" ) - if _pred(cur, pattern[pat_idx]): + if pattern[pat_idx] == question or _pred(cur, pattern[pat_idx]): hist.append(cur) for child in _next(cur, from_args): _probe(child, hist, pat_idx + 1) @@ -236,7 +237,8 @@ def _probe( # Check if pattern is valid assert all( - isinstance(i, Callable) or (isinstance(i, str) and i == "*") for i in pattern + isinstance(i, Callable) or (isinstance(i, str) and (i == "*" or i == "?")) + for i in pattern ), f"Invalid pattern: {pattern}" # Start probing @@ -249,7 +251,7 @@ def find_patterns(node, patterns, **kwargs): assert isinstance(patterns, list) and isinstance(patterns[0], list) results = [] for pattern in patterns: - result = _find_pattern(node, pattern, **kwargs) + result = find_pattern(node, pattern, **kwargs) results.append(result) return results diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index 4e9cda21d02..187ddfd7e5e 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -42,9 +42,12 @@ def _build_op_wrappers( edge_program: ExportedProgram, enable_tensor_dump: bool, op_package_infos: List[QnnExecuTorchOpPackageInfo], + use_mha2sha: bool, ): # QNN Delegate Specific Passes - graph_module = QnnPassManager().transform_for_preprocess_pipeline(edge_program) + graph_module = QnnPassManager().transform_for_preprocess_pipeline( + edge_program, use_mha2sha=use_mha2sha + ) assert graph_module is not None nodes_to_wrappers = defaultdict(dict) @@ -106,6 +109,7 @@ def preprocess( edge_program, qnn_manager.IsTensorDump(), obj_options.op_package_options.op_package_infos, + obj_options.use_mha2sha, ) qnn_context_binary = qnn_manager.Compile( @@ -165,6 +169,7 @@ def preprocess_multimethod( programs[i], qnn_manager.IsTensorDump(), option.op_package_options.op_package_infos, + option.use_mha2sha, ) if isinstance(py_op_wrappers, bytes): ctx_binary_list.append(py_op_wrappers) diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 8b59de3bd4e..716bc8401b1 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -690,7 +690,9 @@ def annotate_pad(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator([torch.ops.aten.reshape.default, torch.ops.aten.unflatten.int]) def annotate_reshape(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_single_in_single_out(node, quantization_config) + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_share_out(node, quantization_config) @register_annotator([torch.ops.aten.select.int]) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index c592ad64da6..ca2b07ddc16 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -57,9 +57,9 @@ class StaticLLMQuantConfig(Enum): Layer namespace configuration for Qualcomm's static LLaMA quantization. """ - wq_sha = "wq_sha" # Query weight (single head) - wk_sha = "wk_sha" # Key weight (single head) - wv_sha = "wv_sha" # Value weight (single head) + wq = "wq" # Query weight + wk = "wk" # Key weight + wv = "wv" # Value weight def annotate_eurobert(gm: torch.fx.GraphModule): @@ -373,7 +373,10 @@ def annotate_matmul_input1(node: Node, is_qat: str): torch.ops.aten.transpose.int, torch.ops.aten.view.default, torch.ops.aten.reshape.default, + torch.ops.aten.select.int, torch.ops.aten.slice.Tensor, + torch.ops.aten.expand.default, + torch.ops.aten.unsqueeze.default, ]: annotate_single_in_single_out(node, quantization_config_8a8w) node = node.args[0] diff --git a/backends/qualcomm/serialization/qc_compiler_spec.fbs b/backends/qualcomm/serialization/qc_compiler_spec.fbs index 145ae0010fc..64dac7a5965 100644 --- a/backends/qualcomm/serialization/qc_compiler_spec.fbs +++ b/backends/qualcomm/serialization/qc_compiler_spec.fbs @@ -239,14 +239,17 @@ table QnnExecuTorchOptions { /// Is model from qnn context binary is_from_context_binary:bool; - // Enable this option to record all QNN API calls for debugging purpose + /// Enable this option to record all QNN API calls for debugging purpose saver:bool; - // Path to saver output folder + /// Path to saver output folder saver_output_dir:string; /// Optional structure to specify op packages loaded and used by the backend. op_package_options:QnnExecuTorchOpPackageOptions; + + /// This experimental parameter is used to decide whether to enable multi-head attention to single-head attention pass, aiming to reduce time consumption in AOT and improve performance on HTP. + use_mha2sha:bool; } root_type QnnExecuTorchOptions; diff --git a/backends/qualcomm/serialization/qc_schema.py b/backends/qualcomm/serialization/qc_schema.py index 9f4b37c13d1..dafc0cff7c3 100644 --- a/backends/qualcomm/serialization/qc_schema.py +++ b/backends/qualcomm/serialization/qc_schema.py @@ -194,3 +194,4 @@ class QnnExecuTorchOptions: op_package_options: QnnExecuTorchOpPackageOptions = field( default_factory=QnnExecuTorchOpPackageOptions ) + use_mha2sha: bool = False diff --git a/backends/qualcomm/tests/test_passes.py b/backends/qualcomm/tests/test_passes.py index 94a5d08acc1..8af66c4cbef 100644 --- a/backends/qualcomm/tests/test_passes.py +++ b/backends/qualcomm/tests/test_passes.py @@ -1,7 +1,15 @@ import unittest import torch -from executorch.backends.qualcomm._passes import InsertReshapeForReduceOps +from executorch.backends.qualcomm._passes import ( + ConvertBmmToMatmul, + ConvertMhaToSha, + InsertReshapeForReduceOps, + RemoveRedundancy, +) + +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops class TestPasses(unittest.TestCase): @@ -49,6 +57,98 @@ def forward(self, x): torch.equal(*out, ref), f"Output mismatch: got {out}, expected {ref}" ) + def test_mha_to_sha(self): + from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d + from executorch.examples.models.llama.model_args import ModelArgs + from executorch.examples.qualcomm.oss_scripts.llama.masking_utils import ( + CausalAttentionMask, + ) + from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( + LlamaAttention, + ) + + # Initailize model config + args = ModelArgs() + args.max_seq_len = 128 + args.ar_len = 32 + args.use_kv_cache = True + args.dim = 32 + args.n_heads = 8 + args.n_kv_heads = 8 + args.n_layers = 2 + args.head_dim = args.dim // args.n_heads + mod = convert_linear_to_conv2d(LlamaAttention(0, args, True)) + + # Prepare inputs + hidden_states = torch.randn(args.max_batch_size, args.ar_len, args.dim) + freqs_cos = torch.randn(args.ar_len, 1) + freqs_sin = torch.randn(args.ar_len, 1) + atten_mask = CausalAttentionMask( + args.max_batch_size, args.ar_len, args.max_seq_len + ) + k_cache = torch.zeros( + args.max_batch_size, + args.n_kv_heads, + args.head_dim, + args.max_seq_len - args.ar_len, + ) + + v_cache = torch.zeros( + args.max_batch_size, + args.n_kv_heads, + args.max_seq_len - args.ar_len, + args.head_dim, + ) + sample_input = ( + hidden_states, + freqs_cos, + freqs_sin, + atten_mask.mask, + k_cache, + v_cache, + ) + + # Run original module for reference + refs = mod(*sample_input) + + # Export the module and convert linear to conv2d + edge_program = to_edge(torch.export.export(mod, sample_input)) + new_ep = edge_program.exported_program() + + conv_nodes = [ + n + for n in new_ep.graph.nodes + if n.target == exir_ops.edge.aten.convolution.default + ] + # WQ, WK, WV, O + self.assertTrue(len(conv_nodes) == 4, "Convolution nodes missing") + + # Convert MHA to SHA + # This is a simplified version of what happens in the full pipeline to test the core functionality + graph_module = RemoveRedundancy(quantization_capture=False)( + new_ep.graph_module + ).graph_module + graph_module = ConvertBmmToMatmul()(graph_module).graph_module + graph_module = ConvertMhaToSha(new_ep)(graph_module).graph_module + + conv_nodes = [ + n + for n in new_ep.graph.nodes + if n.target == exir_ops.edge.aten.convolution.default + ] + # Check graph structure: WQ, WK, WV should be converted to SHA + self.assertTrue(len(conv_nodes) == 25, "Convolution nodes should be splited") + + # Execute new graph and compare with reference + outs = graph_module( + *new_ep.state_dict.values(), *new_ep.constants.values(), *sample_input + ) + for i, (out, ref) in enumerate(zip(outs, refs)): + self.assertTrue( + torch.allclose(out, *ref, rtol=1e-6, atol=1e-6), + f"Output {i} mismatch: got {out}, expected {ref}", + ) + if __name__ == "__main__": unittest.main() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 082c1ea5a08..fb9cb24c9cf 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -5902,8 +5902,6 @@ def test_static_gemma3_1b(self): str(self.port), "--prompt", f"{prompt}", - "--ptq", - "16a4w_block", "--temperature", "0", "--decoder_model", @@ -5917,7 +5915,6 @@ def test_static_gemma3_1b(self): "wikitext", "--limit", "1", - "--enable_masked_softmax", ] if self.compile_only: cmds.extend(["--compile_only"]) diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 91610301515..9ca02932d80 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -157,7 +157,7 @@ def __init__(self, weight, bias=None): def forward(self, x): rank = x.dim() - x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1) + x = x.reshape(*x.shape, 1) if rank == 3 else x.reshape(1, *x.shape, 1) x = torch.transpose(x, 1, 2) res = self.conv(x) res = torch.transpose(res, 1, 2) @@ -989,6 +989,7 @@ def generate_qnn_executorch_compiler_spec( is_from_context_binary: bool = False, graph_name: str = "forward", op_package_options: QnnExecuTorchOpPackageOptions = None, + use_mha2sha: bool = False, ) -> List[CompileSpec]: """ Helper function generating compiler specs for Qualcomm AI Engine Direct @@ -1019,6 +1020,7 @@ def generate_qnn_executorch_compiler_spec( graph_name: Assign unique graph name if lowering multiple methods. op_package_options: Optional structure to specify op packages loaded and used by the backend. + use_mha2sha: This experimental parameter is used to decide whether to enable multi-head attention to single-head attention pass, aiming to reduce time consumption in AOT and improve performance on HTP. Returns: List[CompileSpec]: Compiler specs for Qualcomm AI Engine Direct. @@ -1081,6 +1083,8 @@ def generate_qnn_executorch_compiler_spec( if op_package_options and len(op_package_options.op_package_infos) > 0: qnn_executorch_options.op_package_options = op_package_options + qnn_executorch_options.use_mha2sha = use_mha2sha + return [ CompileSpec(QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(qnn_executorch_options)) ] diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index e6fa9a66e26..9deb84db4d2 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -143,20 +143,12 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL ``` ### KV Cache update mechanism -We have two distinct mechanisms for updating the key-value (KV) cache, which can be selected at runtime. Shift Pointer and Smart Mask. - -#### Shift Pointer mechanism - -
- Shift Pointer mechanism
- The figure illustrates the process of updating the key and value caches during each inference step. In key cache update process, we initially allocate memory for each layer with num_head size of (head_dim + 1) * (seq_len - 1). After a single inference, the new key cache is copied from the key output pointer k_out and appended to the key cache. Subsequently, the buffer start pointer of the key cache k_in moves to the next token, making the previous position of the buffer start pointer unused. This process is repeated for each subsequent inference step. - For the value cache update process, we first allocate a contiguous memory of size (num_head + 1) * head_dim * (seq_len - 1) for each layer, with the last head reserved for I/O shifting, After the first inference, the cache is updated by simply shifting the pointers of all heads to the next token position, making only the previous head_dim * 1 section of the buffer start pointer v_in of the first head unused. This process is repeated for each subsequent inference step.
-
+We use Smart Mask mechanisms for updating the key-value (KV) cache. #### Smart Mask mechanism:
Smart Mask mechanism -
The Smart Mask mechanism streamlines the process of updating tokens in the cache. Unlike the Shift Pointer mechanism, which requires moving the buffer start pointer k_in/v_in of the cache, the Smart Mask mechanism updates only the new token at the specified position. This approach eliminates the need to adjust the buffer start pointer. This mechanism is beneficial for shared buffers but requires CPU memory copying.
+
The figure illustrates how key and value caches are updated during each inference step. The Smart Mask mechanism simplifies updating tokens in the cache by modifying only the new token at the designated position. This approach is useful for shared buffers, though it does require copying data in CPU memory to update the kv cache.
#### Analysis KV Cache Update Mechanism for each Layer each inference @@ -173,13 +165,6 @@ We have two distinct mechanisms for updating the key-value (KV) cache, which can K V - - Shift Pointer - num_head * head_dim - 1 - num_head * (head_dim + 1) * seq_len - (num_head + 1) * head_dim * (seq_len - 1) - Smart Mask num_head * head_dim @@ -203,14 +188,6 @@ On the other hand, if you already have a pre-compiled .pte model, you can perfor python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE} ``` -#### KV Cache Updater - -You can select the KV Cache update mechanism at runtime by setting the `KV_UPDATER` variable to either "shift_pointer" or "smart_mask". By default, it is set to "smart_mask". -`KV_UPDATER` = "shift_pointer" -```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER} -``` - #### Lookahead Decoding Mode You can choose the lookahead mode to enhance decoding speed. To use this mode, you need to specify the following parameters: diff --git a/examples/qualcomm/oss_scripts/llama/__init__.py b/examples/qualcomm/oss_scripts/llama/__init__.py index e2407e6812a..f9112b2ceee 100644 --- a/examples/qualcomm/oss_scripts/llama/__init__.py +++ b/examples/qualcomm/oss_scripts/llama/__init__.py @@ -62,12 +62,12 @@ annotate_wqkv_sha = partial( annotate_qkv_proj_sha, qkv_tags={ - StaticLLMQuantConfig.wq_sha, - StaticLLMQuantConfig.wk_sha, - StaticLLMQuantConfig.wv_sha, + StaticLLMQuantConfig.wq, + StaticLLMQuantConfig.wk, + StaticLLMQuantConfig.wv, }, ) -annotate_wv_sha = partial(annotate_qkv_proj_sha, qkv_tags={StaticLLMQuantConfig.wv_sha}) +annotate_wv_sha = partial(annotate_qkv_proj_sha, qkv_tags={StaticLLMQuantConfig.wv}) @dataclass(init=False, frozen=True) @@ -204,16 +204,9 @@ class LlamaStories260K(LLMModelConfig): r1 = False r2 = False r3 = False - quantization_config_wv_sha_8a4w = get_ptq_per_channel_quant_config( - act_dtype=torch.uint8, - weight_dtype=torch.int4, - act_observer=MinMaxObserver, - act_symmetric=True, - ) custom_annotation = ( annotate_kv_8bit, annotate_output_16a8w, - partial(annotate_wv_sha, quantization_config=quantization_config_wv_sha_8a4w), ) @@ -437,7 +430,7 @@ class Qwen2_5_0_5B(LLMModelConfig): seq_mse_candidates = 0 r1 = False r2 = False - r3 = True + r3 = False custom_annotation = () @@ -460,7 +453,7 @@ class Qwen2_5_1_5B(LLMModelConfig): seq_mse_candidates = 0 r1 = False r2 = False - r3 = True + r3 = False custom_annotation = (annotate_output_16a8w,) diff --git a/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte b/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte index 198b96e5b9b..213ec5b5f6f 100644 Binary files a/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte and b/examples/qualcomm/oss_scripts/llama/artifacts/stories260k_hybrid_llama_qnn.pte differ diff --git a/examples/qualcomm/oss_scripts/llama/decoder_utils.py b/examples/qualcomm/oss_scripts/llama/decoder_utils.py index 085e2a6c07e..b14c8999da8 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_utils.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_utils.py @@ -60,7 +60,6 @@ def __init__( # noqa: C901 ar_len: int, use_kv_cache: bool, get_example_inputs: Callable, - kv_updater: Callable, use_i64_token: bool, seq_mse_candidates: int, ): @@ -74,7 +73,6 @@ def __init__( # noqa: C901 self._use_kv_cache = use_kv_cache self.get_example_inputs = get_example_inputs self.max_seq_length = max_seq_length - self.kv_updater = kv_updater self.use_i64_token = use_i64_token self.seq_mse_candidates = seq_mse_candidates @@ -83,7 +81,6 @@ def _model_call(self, inps): kwargs = {} if self._use_kv_cache: kwargs["ar_len"] = self.ar_len - kwargs["kv_updater"] = self.kv_updater kwargs["seq_mse_candidates"] = self.seq_mse_candidates all_logits = INFERENCE_REGISTRY[self._use_kv_cache]( @@ -389,7 +386,6 @@ def post_process(): f"--performance_output_path {self.args.artifact}/{performance_output_path}", f"--eval_mode {EVAL_MODE[self.args.model_mode]}", "--temperature 0", - "--kv_updater ShiftPointer", f"--dump_logits_path {self.args.artifact}/{dump_logits_path}", f"--tokenized_prompt {input_file_name}", ] @@ -414,7 +410,6 @@ def post_process(): f"--seq_len {self.max_seq_length}", f"--output_path {outputs_path}", f"--performance_output_path {performance_output_path}", - f"--kv_updater {'SmartMask' if self.args.kv_updater == smart_mask_updater else 'ShiftPointer'}", f"--window {self.args.window}", f"--gcap {self.args.gcap}", f"--ngram {self.args.ngram}", @@ -422,6 +417,7 @@ def post_process(): "--temperature 0", f"--dump_logits_path {dump_logits_path}", f"--tokenized_prompt {os.path.basename(input_file_name)}", + "--shared_buffer", ] ) @@ -452,13 +448,17 @@ def smart_mask_updater( for i, offset in enumerate(lade_token_offset): current_pos = pos + i for j, (k_cache, v_cache) in enumerate(zip(k_caches, v_caches)): - k_cache[:, :, current_pos] = new_k_caches[j][:, :, offset] - v_cache[:, current_pos, :] = new_v_caches[j][:, offset, :] + k_cache[:, :, :, current_pos] = new_k_caches[j][:, :, :, offset] + v_cache[:, :, current_pos, :] = new_v_caches[j][:, :, offset, :] else: for i, k_cache in enumerate(k_caches): - k_cache[:, :, pos : pos + n_updates] = new_k_caches[i][:, :, :n_updates] + k_cache[:, :, :, pos : pos + n_updates] = new_k_caches[i][ + :, :, :, :n_updates + ] for i, v_cache in enumerate(v_caches): - v_cache[:, pos : pos + n_updates, :] = new_v_caches[i][:, :n_updates, :] + v_cache[:, :, pos : pos + n_updates, :] = new_v_caches[i][ + :, :, :n_updates, : + ] atten_mask.smart_mask_update(pos, n_updates, lade_pos_offset) @@ -466,57 +466,6 @@ def smart_mask_updater( return pos, k_caches, v_caches -def shift_pointer_updater( - n_updates: int, - atten_mask: AttentionMask, - pos, - k_caches, - v_caches, - new_k_caches, - new_v_caches, - # lookahead decoding related - lade_token_offset=None, - lade_pos_offset=None, -): - max_cache_len = k_caches[0].size(-1) - if pos + n_updates <= max_cache_len: - if lade_token_offset is not None: - # lookahead decode update - for offset in lade_token_offset: - for i, (k_cache, v_cache) in enumerate(zip(k_caches, v_caches)): - k_caches[i] = torch.cat( - [ - k_cache[:, :, 1:], - new_k_caches[i][:, :, offset].unsqueeze(-1), - ], - dim=-1, - ) - v_caches[i] = torch.cat( - [v_cache[:, 1:, :], new_v_caches[i][:, offset, :].unsqueeze(1)], - dim=1, - ) - else: - k_caches = [ - torch.cat( - [k_cache[:, :, n_updates:], new_k_caches[i][:, :, :n_updates]], - dim=-1, - ) - for i, k_cache in enumerate(k_caches) - ] - v_caches = [ - torch.cat( - [v_cache[:, n_updates:, :], new_v_caches[i][:, :n_updates, :]], - dim=1, - ) - for i, v_cache in enumerate(v_caches) - ] - - atten_mask.shift_pointer_update(pos, n_updates, lade_pos_offset) - - pos += n_updates - return pos, k_caches, v_caches - - @register_inference(use_kv_cache=True) def kv_inference( # noqa: C901 get_example_inputs, @@ -525,7 +474,6 @@ def kv_inference( # noqa: C901 tokenizer, ar_len=1, max_seq_len=512, - kv_updater=smart_mask_updater, use_i64_token=False, collect_logits=False, seq_mse_candidates=0, @@ -601,7 +549,7 @@ def kv_inference( # noqa: C901 ) # Update the pos, KV cache and attention mask. - pos, k_caches, v_caches = kv_updater( + pos, k_caches, v_caches = smart_mask_updater( num_tokens_in_chunk, atten_mask, pos, @@ -647,7 +595,7 @@ def kv_inference( # noqa: C901 *v_caches, ) - pos, k_caches, v_caches = kv_updater( + pos, k_caches, v_caches = smart_mask_updater( 1, atten_mask, pos, @@ -713,7 +661,7 @@ def kv_inference( # noqa: C901 for e in range(num_match) ] # update kv cache - pos, k_caches, v_caches = kv_updater( + pos, k_caches, v_caches = smart_mask_updater( len(lade_token_offset), atten_mask, pos, @@ -811,7 +759,6 @@ def graph_module_inference( tokenizer, ar_len=1, max_seq_len=512, - kv_updater=smart_mask_updater, prompt=None, tasks=None, tasks_limit=1, @@ -834,7 +781,6 @@ def graph_module_inference( kwargs = {} if use_kv_cache: kwargs["ar_len"] = ar_len - kwargs["kv_updater"] = kv_updater kwargs["lookahead_config"] = lookahead_config INFERENCE_REGISTRY[use_kv_cache]( @@ -855,7 +801,6 @@ def graph_module_inference( ar_len=ar_len, use_kv_cache=use_kv_cache, get_example_inputs=get_example_inputs, - kv_updater=kv_updater, use_i64_token=use_i64_token, seq_mse_candidates=seq_mse_candidates, ) diff --git a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py index 9af9cdf9549..6fc1fbd583d 100644 --- a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py +++ b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py @@ -255,8 +255,6 @@ def prequant_algorithm(model, prefill_config, args): reverse_quantize_module_swap(wrapped_model) for layer in model.layers: - if getattr(layer.attention, "prepare_sha", None): - layer.attention.prepare_sha() if getattr(layer.feed_forward, "prepare_feedfoward_conv", None): layer.feed_forward.prepare_feedfoward_conv() if args.embedding_quantize: @@ -342,7 +340,6 @@ def eval_llm(args): tokenizer=tokenizer, ar_len=args.max_seq_len, max_seq_len=args.max_seq_len, - kv_updater=args.kv_updater, tasks=["wikitext"], tasks_limit=1, use_i64_token=use_i64_token, @@ -364,7 +361,6 @@ def eval_llm(args): # tokenizer=tokenizer, # ar_len=args.max_seq_len, # max_seq_len=args.max_seq_len, - # kv_updater=args.kv_updater, # prompt="Can you tell me about Facebook?", # use_i64_token=use_i64_token, # event_name="convert_pt2e_prompt", @@ -378,7 +374,6 @@ def eval_llm(args): tokenizer=tokenizer, ar_len=args.max_seq_len, max_seq_len=args.max_seq_len, - kv_updater=args.kv_updater, tasks=["wikitext"], tasks_limit=0.1, use_i64_token=use_i64_token, @@ -416,13 +411,6 @@ def main() -> None: help="if you select this option we quantize linear layers only", action="store_true", ) - parser.add_argument( - "--kv_updater", - help="Choose how to update kv cache during runtime", - choices=["smart_mask", "shift_pointer"], - default="smart_mask", - type=str, - ) parser.add_argument( "--decoder_model", help=f"The Llama model to export. Current available options are: {SUPPORTED_LLM_MODELS.keys()}", diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 91d82531654..bef5c928a54 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -71,8 +71,6 @@ apply_prompt_template, graph_module_inference, QnnRunnerEvalWrapper, - shift_pointer_updater, - smart_mask_updater, ) from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import ( LlamaModel, @@ -263,7 +261,6 @@ def quantize( tokenizer=tokenizer, ar_len=self.llama_meta["get_ar_len"], max_seq_len=self.llama_meta["get_max_seq_len"], - kv_updater=args.kv_updater, tasks=args.tasks, tasks_limit=args.limit, num_fewshot=args.num_fewshot, @@ -287,7 +284,6 @@ def quantize( tokenizer=tokenizer, ar_len=self.llama_meta["get_ar_len"], max_seq_len=self.llama_meta["get_max_seq_len"], - kv_updater=args.kv_updater, prompt=prompt, use_i64_token=args.embedding_quantize is not None, event_name="prepare_pt2e_prompt", @@ -312,7 +308,6 @@ def quantize( tokenizer=tokenizer, ar_len=self.llama_meta["get_ar_len"], max_seq_len=self.llama_meta["get_max_seq_len"], - kv_updater=args.kv_updater, tasks=args.tasks, tasks_limit=args.limit, num_fewshot=args.num_fewshot, @@ -341,7 +336,6 @@ def quantize( tokenizer=tokenizer, ar_len=self.llama_meta["get_ar_len"], max_seq_len=self.llama_meta["get_max_seq_len"], - kv_updater=args.kv_updater, prompt=prompt, use_i64_token=args.embedding_quantize is not None, event_name="convert_pt2e_prompt", @@ -393,6 +387,7 @@ def lowering_modules( soc_model=soc_model, backend_options=backend_options, shared_buffer=shared_buffer, + use_mha2sha=True, ) skip_node_op_set = {"llama.fallback.default"} @@ -652,9 +647,6 @@ def permute(w, heads, partial_rotary_dim): for llama_instance in llama_instance_list: for layer in llama_instance.layers: - if getattr(layer.attention, "prepare_sha", None): - layer.attention.prepare_sha() - if getattr(layer.feed_forward, "prepare_feedfoward_conv", None): layer.feed_forward.prepare_feedfoward_conv() @@ -800,7 +792,8 @@ def permute(w, heads, partial_rotary_dim): args.artifact, use_fp16=use_fp16, soc_model=get_soc_to_chipset_map()[args.model], - shared_buffer=args.shared_buffer, + shared_buffer=not args.enable_x86_64, # x86 emulator does not support shared buffer + verbose=args.verbose, ) elif args.model_mode in ["hybrid", "lookahead"]: sample_inputs_list = [ @@ -816,8 +809,9 @@ def permute(w, heads, partial_rotary_dim): generate_qnn_executorch_compiler_spec( soc_model=get_soc_to_chipset_map()[args.model], backend_options=backend_options, - shared_buffer=args.shared_buffer, + shared_buffer=not args.enable_x86_64, # x86 emulator does not support shared buffer graph_name=graph_name, + use_mha2sha=True, ) for graph_name in graph_names ] @@ -966,11 +960,6 @@ def post_process(): # x86 emulator is intended for CI and not performance. Check only the first few tokens. seq_len = min(seq_len, 16) - if args.kv_updater == smart_mask_updater: - logging.warning( - "x86 only support ShiftPointer, overwrite kv_updater to ShiftPointer" - ) - qnn_sdk = os.getenv("QNN_SDK_ROOT") target = "x86_64-linux-clang" runner_cmd = " ".join( @@ -983,7 +972,6 @@ def post_process(): f"--seq_len {seq_len}", f"--output_path {args.artifact}/outputs/outputs.txt", f"--performance_output_path {args.artifact}/{performance_output_path}", - f"--kv_updater ShiftPointer", runner_args, ] ) @@ -1005,7 +993,7 @@ def post_process(): f"--seq_len {seq_len}", "--output_path outputs/outputs.txt", f"--performance_output_path {performance_output_path}", - f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}", + "--shared_buffer", runner_args, ] ) @@ -1018,7 +1006,7 @@ def post_process(): device_id=args.device, host_id=args.host, soc_model=args.model, - shared_buffer=args.shared_buffer, + shared_buffer=True, target=args.target, runner=f"examples/qualcomm/oss_scripts/llama/qnn_llama_runner", ) @@ -1182,14 +1170,6 @@ def _build_parser(): type=int, ) - parser.add_argument( - "--kv_updater", - help="Choose how to update kv cache during runtime", - choices=["smart_mask", "shift_pointer"], - default="smart_mask", - type=str, - ) - parser.add_argument( "-E", "--embedding-quantize", @@ -1316,14 +1296,6 @@ def export_llama(args) -> None: json.dump(data, file, indent=4) file.truncate() - if args.kv_updater == "smart_mask": - args.shared_buffer = True - args.kv_updater = smart_mask_updater - elif args.kv_updater == "shift_pointer": - args.kv_updater = shift_pointer_updater - else: - raise RuntimeError(f"Using an unknown kv update {args.kv_updater}") - if args.pre_gen_pte: inference( args, decoder_model_config, pte_filename, runtime_tokenizer_path, tokenizer diff --git a/examples/qualcomm/oss_scripts/llama/masking_utils.py b/examples/qualcomm/oss_scripts/llama/masking_utils.py index 0031f468802..ea68f89276a 100644 --- a/examples/qualcomm/oss_scripts/llama/masking_utils.py +++ b/examples/qualcomm/oss_scripts/llama/masking_utils.py @@ -104,18 +104,6 @@ def smart_mask_update(self, pos, n_updates, lade_pos_offset): """ pass - @abstractmethod - def shift_pointer_update(self, pos, n_updates, lade_pos_offset): - """ - Update the attention mask by shift pointer update method after model forward. - - Args: - pos (int): Current position in the sequence. - n_updates (int): Number of tokens to shift. - lade_pos_offset (List[int]): Position offset of lookahead attention mask. - """ - pass - class CausalAttentionMask(BaseAttentionMask): def __init__(self, max_batch_size: int, ar_len: int, max_seq_len: int): @@ -161,41 +149,6 @@ def smart_mask_update(self, pos, n_updates, _): end_pos = pos + n_updates self.mask[:, :, start_pos:end_pos] = 0 - def shift_pointer_update(self, pos, n_updates, _): - """ - Shift Pointer mechanism for attention mask updating - - Initial mask(5x15) layout (before any updates): - Each row represents a query token in the autoregressive context. - ● = activate (can attend), ○ = inactivate (masked) - - 0 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ○ ○ ○ ○ - 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ○ ○ ○ - 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ - 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ○ - 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ● ● - - After 1st update (e.g., pos=0, n_updates=5): - Newly added tokens are unmasked (set to 0). - - 0 ○ ○ ○ ○ ○ ● ● ● ● ● ● ○ ○ ○ ○ - 1 ○ ○ ○ ○ ○ ● ● ● ● ● ● ● ○ ○ ○ - 2 ○ ○ ○ ○ ○ ● ● ● ● ● ● ● ● ○ ○ - 3 ○ ○ ○ ○ ○ ● ● ● ● ● ● ● ● ● ○ - 4 ○ ○ ○ ○ ○ ● ● ● ● ● ● ● ● ● ● - - After 2nd update (e.g., pos=5, n_updates=5): - - 0 ● ● ● ● ● ● ● ● ● ● ● ○ ○ ○ ○ - 1 ● ● ● ● ● ● ● ● ● ● ● ● ○ ○ ○ - 2 ● ● ● ● ● ● ● ● ● ● ● ● ● ○ ○ - 3 ● ● ● ● ● ● ● ● ● ● ● ● ● ● ○ - 4 ● ● ● ● ● ● ● ● ● ● ● ● ● ● ● - """ - start_pos = -pos - n_updates - self.ar_len - end_pos = -pos - self.ar_len - self.mask[:, :, start_pos:end_pos] = 0 - class SlidingWindowAttentionMask(BaseAttentionMask): def __init__( @@ -266,54 +219,6 @@ def smart_mask_update(self, pos, n_updates, lade_pos_offset): # TODO: [Optional]: it can be optimized by computing the exact start index self.mask[:, i, : end_pos - available_cache_len] = -255.0 - def shift_pointer_update(self, pos, n_updates, lade_pos_offset): - """ - Shift Pointer mechanism for attention mask updating - - Initial mask(5x15) layout (before any updates): - Each row represents a query token in the autoregressive context. - ● = activate (can attend), ○ = inactivate (masked) - - 0 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ○ ○ ○ ○ - 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ○ ○ ○ - 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ - 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ - 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● - - After 1st update (e.g., pos=0, n_updates=5, sliding_window=3): - - 0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○ - 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ - 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ - 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ - 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● - - After 2nd update (e.g., pos=5, n_updates=5, sliding_window=3): - - 0 ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ ○ - 1 ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ ○ - 2 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ ○ - 3 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● ○ - 4 ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ○ ● ● ● - """ - - start_pos = -pos - n_updates - self.ar_len - end_pos = -pos - self.ar_len - self.mask[:, :, start_pos:end_pos] = 0 - - for i in range(self.ar_len): - available_cache_len = self.sliding_window - ( - (i + 1) if lade_pos_offset is None else (lade_pos_offset[i] + 1) - ) - if abs(start_pos + self.ar_len) > available_cache_len: - self.mask[ - :, - i, - start_pos : start_pos - + abs(start_pos + self.ar_len) - - available_cache_len, - ] = -255.0 - class AttentionMask: def __init__(self, masks: Union[BaseAttentionMask, List[BaseAttentionMask]]): @@ -323,9 +228,5 @@ def smart_mask_update(self, pos, n_updates, lade_pos_offset=None): for mask in self.masks: mask.smart_mask_update(pos, n_updates, lade_pos_offset) - def shift_pointer_update(self, pos, n_updates, lade_pos_offset=None): - for mask in self.masks: - mask.shift_pointer_update(pos, n_updates, lade_pos_offset) - def __iter__(self): return iter([mask.mask for mask in self.masks]) diff --git a/examples/qualcomm/oss_scripts/llama/model/apply_rope.py b/examples/qualcomm/oss_scripts/llama/model/apply_rope.py index 6d011c47336..9ad23eb5c41 100644 --- a/examples/qualcomm/oss_scripts/llama/model/apply_rope.py +++ b/examples/qualcomm/oss_scripts/llama/model/apply_rope.py @@ -24,8 +24,8 @@ def decorator(fn: Callable): @register_rotary_emb("partial") def apply_partial_rotary_emb_single(x, freqs_cos, freqs_sin): if x.dim() == 4: - freqs_cos = freqs_cos[None, :, None, :] - freqs_sin = freqs_sin[None, :, None, :] + freqs_cos = freqs_cos[None, None, :, :] + freqs_sin = freqs_sin[None, None, :, :] rotary_dim = freqs_cos.shape[-1] * 2 x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:] x_r, x_i = x_rot[..., : x_rot.shape[-1] // 2], x_rot[..., x_rot.shape[-1] // 2 :] @@ -43,8 +43,8 @@ def apply_rotary_emb_single(x, freqs_cos, freqs_sin): x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] # broadcast for batch_prefill mode input x if x.dim() == 4: - freqs_cos = freqs_cos[None, :, None, :] - freqs_sin = freqs_sin[None, :, None, :] + freqs_cos = freqs_cos[None, None, :, :] + freqs_sin = freqs_sin[None, None, :, :] x_out_r = x_r * freqs_cos - x_i * freqs_sin x_out_i = x_r * freqs_sin + x_i * freqs_cos diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index ba2d33d7890..6b503efd55b 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -31,6 +31,20 @@ ) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + class LlamaAttention(nn.Module): def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=False): super().__init__() @@ -41,6 +55,7 @@ def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=Fals self.n_kv_heads = config.n_kv_heads self.num_key_value_groups = config.n_heads // self.n_kv_heads self.max_seq_len = config.max_seq_len + self.use_kv_cache = config.use_kv_cache self.output_new_cache_only = output_new_cache_only self.enable_masked_softmax = getattr(config, "enable_masked_softmax", False) self.use_qk_norm = config.use_qk_norm @@ -95,168 +110,6 @@ def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=Fals persistent=False, ) - def prepare_sha(self): - self.wq_sha = nn.ModuleList( - [ - nn.Conv2d( - self.dim, - self.head_dim, - 1, - bias=getattr(self.config, "attention_qkv_bias", False), - ) - for _ in range(self.n_heads) - ] - ) - self.wk_sha = nn.ModuleList( - [ - nn.Conv2d( - self.dim, - self.head_dim, - 1, - bias=getattr(self.config, "attention_qkv_bias", False), - ) - for _ in range(self.n_kv_heads) - ] - ) - self.wv_sha = nn.ModuleList( - [ - nn.Conv2d( - self.dim, - self.head_dim, - 1, - bias=getattr(self.config, "attention_qkv_bias", False), - ) - for _ in range(self.n_kv_heads) - ] - ) - self.wo_sha = nn.Conv2d(self.n_heads * self.head_dim, self.dim, 1, bias=False) - - self.forward_mha = self.forward - self.forward = self.forward_sha - for i in range(self.n_heads): - self.wq_sha[i].weight.data.copy_( - self.wq.weight[ - i * self.head_dim : (i + 1) * self.head_dim, :, None, None - ] - ) - if self.wq_sha[i].bias is not None: - self.wq_sha[i].bias.data.copy_( - self.wq.bias[i * self.head_dim : (i + 1) * self.head_dim] - ) - for i in range(self.n_kv_heads): - self.wk_sha[i].weight.data.copy_( - self.wk.weight[ - i * self.head_dim : (i + 1) * self.head_dim, :, None, None - ] - ) - if self.wk_sha[i].bias is not None: - self.wk_sha[i].bias.data.copy_( - self.wk.bias[i * self.head_dim : (i + 1) * self.head_dim] - ) - self.wv_sha[i].weight.data.copy_( - self.wv.weight[ - i * self.head_dim : (i + 1) * self.head_dim, :, None, None - ] - ) - if self.wv_sha[i].bias is not None: - self.wv_sha[i].bias.data.copy_( - self.wv.bias[i * self.head_dim : (i + 1) * self.head_dim] - ) - self.wo_sha.weight.data.copy_(self.wo.weight[:, :, None, None]) - - def forward_sha( # noqa: C901 - self, - hidden_states: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor, - atten_mask: torch.Tensor, - k_caches: Optional[List[torch.Tensor]] = None, - v_caches: Optional[List[torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz, seq_len, _ = hidden_states.shape - # In the HTP backend, the input axis order for the convolution operation is - # more efficient with [1, 1, seq_len, dim] compared to [1, seq_len, 1, dim]. - hidden_states = torch.reshape( - hidden_states, (bsz, seq_len, 1, self.dim) - ).transpose(1, 3) - q = [ - wq_sha(hidden_states) - .permute(0, 2, 3, 1) - .reshape(bsz, seq_len, self.head_dim) - for wq_sha in self.wq_sha - ] - k = [ - wk_sha(hidden_states) - .permute(0, 2, 3, 1) - .reshape(bsz, seq_len, self.head_dim) - for wk_sha in self.wk_sha - ] - v = [ - wv_sha(hidden_states) - .permute(0, 2, 3, 1) - .reshape(bsz, seq_len, self.head_dim) - for wv_sha in self.wv_sha - ] - for i in range(len(q)): - if self.use_qk_norm and self.qk_norm_before_rope: - q[i] = self.q_norm_fn(q[i]) - if self.use_rope: - q[i] = self.apply_rope_emb(q[i], freqs_cos, freqs_sin) - if self.use_qk_norm and not self.qk_norm_before_rope: - q[i] = self.q_norm_fn(q[i]) - if getattr(self.config, "enable_r3", False): - q[i] = torch.matmul(q[i], self.r3_weight) - - for i in range(len(k)): - if self.use_qk_norm and self.qk_norm_before_rope: - k[i] = self.k_norm_fn(k[i]) - if self.use_rope: - k[i] = self.apply_rope_emb(k[i], freqs_cos, freqs_sin) - if self.use_qk_norm and not self.qk_norm_before_rope: - k[i] = self.k_norm_fn(k[i]) - if getattr(self.config, "enable_r3", False): - k[i] = torch.matmul(k[i], self.r3_weight) - k[i] = k[i].transpose(1, 2) - - output_y = [] - kh, vh = [], [] - # kv cache mode - if k_caches and v_caches: - for i, _ in enumerate(k_caches): - kh.append(torch.cat([k_caches[i], k[i]], dim=-1)) - vh.append(torch.cat([v_caches[i], v[i]], dim=1)) - # batch_prefill mode - else: - kh = k - vh = v - - for i, _ in enumerate(q): - cache_idx = i // self.num_key_value_groups - attn = q[i] @ kh[cache_idx] - attn = attn / self.scale - if self.enable_masked_softmax: - attn_min = torch.amin(attn, dim=-1, keepdim=True) - minus_value = -20 - attn = torch.where(atten_mask == 0, attn, attn_min + minus_value) - else: - attn = attn + atten_mask - attn = self.attn_softmax(attn) - y = attn @ vh[cache_idx] - - output_y.append(y) - - y = torch.concat(output_y, dim=-1) - y = y.reshape(bsz, seq_len, 1, -1) - y = y.transpose(1, 3) - y = self.wo_sha(y) - y = y.transpose(1, 3) - y = y.reshape(bsz, seq_len, -1) - - if self.output_new_cache_only: - return y, k, v - - return y, kh, vh - def forward( self, hidden_states: torch.Tensor, @@ -268,10 +121,12 @@ def forward( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, seq_len, _ = hidden_states.shape - q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states) - q = q.view(bsz, seq_len, self.n_heads, self.head_dim) - k = k.view(bsz, seq_len, self.n_kv_heads, self.head_dim) - v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim) + q = self.wq(hidden_states) + k = self.wk(hidden_states) + v = self.wv(hidden_states) + q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) if self.use_qk_norm and self.qk_norm_before_rope: q = self.q_norm_fn(q) @@ -280,55 +135,45 @@ def forward( if self.use_rope: q = self.apply_rope_emb(q, freqs_cos, freqs_sin) k = self.apply_rope_emb(k, freqs_cos, freqs_sin) - k = k.permute(0, 2, 3, 1) if self.use_qk_norm and not self.qk_norm_before_rope: q = self.q_norm_fn(q) k = self.k_norm_fn(k) + if getattr(self.config, "enable_r3", False): + q = torch.matmul(q, self.r3_weight) + k = torch.matmul(k, self.r3_weight) + k = k.transpose(2, 3) - output_kh, output_vh, output_y = [], [], [] - kh, vh = [], [] + kh, vh = None, None # kv cache mode - if k_caches and v_caches: - for i, _ in enumerate(k_caches): - kh.append(torch.cat([k_caches[i], k[:, i, :, :]], dim=-1)) - vh.append(torch.cat([v_caches[i], v[:, :, i, :]], dim=1)) - for i in range(self.n_heads): - cache_idx = i // self.num_key_value_groups - - attn = q[:, :, i, :] @ kh[cache_idx] - attn = attn / self.scale + atten_mask - attn = self.attn_softmax(attn) - y = attn @ vh[cache_idx] - - output_y.append(y) - + if self.use_kv_cache: + kh = torch.cat([k_caches, k], dim=-1) + vh = torch.cat([v_caches, v], dim=2) # batch_prefill mode else: kh = k vh = v - for i in range(self.n_heads): - cache_idx = i // self.num_key_value_groups - - attn = q[:, :, i, :] @ kh[:, cache_idx, :, :] - attn = attn / self.scale + atten_mask - attn = self.attn_softmax(attn) - y = attn @ vh[:, :, cache_idx, :] - output_y.append(y) + kh = repeat_kv(kh, self.num_key_value_groups) + vh = repeat_kv(vh, self.num_key_value_groups) - for i in range(self.n_kv_heads): - if self.output_new_cache_only: - output_kh.append(k[:, i, :, -1]) - output_vh.append(v[:, -1, i, :]) - else: - output_kh.append(k[:, i, :, :]) - output_vh.append(v[:, :, i, :]) - - y = torch.concat(output_y, dim=-1) + attn = q @ kh + attn = attn / self.scale + if self.enable_masked_softmax: + attn_min = torch.amin(attn, dim=-1, keepdim=True) + minus_value = -20 + attn = torch.where(atten_mask == 0, attn, attn_min + minus_value) + else: + attn = attn + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh + y = y.transpose(1, 2).reshape(bsz, seq_len, -1) y = self.wo(y) - return y, output_kh, output_vh + if self.output_new_cache_only: + return y, [k], [v] + + return y, [kh], [vh] class FeedForward(nn.Module): @@ -531,10 +376,11 @@ def forward( k_caches = None v_caches = None if self.use_kv_cache: - offset_k = ind * self.n_kv_heads - offset_v = self.n_layers * self.n_kv_heads + offset_k - k_caches = args[offset_k : offset_k + self.n_kv_heads] - v_caches = args[offset_v : offset_v + self.n_kv_heads] + offset_k = ind + offset_v = self.n_layers + offset_k + k_caches = args[offset_k] + v_caches = args[offset_v] + hidden_states, k, v = decoder_layer( hidden_states, freqs_cos=freqs_cos, @@ -564,24 +410,24 @@ def get_example_inputs(self, use_kv_cache=True): if use_kv_cache: pos_ids = torch.zeros((self.max_batch_size, self.ar_len), dtype=torch.int32) k_cache, v_cache = [], [] - for _ in range(self.n_layers): - for _ in range(self.n_kv_heads): - # transpose first to decrease the runtime efforts - k_cache.append( - torch.zeros( - self.max_batch_size, - self.head_dim, - self.max_seq_len - self.ar_len, - ) + # transpose first to decrease the runtime efforts + k_cache.append( + torch.zeros( + self.max_batch_size, + self.n_kv_heads, + self.head_dim, + self.max_seq_len - self.ar_len, ) - v_cache.append( - torch.zeros( - self.max_batch_size, - self.max_seq_len - self.ar_len, - self.head_dim, - ) + ) + v_cache.append( + torch.zeros( + self.max_batch_size, + self.n_kv_heads, + self.max_seq_len - self.ar_len, + self.head_dim, ) + ) return ( tokens, atten_mask, @@ -688,10 +534,10 @@ def forward( k_caches = None v_caches = None if self.use_kv_cache: - offset_k = ind * self.n_kv_heads - offset_v = self.n_layers * self.n_kv_heads + offset_k - k_caches = args[offset_k : offset_k + self.n_kv_heads] - v_caches = args[offset_v : offset_v + self.n_kv_heads] + offset_k = ind + offset_v = self.n_layers + offset_k + k_caches = args[offset_k] + v_caches = args[offset_v] if self.layer_types[ind] == "sliding_attention": hidden_states, k, v = decoder_layer( diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 52796e886fd..3e3e7aa7849 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -65,10 +65,10 @@ DEFINE_int32( eval_mode, 1, "0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding"); -DEFINE_string( - kv_updater, - "SmartMask", - "How to update kv cache. Choose between SmartMask and ShiftPointer"); +DEFINE_bool( + shared_buffer, + false, + "Specifies to use shared buffers for zero-copy use case between the application and device/co-processor associated with the backend."); DEFINE_int32(num_iters, 1, "total num of iterations to run."); DEFINE_int32( ngram, @@ -159,7 +159,8 @@ std::string get_formatted_prompt( } formatted_prompt.append("<|im_start|>user\n"); formatted_prompt.append(prompt); - formatted_prompt.append("<|im_end|>\n\n"); + formatted_prompt.append("<|im_end|>\n"); + formatted_prompt.append("<|im_start|>assistant\n\n"); break; case example::DecoderModelVersion::kSmollm3: if (!system_prompt.empty()) { @@ -196,7 +197,7 @@ void start_runner( FLAGS_performance_output_path.c_str(), FLAGS_temperature, FLAGS_eval_mode, - FLAGS_kv_updater, + FLAGS_shared_buffer, FLAGS_ngram, FLAGS_window, FLAGS_gcap); diff --git a/examples/qualcomm/oss_scripts/llama/runner/client_mem.h b/examples/qualcomm/oss_scripts/llama/runner/client_mem.h index 0fd535796de..6d4dbb68f9c 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/client_mem.h +++ b/examples/qualcomm/oss_scripts/llama/runner/client_mem.h @@ -15,7 +15,7 @@ namespace example { /** * @class ClientMem * @brief Final class for client buffer allocation, implementing IBufferAlloc - * interface. Used for SHIFT_POINTER mode. + * interface. This is specifically designed for use cases without shared buffer. */ class ClientMem final : public IMemAlloc { public: diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp index c6e59097ffc..1399f726869 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp @@ -10,40 +10,16 @@ #include namespace example { template -KVManager::KVManager(KVManagerMode kv_updater, Metadata metadata) - : kv_updater_(kv_updater), metadata_(metadata) { - k_cache_.resize( - metadata_.num_layers, std::vector>(metadata_.num_heads)); - v_cache_.resize( - metadata_.num_layers, std::vector>(metadata_.num_heads)); +KVManager::KVManager(Metadata metadata) : metadata_(metadata) { + k_cache_.resize(metadata_.num_layers); + v_cache_.resize(metadata_.num_layers); // Calculate cache size - switch (kv_updater_) { - case KVManagerMode::SMART_MASK: { - size_t cache_in_bytes = metadata_.num_layers * metadata_.num_heads * - metadata_.head_dim * metadata_.max_cache_len * sizeof(T); - size_t cache_out_bytes = metadata_.num_layers * metadata_.num_heads * - metadata_.head_dim * metadata_.max_ar_len * sizeof(T); - total_cache_size_ = 2 * (cache_in_bytes + cache_out_bytes); - break; - } - case KVManagerMode::SHIFT_POINTER: { - size_t k_cache_in_bytes = metadata_.num_layers * metadata_.num_heads * - (metadata_.head_dim + 1) * metadata_.max_cache_len * sizeof(T); - size_t k_cache_out_bytes = metadata_.num_layers * metadata_.num_heads * - metadata_.head_dim * metadata_.max_ar_len * sizeof(T); - // Use the same memory for input and output of value cache in shift - // pointer mode. Note that using context length to prevent exceeding the - // range when the AR-N model updates the last block in shift pointer - // mode. - size_t v_cache_bytes = metadata_.num_layers * (metadata_.num_heads + 1) * - metadata_.head_dim * metadata_.context_len * sizeof(T); - total_cache_size_ = k_cache_in_bytes + k_cache_out_bytes + v_cache_bytes; - break; - } - default: - break; - } + size_t cache_in_bytes = metadata_.num_layers * metadata_.num_heads * + metadata_.head_dim * metadata_.max_cache_len * sizeof(T); + size_t cache_out_bytes = metadata_.num_layers * metadata_.num_heads * + metadata_.head_dim * metadata_.max_ar_len * sizeof(T); + total_cache_size_ = 2 * (cache_in_bytes + cache_out_bytes); }; template @@ -63,56 +39,26 @@ void KVManager::init_attention_mask( std::fill_n(attention_mask, ar_len * metadata_.context_len, neg_val); // SMART_MASK requires special handling of attention mask - switch (kv_updater_) { - case KVManagerMode::SMART_MASK: { - uint16_t* past_ptr = attention_mask; - uint16_t* new_ptr = attention_mask + (metadata_.context_len - ar_len); - // All inputs will necessarily attend to n_past and itself - for (int i = 0; i < ar_len; i++) { - // Iterate across ar_len - if (attention_map[i] < 0) { - // If negative, attend to only past tokens - std::fill_n(past_ptr, n_past, pos_val); - } else { - // If positive, copy attention map from (relative to 0th input) parent - // Parent token index - const int32_t pidx = attention_map[i]; - uint16_t* parent_ptr = attention_mask + pidx * metadata_.context_len; - std::memcpy( - past_ptr, parent_ptr, metadata_.context_len * sizeof(uint16_t)); - } - // Attend to itself - new_ptr[i] = pos_val; - past_ptr += metadata_.context_len; - new_ptr += metadata_.context_len; - } - break; - } - case KVManagerMode::SHIFT_POINTER: { - // Only fill in ar_len. Rest will be padding - const size_t attn_row_start = metadata_.context_len - n_past - ar_len; - for (int i = 0; i < ar_len; i++) { - uint16_t* cur_ptr = - attention_mask + i * metadata_.context_len + attn_row_start; - // Attend to itself - cur_ptr[n_past + i] = pos_val; - if (attention_map[i] < 0) { - // If negative, attend to only past tokens - std::fill_n(cur_ptr, n_past, pos_val); - } else { - // If positive, copy attention map from (relative to 0th input) parent - // Parent token index - const int32_t pidx = attention_map[i]; - uint16_t* parent_ptr = - attention_mask + pidx * metadata_.context_len + attn_row_start; - std::memcpy( - cur_ptr, parent_ptr, (n_past + pidx + 1) * sizeof(uint16_t)); - } - } - break; + uint16_t* past_ptr = attention_mask; + uint16_t* new_ptr = attention_mask + (metadata_.context_len - ar_len); + // All inputs will necessarily attend to n_past and itself + for (int i = 0; i < ar_len; i++) { + // Iterate across ar_len + if (attention_map[i] < 0) { + // If negative, attend to only past tokens + std::fill_n(past_ptr, n_past, pos_val); + } else { + // If positive, copy attention map from (relative to 0th input) parent + // Parent token index + const int32_t pidx = attention_map[i]; + uint16_t* parent_ptr = attention_mask + pidx * metadata_.context_len; + std::memcpy( + past_ptr, parent_ptr, metadata_.context_len * sizeof(uint16_t)); } - default: - break; + // Attend to itself + new_ptr[i] = pos_val; + past_ptr += metadata_.context_len; + new_ptr += metadata_.context_len; } } @@ -135,64 +81,35 @@ void KVManager::init_attention_mask( std::fill_n(attention_mask, ar_len * metadata_.context_len, neg_val); // SMART_MASK requires special handling of attention mask - switch (kv_updater_) { - case KVManagerMode::SMART_MASK: { - uint16_t* past_ptr = attention_mask; - uint16_t* new_ptr = attention_mask + (metadata_.context_len - ar_len); - // All inputs will necessarily attend to n_past and itself - for (int i = 0; i < ar_len; i++) { - // Iterate across ar_len - if (attention_map[i] < 0) { - // If negative, attend to only past tokens - std::fill_n(past_ptr, n_past, pos_val); - } else { - // If positive, copy attention map from (relative to 0th input) parent - // Parent token index - const int32_t pidx = attention_map[i]; - uint16_t* parent_ptr = attention_mask + pidx * metadata_.context_len; - std::memcpy( - past_ptr, parent_ptr, metadata_.context_len * sizeof(uint16_t)); - } - // Attend to itself - new_ptr[i] = pos_val; - // mask by limitation of sliding_window - int32_t available_context_len = position_offset.empty() - ? sliding_window - (i + 1) - n_past - : sliding_window - (position_offset[i] + 1) - n_past; - if (n_past > available_context_len) { - std::fill_n(past_ptr, n_past - available_context_len, neg_val); - } - - past_ptr += metadata_.context_len; - new_ptr += metadata_.context_len; - } - break; + uint16_t* past_ptr = attention_mask; + uint16_t* new_ptr = attention_mask + (metadata_.context_len - ar_len); + // All inputs will necessarily attend to n_past and itself + for (int i = 0; i < ar_len; i++) { + // Iterate across ar_len + if (attention_map[i] < 0) { + // If negative, attend to only past tokens + std::fill_n(past_ptr, n_past, pos_val); + } else { + // If positive, copy attention map from (relative to 0th input) parent + // Parent token index + const int32_t pidx = attention_map[i]; + uint16_t* parent_ptr = attention_mask + pidx * metadata_.context_len; + std::memcpy( + past_ptr, parent_ptr, metadata_.context_len * sizeof(uint16_t)); } - case KVManagerMode::SHIFT_POINTER: { - // Only fill in ar_len. Rest will be padding - const size_t attn_row_start = metadata_.context_len - n_past - ar_len; - for (int i = 0; i < ar_len; i++) { - uint16_t* cur_ptr = - attention_mask + i * metadata_.context_len + attn_row_start; - // Attend to itself - cur_ptr[n_past + i] = pos_val; - if (attention_map[i] < 0) { - // If negative, attend to only past tokens - std::fill_n(cur_ptr, n_past, pos_val); - } else { - // If positive, copy attention map from (relative to 0th input) parent - // Parent token index - const int32_t pidx = attention_map[i]; - uint16_t* parent_ptr = - attention_mask + pidx * metadata_.context_len + attn_row_start; - std::memcpy( - cur_ptr, parent_ptr, (n_past + pidx + 1) * sizeof(uint16_t)); - } - } - break; + // Attend to itself + new_ptr[i] = pos_val; + + // mask by limitation of sliding_window + int32_t available_context_len = position_offset.empty() + ? sliding_window - (i + 1) - n_past + : sliding_window - (position_offset[i] + 1) - n_past; + if (n_past > available_context_len) { + std::fill_n(past_ptr, n_past - available_context_len, neg_val); } - default: - break; + + past_ptr += metadata_.context_len; + new_ptr += metadata_.context_len; } } @@ -204,10 +121,7 @@ void KVManager::update_attention_mask( int32_t n_update) { uint16_t pos_val = 65535; uint16_t* cur_ptr = attention_mask; - if (kv_updater_ == KVManagerMode::SMART_MASK) - cur_ptr += n_past; - if (kv_updater_ == KVManagerMode::SHIFT_POINTER) - cur_ptr += metadata_.context_len - n_past - ar_len - n_update; + cur_ptr += n_past; for (int i = 0; i < ar_len; i++) { std::fill_n(cur_ptr, n_update, pos_val); @@ -226,27 +140,16 @@ void KVManager::update_attention_mask( uint16_t pos_val = 65535; uint16_t neg_val = 0; uint16_t* cur_ptr = attention_mask; - if (kv_updater_ == KVManagerMode::SMART_MASK) - cur_ptr += n_past; - if (kv_updater_ == KVManagerMode::SHIFT_POINTER) - cur_ptr += metadata_.context_len - n_past - ar_len - n_update; + cur_ptr += n_past; for (int i = 0; i < ar_len; i++) { std::fill_n(cur_ptr, n_update, pos_val); int32_t available_cache_len = position_offset.empty() ? sliding_window - (i + 1) : sliding_window - (position_offset[i] + 1); - if (kv_updater_ == KVManagerMode::SMART_MASK) { - if (n_past + n_update > available_cache_len) { - std::fill_n( - cur_ptr - n_past, n_past + n_update - available_cache_len, neg_val); - } - } else if (kv_updater_ == KVManagerMode::SHIFT_POINTER) { - if (std::abs(n_past + ar_len) > available_cache_len) { - int32_t n_invalid = n_past - available_cache_len; - std::fill_n( - cur_ptr, std::abs(n_past + ar_len) - available_cache_len, neg_val); - } + if (n_past + n_update > available_cache_len) { + std::fill_n( + cur_ptr - n_past, n_past + n_update - available_cache_len, neg_val); } cur_ptr += metadata_.context_len; } @@ -259,75 +162,25 @@ void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { metadata_.max_cache_len * sizeof(T); const size_t max_out_cache_block_in_bytes = metadata_.max_ar_len * sizeof(T); - switch (kv_updater_) { - case KVManagerMode::SMART_MASK: { - const size_t cache_in_bytes = - metadata_.head_dim * max_in_cache_block_in_bytes; - const size_t cache_out_bytes = - metadata_.head_dim * max_out_cache_block_in_bytes; - for (int layer = 0; layer < metadata_.num_layers; ++layer) { - for (int head = 0; head < metadata_.num_heads; ++head) { - // Allocate buffer for key cache and value cache - T* single_layer_k_cache_in = - reinterpret_cast(buffer_manager->allocate(cache_in_bytes)); - T* single_layer_k_cache_out = - reinterpret_cast(buffer_manager->allocate(cache_out_bytes)); - T* single_layer_v_cache_in = - reinterpret_cast(buffer_manager->allocate(cache_in_bytes)); - T* single_layer_v_cache_out = - reinterpret_cast(buffer_manager->allocate(cache_out_bytes)); + const size_t cache_in_bytes = + metadata_.num_heads * metadata_.head_dim * max_in_cache_block_in_bytes; + const size_t cache_out_bytes = + metadata_.num_heads * metadata_.head_dim * max_out_cache_block_in_bytes; + for (int layer = 0; layer < metadata_.num_layers; ++layer) { + // Allocate buffer for key cache and value cache + T* single_layer_k_cache_in = + reinterpret_cast(buffer_manager->allocate(cache_in_bytes)); + T* single_layer_k_cache_out = + reinterpret_cast(buffer_manager->allocate(cache_out_bytes)); + T* single_layer_v_cache_in = + reinterpret_cast(buffer_manager->allocate(cache_in_bytes)); + T* single_layer_v_cache_out = + reinterpret_cast(buffer_manager->allocate(cache_out_bytes)); - k_cache_[layer][head].buffer = single_layer_k_cache_in; - k_cache_[layer][head].output_buffer = single_layer_k_cache_out; - v_cache_[layer][head].buffer = single_layer_v_cache_in; - v_cache_[layer][head].output_buffer = single_layer_v_cache_out; - } - } - break; - } - case KVManagerMode::SHIFT_POINTER: { - const size_t k_cache_in_size_in_bytes = metadata_.num_heads * - (metadata_.head_dim + 1) * max_in_cache_block_in_bytes; - const size_t k_cache_out_size_in_bytes = metadata_.num_heads * - metadata_.head_dim * max_out_cache_block_in_bytes; - const size_t v_cache_size_in_bytes = (metadata_.num_heads + 1) * - metadata_.head_dim * metadata_.context_len * sizeof(T); - const int32_t single_head_size_in = - metadata_.head_dim * metadata_.max_cache_len; - const int32_t single_head_size_out = - metadata_.head_dim * metadata_.max_ar_len; - for (int layer = 0; layer < metadata_.num_layers; ++layer) { - // Allocate buffer for key cache and value cache - T* single_layer_k_cache_in = reinterpret_cast( - buffer_manager->allocate(k_cache_in_size_in_bytes)); - T* single_layer_k_cache_out = reinterpret_cast( - buffer_manager->allocate(k_cache_out_size_in_bytes)); - // Note that using context length to prevent exceeding the range when - // the AR-N model updates the last block in shift pointer mode. - T* single_layer_v_cache = reinterpret_cast( - buffer_manager->allocate(v_cache_size_in_bytes)); - for (int head = 0; head < metadata_.num_heads; ++head) { - k_cache_[layer][head].buffer = single_layer_k_cache_in + - head * (metadata_.head_dim + 1) * metadata_.max_cache_len; - k_cache_[layer][head].output_buffer = - single_layer_k_cache_out + head * single_head_size_out; - // v_cache: - // |cache_gap|h1_v_in_ptr|cache_len|h1_v_out_ptr|cache_gap|h2_v_in_ptr|cache_len|h2_v_out_ptr|...| - const int32_t cache_gap = (cur_ar_len_ == metadata_.context_len) - ? 0 - : metadata_.max_cache_len - (metadata_.context_len - cur_ar_len_); - v_cache_[layer][head].buffer = single_layer_v_cache + - head * metadata_.head_dim * metadata_.context_len + - cache_gap * metadata_.head_dim; - v_cache_[layer][head].output_buffer = single_layer_v_cache + - head * metadata_.head_dim * metadata_.context_len + - single_head_size_in; - } - } - break; - } - default: - break; + k_cache_[layer].buffer = single_layer_k_cache_in; + k_cache_[layer].output_buffer = single_layer_k_cache_out; + v_cache_[layer].buffer = single_layer_v_cache_in; + v_cache_[layer].output_buffer = single_layer_v_cache_out; } } @@ -337,10 +190,8 @@ void KVManager::rearrange_cache(int32_t ar_len_dst) { if (cur_ar_len_ == ar_len_dst) return; for (int layer = 0; layer < metadata_.num_layers; ++layer) { - for (int head = 0; head < metadata_.num_heads; ++head) { - rearrange_key(k_cache_[layer][head], ar_len_dst); - rearrange_value(v_cache_[layer][head], ar_len_dst); - } + rearrange_key(k_cache_[layer], ar_len_dst); + rearrange_value(v_cache_[layer], ar_len_dst); } // rearrange done. cur_ar_len_ = ar_len_dst; @@ -348,8 +199,6 @@ void KVManager::rearrange_cache(int32_t ar_len_dst) { template void KVManager::rearrange_key(KVCache& k_cache, int32_t ar_len_dst) { - // The output of key cache doesn't need to rearrange for both of SMART_MASK - // and SHIFT_POINTER const int32_t src_cache_num = (cur_ar_len_ == metadata_.context_len) ? metadata_.context_len : metadata_.context_len - cur_ar_len_; @@ -358,27 +207,20 @@ void KVManager::rearrange_key(KVCache& k_cache, int32_t ar_len_dst) { T* k_cache_in_write_ptr = k_cache.buffer; if (src_cache_num > dst_cache_num) { - if (kv_updater_ == KVManagerMode::SHIFT_POINTER) { - // Left padded KV$ - k_cache_in_read_ptr += src_cache_num; - k_cache_in_write_ptr += dst_cache_num; - } // copy from first dimension - for (int i = 0; i < metadata_.head_dim; i++) { + for (int i = 0; i < metadata_.head_dim * metadata_.num_heads; i++) { std::memmove( k_cache_in_write_ptr, k_cache_in_read_ptr, dst_cache_num * sizeof(T)); k_cache_in_read_ptr += src_cache_num; k_cache_in_write_ptr += dst_cache_num; } } else { - k_cache_in_read_ptr += (metadata_.head_dim - 1) * src_cache_num; - k_cache_in_write_ptr += (metadata_.head_dim - 1) * dst_cache_num; - if (kv_updater_ == KVManagerMode::SHIFT_POINTER) { - k_cache_in_read_ptr += src_cache_num; - k_cache_in_write_ptr += dst_cache_num; - } + k_cache_in_read_ptr += + (metadata_.head_dim * metadata_.num_heads - 1) * src_cache_num; + k_cache_in_write_ptr += + (metadata_.head_dim * metadata_.num_heads - 1) * dst_cache_num; // copy from last dimension - for (int i = 0; i < metadata_.head_dim; i++) { + for (int i = 0; i < metadata_.head_dim * metadata_.num_heads; i++) { std::memmove( k_cache_in_write_ptr, k_cache_in_read_ptr, src_cache_num * sizeof(T)); k_cache_in_read_ptr -= src_cache_num; @@ -389,54 +231,37 @@ void KVManager::rearrange_key(KVCache& k_cache, int32_t ar_len_dst) { template void KVManager::rearrange_value(KVCache& v_cache, int32_t ar_len_dst) { - // The input and output of the value cache don't need to rearrange for both - // SMART_MASK and SHIFT_POINTER. However, the input pointer of the value cache - // needs to be reset by ar_len_dst in SHIFT_POINTER mode. The output pointer - // of the value cache remains unchanged regardless of ar_len. - const int32_t ar_gap = (cur_ar_len_ == metadata_.context_len) - ? ar_len_dst - : ar_len_dst - cur_ar_len_; - if (kv_updater_ == KVManagerMode::SHIFT_POINTER) { - v_cache.buffer = v_cache.buffer + ar_gap * metadata_.head_dim; - } -} - -template -bool KVManager::update_cache_tensor( - std::vector>>& - k_cache_in, - std::vector>>& - k_cache_out, - std::vector>>& - v_cache_in, - std::vector>>& - v_cache_out, - int32_t ar_len, - int32_t n_past) { - ET_CHECK_MSG( - cur_ar_len_ == ar_len, - "Current AR length (%d) is not matched with target AR length (%d). Please rearrange cache first.", - cur_ar_len_, - ar_len); - bool updated = false; - // Data pointer in the tensors need to update only for SHIFT_POINTER mode - // The BERT model does not update the cache tensor because it does not use KV - // cache inputs. - if (kv_updater_ == KVManagerMode::SHIFT_POINTER && - metadata_.context_len != cur_ar_len_) { - for (int layer = 0; layer < metadata_.num_layers; ++layer) { - for (int head = 0; head < metadata_.num_heads; ++head) { - k_cache_in[layer][head]->set_data( - k_cache_[layer][head].buffer + n_past); - v_cache_in[layer][head]->set_data( - v_cache_[layer][head].buffer + n_past * metadata_.head_dim); - v_cache_out[layer][head]->set_data( - v_cache_[layer][head].output_buffer + n_past * metadata_.head_dim); - } + const int32_t src_cache_num = (cur_ar_len_ == metadata_.context_len) + ? metadata_.context_len + : metadata_.context_len - cur_ar_len_; + const int32_t dst_cache_num = metadata_.context_len - ar_len_dst; + T* v_cache_in_read_ptr = v_cache.buffer; + T* v_cache_in_write_ptr = v_cache.buffer; + if (src_cache_num > dst_cache_num) { + // copy from first dimension + for (int i = 0; i < metadata_.num_heads; i++) { + std::memmove( + v_cache_in_write_ptr, + v_cache_in_read_ptr, + dst_cache_num * metadata_.head_dim * sizeof(T)); + v_cache_in_read_ptr += src_cache_num * metadata_.head_dim; + v_cache_in_write_ptr += dst_cache_num * metadata_.head_dim; + } + } else { + v_cache_in_read_ptr += + metadata_.head_dim * (metadata_.num_heads - 1) * src_cache_num; + v_cache_in_write_ptr += + metadata_.head_dim * (metadata_.num_heads - 1) * dst_cache_num; + // copy from last dimension + for (int i = 0; i < metadata_.num_heads; i++) { + std::memmove( + v_cache_in_write_ptr, + v_cache_in_read_ptr, + src_cache_num * metadata_.head_dim * sizeof(T)); + v_cache_in_read_ptr -= src_cache_num * metadata_.head_dim; + v_cache_in_write_ptr -= dst_cache_num * metadata_.head_dim; } - updated = true; } - return updated; } template @@ -451,10 +276,8 @@ void KVManager::update_cache( cur_ar_len_, ar_len); for (int layer = 0; layer < metadata_.num_layers; ++layer) { - for (int head = 0; head < metadata_.num_heads; ++head) { - update_key(k_cache_[layer][head], n_past, n_update, selected); - update_value(v_cache_[layer][head], n_past, n_update, selected); - } + update_key(k_cache_[layer], n_past, n_update, selected); + update_value(v_cache_[layer], n_past, n_update, selected); } } @@ -472,12 +295,9 @@ void KVManager::update_key( : metadata_.context_len - cur_ar_len_; const int32_t out_size = cur_ar_len_; const int32_t past_size = n_past; - const int32_t n_iter = metadata_.head_dim; + const int32_t n_iter = metadata_.head_dim * metadata_.num_heads; - if (kv_updater_ == KVManagerMode::SHIFT_POINTER) - write_ptr += iter_size + past_size; - if (kv_updater_ == KVManagerMode::SMART_MASK) - write_ptr += past_size; + write_ptr += past_size; if (selected.empty()) { for (int i = 0; i < n_iter; ++i) { std::memcpy(write_ptr, read_ptr, copy_size); @@ -512,21 +332,20 @@ void KVManager::update_value( T* read_ptr = v_cache.output_buffer; const int32_t copy_size = n_update * metadata_.head_dim * sizeof(T); const int32_t past_size = n_past * metadata_.head_dim; + const int32_t n_iter = metadata_.num_heads; + const int32_t iter_size = (cur_ar_len_ == metadata_.context_len) + ? metadata_.context_len * metadata_.head_dim + : (metadata_.context_len - cur_ar_len_) * metadata_.head_dim; + const int32_t out_size = cur_ar_len_ * metadata_.head_dim; - if (kv_updater_ == KVManagerMode::SMART_MASK) - write_ptr += past_size; - - // Update the value cache for lookahead decoding in SHIFT_POINTER mode - if (kv_updater_ == KVManagerMode::SHIFT_POINTER) { - read_ptr += past_size; - write_ptr = read_ptr; - } + write_ptr += past_size; if (selected.empty()) { - // In general, value cache doesn't need to copy for SHIFT_POINTER mode - if (kv_updater_ == KVManagerMode::SHIFT_POINTER) - return; - std::memcpy(write_ptr, read_ptr, copy_size); + for (int i = 0; i < n_iter; i++) { + std::memcpy(write_ptr, read_ptr, copy_size); + write_ptr += iter_size; + read_ptr += out_size; + } } else { int32_t update_times = n_update; auto wp = write_ptr, rp = read_ptr; diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h index ca24166aa9c..aa355335b68 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h @@ -21,8 +21,6 @@ struct KVCache { T* output_buffer; }; -// Enumeration for key-value manager modes -enum KVManagerMode { SMART_MASK = 0x0, SHIFT_POINTER = 0x1 }; /** * @class KVManager * @brief Class for kv cache update, rearrangement, and buffer allocatation. @@ -38,12 +36,12 @@ class KVManager { int64_t num_heads; int64_t num_layers; }; - KVManager(KVManagerMode kv_updater, Metadata metadata); + KVManager(Metadata metadata); /** * @brief Allocate buffer for KV cache and set the cur_ar_len_. - * @param buffer_manager Pointer to IMemAlloc instance which depends on - * kv_updater. + * @param buffer_manager Pointer to IMemAlloc instance; by default, it uses a + * shared buffer with RPC memory. * @param ar_len Length of input tokens. */ void init_cache(IMemAlloc* buffer_manager, int32_t ar_len); @@ -141,31 +139,6 @@ class KVManager { int32_t sliding_window, const std::vector& position_offset = {}); - /** - * @brief Reset the data pointer of the I/O cache tensor based on number of - * past cache, kv manager mode, current ar length and KV cache data pointer - * for SHIFT_POINTER mode. - * @param k_cache_in Reference to the input key cache TensorImpl vector. - * @param k_cache_out Reference to the output key cache TensorImpl vector. - * @param v_cache_in Reference to the input value cache TensorImpl vector. - * @param v_cache_out Reference to the output value cache TensorImpl vector. - * @param ar_len Length of input tokens. - * @param n_past Number of past elements in the cache. - * @return Returns true if the data pointer is updated; otherwise, returns - * false. - */ - bool update_cache_tensor( - std::vector>>& - k_cache_in, - std::vector>>& - k_cache_out, - std::vector>>& - v_cache_in, - std::vector>>& - v_cache_out, - int32_t ar_len, - int32_t n_past); - /** * @brief Based on cur_ar_len_ to update cache * @param ar_len Length of input tokens. @@ -179,10 +152,10 @@ class KVManager { int32_t n_update, const std::vector& selected); - const std::vector>>& get_k_cache_() const { + const std::vector>& get_k_cache_() const { return k_cache_; } - const std::vector>>& get_v_cache_() const { + const std::vector>& get_v_cache_() const { return v_cache_; } @@ -204,16 +177,15 @@ class KVManager { int32_t n_past, int32_t n_update, const std::vector& selected); - KVManagerMode kv_updater_; // metadata Metadata metadata_; size_t total_cache_size_; int32_t cur_ar_len_; // Store start pointer of k and v cache for input and output - // input: layer -> head -> head_dim * max_cache_len - // output: layer -> head -> head_dim * max_ar_len - std::vector>> k_cache_; - std::vector>> v_cache_; + // input: layer -> head * head_dim * max_cache_len + // output: layer -> head * head_dim * max_ar_len + std::vector> k_cache_; + std::vector> v_cache_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp index 96a25e9c935..11a72ba421e 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp @@ -268,26 +268,8 @@ Result LhdTokenGenerator::generate( } } + // Fill in the token and position data prepare_io(input_tokens, input_pos); - // Only update data pointer of the cache to the tensor for SHIFT_POINTER - // mode - bool updated = this->kv_manager_->update_cache_tensor( - this->k_cache_in_, - this->k_cache_out_, - this->v_cache_in_, - this->v_cache_out_, - metadata_.ar_len, - pos); - // Only update the output of module for SHIFT_POINTER mode - if (updated) { - // Update the output of the module - ET_CHECK_MSG( - this->decoder_runner_->set_outputs( - this->method_name_, this->output_tensors_) == - executorch::runtime::Error::Ok, - "Failed to set output tensor for module %s", - this->method_name_.c_str()); - } // Run inference auto logits_res = diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp index 73da764b584..43c153b00f1 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp @@ -130,28 +130,26 @@ void PromptProcessor::init_io( // [I] kv_cache size_t index = idx; // bypass input_tokens, atten_mask, input_pos for (int cache_group = 0; cache_group < 2; ++cache_group) { - std::vector>>& cache = + std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector>> cache_ptrs = (cache_group == 0) + std::vector> cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); - for (int layer = 0; layer < metadata_.num_layers; ++layer) { - for (int head = 0; head < metadata_.num_heads; ++head, ++index) { - Result kv_cache = method_meta->input_tensor_meta(index); + for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { + Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer][head].buffer; + T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer].emplace_back(std::make_unique( - kv_cache->scalar_type(), - kv_cache->sizes().size(), - const_cast(kv_cache->sizes().data()), - cache_ptr, - const_cast( - kv_cache->dim_order().data()))); - input_tensors_.emplace_back(cache[layer][head].get()); - buffer_manager->add_memory_info( - cache_ptr, cache[layer][head]->nbytes(), kv_cache.get()); - } + cache[layer] = std::make_unique( + kv_cache->scalar_type(), + kv_cache->sizes().size(), + const_cast(kv_cache->sizes().data()), + cache_ptr, + const_cast( + kv_cache->dim_order().data())); + input_tensors_.emplace_back(cache[layer].get()); + buffer_manager->add_memory_info( + cache_ptr, cache[layer]->nbytes(), kv_cache.get()); } } } @@ -172,26 +170,23 @@ void PromptProcessor::init_io( // [O] kv_cache size_t index = 1; for (int cache_group = 0; cache_group < 2; ++cache_group) { - std::vector>>& cache = + std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector>> cache_ptrs = (cache_group == 0) + std::vector> cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); - for (int layer = 0; layer < metadata_.num_layers; ++layer) { - for (int head = 0; head < metadata_.num_heads; ++head, ++index) { - Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer][head].output_buffer; - cache[layer].emplace_back(std::make_unique( - kv_cache->scalar_type(), - kv_cache->sizes().size(), - const_cast(kv_cache->sizes().data()), - cache_ptr, - const_cast( - kv_cache->dim_order().data()))); - output_tensors_.emplace_back(cache[layer][head].get()); - buffer_manager->add_memory_info( - cache_ptr, cache[layer][head]->nbytes(), kv_cache.get()); - } + for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { + Result kv_cache = method_meta->output_tensor_meta(index); + T* cache_ptr = cache_ptrs[layer].output_buffer; + cache[layer] = std::make_unique( + kv_cache->scalar_type(), + kv_cache->sizes().size(), + const_cast(kv_cache->sizes().data()), + cache_ptr, + const_cast(kv_cache->dim_order().data())); + output_tensors_.emplace_back(cache[layer].get()); + buffer_manager->add_memory_info( + cache_ptr, cache[layer]->nbytes(), kv_cache.get()); } } // Prepare the vector of EValue to run inference @@ -290,24 +285,7 @@ Result PromptProcessor::prefill( for (int i = 0; i < num_iters; ++i) { // Fill in the token and position data prepare_io(prompt_tokens, prompt_pos, pos); - // Only update data pointer of the cache to the tensor for SHIFT_POINTER - // mode - bool updated = kv_manager_->update_cache_tensor( - k_cache_in_, - k_cache_out_, - v_cache_in_, - v_cache_out_, - metadata_.ar_len, - pos); - // Only update the output of module for SHIFT_POINTER mode - if (updated) { - // Update the output of the module - ET_CHECK_MSG( - decoder_runner_->set_outputs(method_name_, output_tensors_) == - executorch::runtime::Error::Ok, - "Failed to set output tensor for module %s", - method_name_.c_str()); - } + // Run inference decoder_runner_->step(method_name_, inputs_); if (dump_logits) { diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h index a3dd2079461..5c97e510987 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h @@ -41,8 +41,8 @@ class PromptProcessor { /** * @brief Initialize I/O tensor and allocate I/O data buffer. - * @param buffer_manager Pointer to IMemAlloc instance which depends on - * kv_updater. + * @param buffer_manager Pointer to IMemAlloc instance; by default, it uses a + * shared buffer with RPC memory. * @param method_meta Method metadata. */ void init_io( @@ -114,15 +114,11 @@ class PromptProcessor { TensorStruct window_attention_mask_; TensorStruct logits_; - // layer -> head -> TensorImpl - std::vector>> - k_cache_in_; - std::vector>> - v_cache_in_; - std::vector>> - k_cache_out_; - std::vector>> - v_cache_out_; + // layer -> TensorImpl + std::vector> k_cache_in_; + std::vector> v_cache_in_; + std::vector> k_cache_out_; + std::vector> v_cache_out_; std::vector inputs_; std::vector input_tensors_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index e239a2a5fe1..689a5ade581 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -94,7 +94,7 @@ Runner::Runner( const std::string& performance_output_path, const float temperature, const int eval_mode, - const std::string& kv_updater, + const bool shared_buffer, const int ngram, const int window, const int gcap, @@ -108,15 +108,9 @@ Runner::Runner( dump_logits_path_(dump_logits_path), temperature_(temperature), eval_mode_(static_cast(eval_mode)), + shared_buffer_(shared_buffer), tokenizer_(std::move(tokenizer)) { stats_.reset(); - if (kv_updater == "SmartMask") { - kv_updater_ = KVManagerMode::SMART_MASK; - } else if (kv_updater == "ShiftPointer") { - kv_updater_ = KVManagerMode::SHIFT_POINTER; - } else { - ET_CHECK_MSG(false, "kv updater (%s) not found", kv_updater.c_str()); - } if (decoder_model_version == "llama2") { decoder_model_version_ = DecoderModelVersion::kLlama2; @@ -146,7 +140,6 @@ Runner::Runner( ET_LOG(Info, "creating module: model_path=%s", model_path.c_str()); ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str()); ET_LOG(Info, "eval mode=%d", eval_mode_); - ET_LOG(Info, "kv updater=%s", kv_updater.c_str()); } template @@ -229,9 +222,10 @@ Error Runner::load() { ET_UNWRAP(module_->get("get_n_layers")).toScalar().to(); ET_CHECK_MSG(num_layers != -1, "Could not retrieve num layers"); - // k_cache: [1, head_dim, seq_len] - int64_t head_dim = method_meta->output_tensor_meta(1)->sizes()[1]; - int64_t num_heads = (method_meta->num_outputs() - 1) / (num_layers * 2); + // k_cache: [1, n_heads, head_dim, seq_len] + auto k_cache_shape = method_meta->output_tensor_meta(1)->sizes(); + int64_t num_heads = k_cache_shape[1]; + int64_t head_dim = k_cache_shape[2]; bool use_int64_token = method_meta->input_tensor_meta(0)->scalar_type() == executorch::aten::ScalarType::Long; @@ -269,15 +263,13 @@ Error Runner::load() { if (module_->method_names()->count("get_sliding_window") > 0) { sliding_window = ET_UNWRAP(module_->get("get_sliding_window")).toInt(); } - kv_manager_ = std::make_unique>( - kv_updater_, - typename KVManager::Metadata{ - context_len_, - head_dim, - max_ar_len, - max_cache_len, - num_heads, - num_layers}); + kv_manager_ = std::make_unique>(typename KVManager::Metadata{ + context_len_, + head_dim, + max_ar_len, + max_cache_len, + num_heads, + num_layers}); prompt_processor_ = std::make_unique>( decoder_runner_.get(), @@ -332,13 +324,12 @@ Error Runner::load() { } buffer_manager_ = std::make_unique(); - if (kv_updater_ == KVManagerMode::SMART_MASK) { + if (shared_buffer_) { buffer_manager_ = std::make_unique( kv_manager_->total_cache_size_in_bytes(), prompt_processor_->total_prompt_processor_io_size_in_bytes(), token_generator_->total_token_generator_io_size_in_bytes()); } - ET_LOG(Info, "creating io_memory"); // prepare io kv_manager_->init_cache(buffer_manager_.get(), prompt_processor_ar_len); diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 9cf730c3620..690e3f73a43 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -59,7 +59,7 @@ class Runner : public executorch::extension::llm::IRunner { const std::string& dump_logits_path, const float temperature = 0.8f, const int eval_mode = EvalMode::kHybrid, - const std::string& kv_updater = "SmartMask", + const bool shared_buffer = false, const int ngram = 0, const int window = 0, const int gcap = 0, @@ -110,9 +110,9 @@ class Runner : public executorch::extension::llm::IRunner { std::string dump_logits_path_; float temperature_; EvalMode eval_mode_; + bool shared_buffer_; DecoderModelVersion decoder_model_version_; - KVManagerMode kv_updater_; std::unique_ptr buffer_manager_; std::unique_ptr> kv_manager_; std::unique_ptr tokenizer_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp index 6775c08bd87..733577e7d59 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp @@ -134,28 +134,25 @@ void TokenGenerator::init_io( // [I] kv_cache size_t index = idx; // bypass input_tokens, atten_mask, input_pos for (int cache_group = 0; cache_group < 2; ++cache_group) { - std::vector>>& cache = + std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector>> cache_ptrs = (cache_group == 0) + std::vector> cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); - for (int layer = 0; layer < metadata_.num_layers; ++layer) { - for (int head = 0; head < metadata_.num_heads; ++head, ++index) { - Result kv_cache = method_meta->input_tensor_meta(index); + for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { + Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer][head].buffer; + T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer].emplace_back(std::make_unique( - kv_cache->scalar_type(), - kv_cache->sizes().size(), - const_cast(kv_cache->sizes().data()), - cache_ptr, - const_cast( - kv_cache->dim_order().data()))); - input_tensors_.emplace_back(cache[layer][head].get()); - buffer_manager->add_memory_info( - cache_ptr, cache[layer][head]->nbytes(), kv_cache.get()); - } + cache[layer] = std::make_unique( + kv_cache->scalar_type(), + kv_cache->sizes().size(), + const_cast(kv_cache->sizes().data()), + cache_ptr, + const_cast(kv_cache->dim_order().data())); + input_tensors_.emplace_back(cache[layer].get()); + buffer_manager->add_memory_info( + cache_ptr, cache[layer]->nbytes(), kv_cache.get()); } } @@ -175,26 +172,23 @@ void TokenGenerator::init_io( // [O] kv_cache index = 1; for (int cache_group = 0; cache_group < 2; ++cache_group) { - std::vector>>& cache = + std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector>> cache_ptrs = (cache_group == 0) + std::vector> cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); - for (int layer = 0; layer < metadata_.num_layers; ++layer) { - for (int head = 0; head < metadata_.num_heads; ++head, ++index) { - Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer][head].output_buffer; - cache[layer].emplace_back(std::make_unique( - kv_cache->scalar_type(), - kv_cache->sizes().size(), - const_cast(kv_cache->sizes().data()), - cache_ptr, - const_cast( - kv_cache->dim_order().data()))); - output_tensors_.emplace_back(cache[layer][head].get()); - buffer_manager->add_memory_info( - cache_ptr, cache[layer][head]->nbytes(), kv_cache.get()); - } + for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { + Result kv_cache = method_meta->output_tensor_meta(index); + T* cache_ptr = cache_ptrs[layer].output_buffer; + cache[layer] = std::make_unique( + kv_cache->scalar_type(), + kv_cache->sizes().size(), + const_cast(kv_cache->sizes().data()), + cache_ptr, + const_cast(kv_cache->dim_order().data())); + output_tensors_.emplace_back(cache[layer].get()); + buffer_manager->add_memory_info( + cache_ptr, cache[layer]->nbytes(), kv_cache.get()); } } // Prepare the vector of EValue to run inference @@ -261,24 +255,7 @@ Result TokenGenerator::generate( while (pos < seq_len - 1) { // Fill in the token and position data prepare_io(cur_token, pos); - // Only update data pointer of the cache to the tensor for SHIFT_POINTER - // mode - bool updated = kv_manager_->update_cache_tensor( - k_cache_in_, - k_cache_out_, - v_cache_in_, - v_cache_out_, - metadata_.ar_len, - pos); - // Only update the output of module for SHIFT_POINTER mode - if (updated) { - // Update the output of the module - ET_CHECK_MSG( - decoder_runner_->set_outputs(method_name_, output_tensors_) == - executorch::runtime::Error::Ok, - "Failed to set output tensor for module %s", - method_name_.c_str()); - } + // Run inference auto logits_res = decoder_runner_->step(method_name_, inputs_); if (dump_logits) { diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h index 9f0198f3040..329a4d49cc6 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h @@ -46,8 +46,8 @@ class TokenGenerator { virtual ~TokenGenerator() = default; /** * @brief Initialize I/O tensor and allocate I/O data buffer. - * @param buffer_manager Pointer to IMemAlloc instance which depends on - * kv_updater. + * @param buffer_manager Pointer to IMemAlloc instance; by default, it uses a + * shared buffer with RPC memory. * @param method_meta Method metadata. */ virtual void init_io( @@ -99,15 +99,11 @@ class TokenGenerator { TensorStruct window_attention_mask_; TensorStruct logits_; - // layer -> head -> TensorImpl - std::vector>> - k_cache_in_; - std::vector>> - v_cache_in_; - std::vector>> - k_cache_out_; - std::vector>> - v_cache_out_; + // layer -> TensorImpl + std::vector> k_cache_in_; + std::vector> v_cache_in_; + std::vector> k_cache_out_; + std::vector> v_cache_out_; std::vector inputs_; std::vector input_tensors_;