Skip to content

Commit 37c471d

Browse files
committed
qio + sha + cpu quantized embedding +r1r2 cuda
1 parent 46ea1a4 commit 37c471d

File tree

10 files changed

+348
-29
lines changed

10 files changed

+348
-29
lines changed

backends/qualcomm/_passes/i64_to_i32.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,12 @@ def _cast_to_int32(self, graph_module: torch.fx.GraphModule):
6161
to_dst_node.meta["val"] = node_val.to(torch.int32)
6262

6363
# Replace usage of the src dtype result with the dst dtype result.
64-
n.replace_all_uses_with(to_dst_node)
64+
if n.name != "tokens":
65+
n.replace_all_uses_with(to_dst_node)
66+
else:
67+
for user in n.users.copy():
68+
if user.name != "quantized_decomposed_embedding_4bit_dtype":
69+
user.replace_input_with(n, to_dst_node)
6570
to_dst_node.args = (n,)
6671

6772
def call(self, graph_module: torch.fx.GraphModule):

backends/qualcomm/partition/common_defs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
exir_ops.edge.aten.full.default,
1515
exir_ops.edge.aten.slice_scatter.default,
1616
exir_ops.edge.aten.copy.default,
17+
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
1718
]
1819

1920
to_be_implemented_operator = [

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
QuantizationConfig,
1313
)
1414
from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY
15+
from executorch.exir.dialects._ops import ops as exir_ops
1516
from torch.ao.quantization.quantizer import (
1617
QuantizationAnnotation,
1718
SharedQuantizationSpec,
@@ -144,3 +145,35 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
144145
for node in gm.graph.nodes:
145146
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
146147
annotate_matmul(node, quantization_config_16a8w)
148+
149+
150+
def get_custom_quant_ios_dtype(
151+
cache_shape: torch.Size,
152+
node: torch.fx.Node,
153+
kv_dtype=torch.uint8,
154+
sharding_dtype=torch.uint16,
155+
):
156+
"""
157+
This function is specific for llama inputs and outputs
158+
"""
159+
if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name:
160+
return kv_dtype
161+
162+
# Tag index put node before copy node, because copy is a skipped node in qnn
163+
if (
164+
exir_ops.edge.aten.index_put.default == node.target
165+
and node.meta["val"].shape == cache_shape
166+
):
167+
return kv_dtype
168+
169+
# Tag sharding io
170+
if exir_ops.edge.llama.fallback.default in [
171+
u.target for u in list(node.users.keys())
172+
] + [node.target]:
173+
return sharding_dtype
174+
175+
# Tag index op as quantized tensors. It is caused by sharding
176+
if exir_ops.edge.aten.index.Tensor in [
177+
u.target for u in list(node.users.keys())
178+
] + [node.target]:
179+
return sharding_dtype

backends/qualcomm/utils/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
QCOM_PASS_EXPAND_BROADCAST_SHAPE,
7272
QCOM_PASS_SKIP_ADVANCED_REQUANT,
7373
QCOM_QNN_COMPILE_SPEC,
74+
QCOM_QUANTIZED_IO,
7475
)
7576

7677
from executorch.exir import ExirExportedProgram
@@ -876,3 +877,12 @@ def get_soc_to_chipset_map():
876877
"SM8475": QcomChipset.SM8475,
877878
"SM8450": QcomChipset.SM8450,
878879
}
880+
881+
882+
def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
883+
"""
884+
Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess
885+
"""
886+
for node in gm.graph.nodes:
887+
if dtype := get_quant_io_dtype_fn(node):
888+
node.meta[QCOM_QUANTIZED_IO] = dtype

examples/models/llama/export_llama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
# Example script for exporting Llama2 to flatbuffer
88

99
import logging
10+
import sys
1011

1112
import torch
1213

1314
from .export_llama_lib import build_args_parser, export_llama
1415

16+
sys.setrecursionlimit(4096)
17+
1518

1619
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
1720
logging.basicConfig(level=logging.INFO, format=FORMAT)

examples/models/llama/export_llama_lib.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
fuse_layer_norms,
5151
get_model_with_r1_r2,
5252
)
53+
54+
from .source_transformation.attention import replace_attention_to_attention_sha
5355
from .source_transformation.quantize import (
5456
get_quant_embedding_transform,
5557
get_quant_weight_transform,
@@ -174,6 +176,12 @@ def build_args_parser() -> argparse.ArgumentParser:
174176
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
175177
)
176178

179+
parser.add_argument(
180+
"--use_qnn_sha",
181+
action="store_true",
182+
help="Change multi head attention to multiple single head attention for qnn backend (Qualcomm)",
183+
)
184+
177185
parser.add_argument(
178186
"--calibration_tasks",
179187
nargs="+",
@@ -642,7 +650,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
642650
)
643651
)
644652
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
645-
from executorch.backends.qualcomm.utils.utils import _transform
653+
from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io
646654

647655
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
648656
_transform(builder_exported_to_edge.edge_manager.exported_program())
@@ -654,7 +662,32 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
654662
builder_exported_to_edge.metadata["get_n_layers"],
655663
shares=args.num_sharding,
656664
)
665+
from functools import partial
657666

667+
from executorch.backends.qualcomm.quantizer.custom_annotation import (
668+
get_custom_quant_ios_dtype,
669+
)
670+
atten = builder_exported_to_edge.model.layers[0].attention
671+
if args.use_qnn_sha:
672+
cache_shape = torch.Size(
673+
(atten.max_batch_size, atten.max_seq_len, atten.head_dim)
674+
)
675+
else:
676+
cache_shape = torch.Size(
677+
(
678+
atten.max_batch_size,
679+
atten.max_seq_len,
680+
atten.n_kv_heads,
681+
atten.head_dim,
682+
)
683+
)
684+
tag_quant_io(
685+
builder_exported_to_edge.edge_manager.exported_program().graph_module,
686+
partial(
687+
get_custom_quant_ios_dtype,
688+
cache_shape,
689+
),
690+
)
658691
logging.info("Lowering model using following partitioner(s): ")
659692
for partitioner in partitioners:
660693
logging.info(f"--> {partitioner.__class__.__name__}")
@@ -919,15 +952,27 @@ def _get_source_transforms( # noqa
919952
convert_linear_to_conv2d,
920953
)
921954

922-
transforms.append(replace_kv_cache_with_simple_kv_cache)
923-
transforms.append(replace_sdpa_with_flex_sdpa)
924-
transforms.append(replace_causal_mask)
925-
transforms.append(replace_rms_norm_with_native_rms_norm)
926-
if args.optimized_rotation_path:
927-
transforms.append(fuse_layer_norms)
928-
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
929-
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
930-
transforms.append(convert_linear_to_conv2d)
955+
if args.use_qnn_sha:
956+
if args.optimized_rotation_path:
957+
transforms.append(fuse_layer_norms)
958+
transforms.append(
959+
get_model_with_r1_r2(args.optimized_rotation_path)
960+
)
961+
transforms.append(replace_attention_to_attention_sha)
962+
transforms.append(replace_causal_mask)
963+
transforms.append(replace_rms_norm_with_native_rms_norm)
964+
transforms.append(convert_linear_to_conv2d)
965+
else:
966+
transforms.append(replace_kv_cache_with_simple_kv_cache)
967+
transforms.append(replace_sdpa_with_flex_sdpa)
968+
transforms.append(replace_causal_mask)
969+
transforms.append(replace_rms_norm_with_native_rms_norm)
970+
if args.optimized_rotation_path:
971+
transforms.append(fuse_layer_norms)
972+
transforms.append(
973+
get_model_with_r1_r2(args.optimized_rotation_path)
974+
)
975+
transforms.append(convert_linear_to_conv2d)
931976

932977
elif args.mps:
933978
# Currently mps doesn't support sdpa op, use the simpler decomposition

examples/models/llama/llama_transformer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -263,21 +263,22 @@ class Attention(nn.Module):
263263
def __init__(self, args: ModelArgs, layer_id: int):
264264
super().__init__()
265265
self.use_kv_cache = args.use_kv_cache
266-
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
267-
assert args.n_heads % self.n_kv_heads == 0
266+
self.n_heads = args.n_heads
267+
self.n_kv_heads = self.n_heads if args.n_kv_heads is None else args.n_kv_heads
268+
assert self.n_heads % self.n_kv_heads == 0
268269
model_parallel_size = 1
269-
self.n_local_heads = args.n_heads // model_parallel_size
270+
self.n_local_heads = self.n_heads // model_parallel_size
270271
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
271272
self.n_rep = self.n_local_heads // self.n_local_kv_heads
272-
self.head_dim = args.dim // args.n_heads
273+
self.head_dim = args.dim // self.n_heads
273274
self.max_batch_size = args.max_batch_size
274275
self.max_seq_len = args.max_seq_len
275276
self.dim = args.dim
276-
# args.dim = 4096, args.n_heads = 32, self.head_dim = 4096 / 32 = 125
277-
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
278-
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
279-
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
280-
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
277+
# args.dim = 4096, self.n_heads = 32, self.head_dim = 4096 / 32 = 125
278+
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
279+
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
280+
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
281+
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
281282

282283
self.layer_id = layer_id
283284

examples/models/llama/source_transformation/apply_spin_quant_r1_r2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def get_model_with_r1_r2(optimized_rotation_path: str):
9898

9999

100100
def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str):
101-
optimized_rotation = torch.load(optimized_rotation_path, weights_only=True)
101+
optimized_rotation = torch.load(optimized_rotation_path, weights_only=True, map_location=torch.device('cpu'))
102102
R1 = optimized_rotation["R1"].to(torch.float32)
103103
config = model.params
104104
num_heads = config.n_heads

0 commit comments

Comments
 (0)