diff --git a/backends/qualcomm/serialization/qnn_compile_spec_schema.py b/backends/qualcomm/serialization/qnn_compile_spec_schema.py index 8471aad982d..df7b968c125 100644 --- a/backends/qualcomm/serialization/qnn_compile_spec_schema.py +++ b/backends/qualcomm/serialization/qnn_compile_spec_schema.py @@ -33,6 +33,7 @@ class QcomChipset(IntEnum): SM8450 = 36 # v69 SM8475 = 42 # v69 SM8550 = 43 # v73 + SSG2115P = 46 # v73... I wish I can know where the number comes from... SM8650 = 57 # v75 @@ -47,6 +48,7 @@ class SocInfo: QcomChipset.SM8475: SocInfo(QcomChipset.SM8475, HtpInfo(HtpArch.V69, 8)), QcomChipset.SM8550: SocInfo(QcomChipset.SM8550, HtpInfo(HtpArch.V73, 8)), QcomChipset.SM8650: SocInfo(QcomChipset.SM8650, HtpInfo(HtpArch.V75, 8)), + QcomChipset.SSG2115P: SocInfo(QcomChipset.SSG2115P, HtpInfo(HtpArch.V73, 2)), } diff --git a/backends/qualcomm/serialization/schema.fbs b/backends/qualcomm/serialization/schema.fbs index 4e7fdb56e89..f2275377f7b 100644 --- a/backends/qualcomm/serialization/schema.fbs +++ b/backends/qualcomm/serialization/schema.fbs @@ -32,6 +32,7 @@ enum QcomChipset: int { SM8450 = 36, SM8475 = 42, SM8550 = 43, + SSG2115P = 46, SM8650 = 57, } @@ -170,7 +171,7 @@ table QnnExecuTorchOptions { /// Profiling level of the delegate and the backend. Default is off. profile_level:QnnExecuTorchProfileLevel; - + /// Enables usage of shared buffer between application and backend for graph I/O. shared_buffer:bool; diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 7209b0a2678..52ffac46eee 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -118,6 +118,7 @@ class TestQNN(unittest.TestCase): model: QcomChipset = None compiler_specs: List[CompileSpec] = None arch_table = { + "SSG2115P": QcomChipset.SSG2115P, "SM8650": QcomChipset.SM8650, "SM8550": QcomChipset.SM8550, "SM8475": QcomChipset.SM8475, diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index a39bb048200..e2e5a178edf 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -53,21 +53,23 @@ get_quant_embedding_transform, get_quant_weight_transform, ) -from .source_transformation.quantized_kv_cache import ( - replace_kv_cache_with_quantized_kv_cache, -) + +# from .source_transformation.quantized_kv_cache import ( +# replace_kv_cache_with_quantized_kv_cache, +# ) from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis -from .source_transformation.sdpa import ( - replace_causal_mask, - replace_kv_cache_with_coreml_kv_cache, - replace_kv_cache_with_simple_kv_cache, - replace_sdpa_with_coreml_sdpa, - replace_sdpa_with_custom_op, - replace_sdpa_with_flex_sdpa, - replace_sdpa_with_simple_sdpa, -) + +# from .source_transformation.sdpa import ( +# replace_causal_mask, +# replace_kv_cache_with_coreml_kv_cache, +# replace_kv_cache_with_simple_kv_cache, +# replace_sdpa_with_coreml_sdpa, +# replace_sdpa_with_custom_op, +# replace_sdpa_with_flex_sdpa, +# replace_sdpa_with_simple_sdpa, +# ) IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -893,23 +895,20 @@ def _get_source_transforms( # noqa assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" transforms.append(replace_kv_cache_with_quantized_kv_cache) + if args.qnn: + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` + from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d + + # transforms.append(replace_kv_cache_with_simple_kv_cache) + # transforms.append(replace_sdpa_with_flex_sdpa) + # transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) + transforms.append(convert_linear_to_conv2d) if args.use_kv_cache: - if args.qnn: - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` - from executorch.backends.qualcomm.utils.utils import ( - convert_linear_to_conv2d, - ) - - transforms.append(replace_kv_cache_with_simple_kv_cache) - transforms.append(replace_sdpa_with_flex_sdpa) - transforms.append(replace_causal_mask) - transforms.append(replace_rms_norm_with_native_rms_norm) - if args.optimized_rotation_path: - transforms.append(fuse_layer_norms) - transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) - transforms.append(convert_linear_to_conv2d) - - elif args.mps: + if args.mps: # Currently mps doesn't support sdpa op, use the simpler decomposition # to get free perf gain. transforms.append(replace_sdpa_with_simple_sdpa) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 8e17013ae3d..2a229a87609 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -1,27 +1,14 @@ -# @lint-ignore-every LICENSELINT -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# Llama 2 is licensed under the LLAMA 2 Community License, -# Copyright (c) Meta Platforms, Inc. All Rights Reserved. - -# Please refer to README.md in the same folder for more information. - +import logging +import math from dataclasses import dataclass -from functools import partial -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import torch -import torch.nn.functional as F -from executorch.examples.models.llama2.rope import ( - apply_rotary_emb, - hf_apply_rotary_emb, - hf_precompute_freqs_cis, - precompute_freqs_cis, -) +from torch.nn import functional as F + -from torch import nn +logger: logging.Logger = logging.getLogger() class RMSNorm(torch.nn.Module): @@ -39,9 +26,8 @@ def __init__(self, dim: int, eps: float = 1e-6): """ super().__init__() - self.dim = dim self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) + self.weight = torch.nn.Parameter(torch.ones(dim)) def _norm(self, x): """ @@ -54,7 +40,7 @@ def _norm(self, x): torch.Tensor: The normalized tensor. """ - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): """ @@ -71,12 +57,6 @@ def forward(self, x): return output * self.weight -def find_multiple(n: int, k: int) -> int: - if n % k == 0: - return n - return n + k - (n % k) - - @dataclass class ModelArgs: dim: int = 4096 @@ -84,182 +64,56 @@ class ModelArgs: n_heads: int = 32 n_kv_heads: Optional[int] = None vocab_size: int = -1 # defined later by tokenizer - hidden_dim: Optional[int] = None + invocation_vocab_size: int = -1 # defined later by tokenizer multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 ffn_dim_multiplier: Optional[float] = None norm_eps: float = 1e-5 max_batch_size: int = 32 max_seq_len: int = 2048 - moe: bool = False # True to enable the MoE (Mixture of Experts) - num_experts: int = 8 # Number of experts - num_activated_experts: int = 2 # Number of experts to activate - use_kv_cache: bool = False # Use key/value cache - use_sdpa_with_kv_cache_op: bool = ( - False # Use custom sdpa op that updates kv cache in-place - ) - # Generate logits for all inputs. When it's True, it would take big memory usage - # at runtime. Enable it only necessary (e.g., use perplexity tools that requires - # logits for all input tokens.) - generate_full_logits: bool = False - enable_dynamic_shape: bool = False # export model with dynamic shape support - # A dictionary mapping from pruned token-id to original token-id - output_prune_map: Optional[Dict[int, int]] = None - use_hf_rope: bool = False # Use HuggingFace's RoPE implementation - rope_theta: Optional[float] = ( - None # The official name to override self.rope_freq_base. - ) - rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. - use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1. - # Additional Model Metadata needed at runtime - bos_idx: int = 1 - eos_idx: int = 3 - bos_count: int = -1 # i.e., a single EOS is used as BOS - eos_count: int = 2 - - def __post_init__(self): - if self.n_kv_heads is None: - self.n_kv_heads = self.n_heads - - # rope_theta overrides rope_freq_base since it's the official name. - if self.rope_theta is not None: - self.rope_freq_base = self.rope_theta - - if self.use_sdpa_with_kv_cache_op: - assert self.use_kv_cache, "use_sdpa_with_kv_cache_op requires use_kv_cache" - - if self.hidden_dim is None: - # If hidden_dim is not explicitly set in the ModelArgs, - # then calculate implicitly based on dim and also multiple of `args.multiple_of` - multiple_of = self.multiple_of - hidden_dim = 4 * self.dim - hidden_dim = int(2 * hidden_dim / 3) - if self.ffn_dim_multiplier is not None: - hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) - self.hidden_dim = find_multiple(hidden_dim, multiple_of) + hidden_dim: Optional[int] = None -class KVCache(nn.Module): - def __init__( - self, - max_batch_size: int, - max_seq_length: int, - n_heads: int, - head_dim: int, - transpose_cache: bool, - enable_dynamic_shape: bool, - dtype=torch.float32, - ): - super().__init__() - self.max_seq_length = max_seq_length - self.is_tranposed = transpose_cache - if transpose_cache: - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - else: - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) - - self.max_batch_size = max_batch_size - self.n_heads = n_heads - self.head_dim = head_dim - self.transpose_cache = transpose_cache - self.enable_dynamic_shape = enable_dynamic_shape - self.register_buffer( - "k_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") - ) - self.register_buffer( - "v_cache", torch.zeros(cache_shape, dtype=dtype, device="cpu") - ) +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # pyre-ignore + freqs = torch.outer(t, freqs).float() # pyre-ignore + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + return freqs_cos, freqs_sin - def update( - self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [S], k_val: [B, H, S, D] or [B, S, H, D] depending on transpose_cache - if self.enable_dynamic_shape: - start_pos = input_pos[0].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_seq_length) - dim_to_slice = 2 if self.transpose_cache else 1 - seq_length = k_val.size(dim_to_slice) - # Replace the entry in the cache for this token - # The following lines are equivalent to: - # cache_k[:bsz, start_pos : start_pos + seqlen] = xk - # cache_v[:bsz, start_pos : start_pos + seqlen] = xv - # when dim_to_slice is 1 - # We use .narrow() here to make the compiler happy - # pyre-ignore: Incompatible parameter type [6] - narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) - # pyre-ignore: Incompatible parameter type [6] - narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) - - narrowed_k.copy_(k_val) - narrowed_v.copy_(v_val) - return self.k_cache, self.v_cache - else: - k_out = self.k_cache - v_out = self.v_cache - if self.transpose_cache: - k_out[:, :, input_pos] = k_val - v_out[:, :, input_pos] = v_val - else: - k_out[:, input_pos] = k_val - v_out[:, input_pos] = v_val - return k_out, v_out +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(shape) -class SDPA(nn.Module): - def __init__( - self, - kv_cache: KVCache, - dim: int, - head_dim: int, - n_rep: int, - max_seq_len: int, - enable_dynamic_shape: bool, - ): - super().__init__() - self.kv_cache = kv_cache - self.dim = dim - self.head_dim = head_dim - self.n_rep = n_rep - self.max_seq_len = max_seq_len - self.enable_dynamic_shape = enable_dynamic_shape +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: - def forward( - self, - input_pos: torch.Tensor, - q: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_heads, head_dim) - k: torch.Tensor, # Already have rotary embeddings. (bs, seqlen, n_local_kv_heads, head_dim) - v: torch.Tensor, # (bs, seqlen, n_local_kv_heads, head_dim) - bsz, - seqlen, - mask: torch.Tensor, - ) -> torch.Tensor: - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) - if self.enable_dynamic_shape: - start_pos = input_pos[-1].item() - torch._check_is_size(start_pos) - torch._check(start_pos < self.max_seq_len) - seq_length = q.size(2) - # pyre-ignore: Incompatible parameter type [6] - attn_mask = mask.narrow(0, start_pos, seq_length) - else: - attn_mask = mask[None, None, input_pos] + xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) + xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) + + freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) + freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) + xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin + xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos + xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin + xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos - return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) + xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) -class Attention(nn.Module): - def __init__(self, args: ModelArgs, layer_id: int): + +class Attention(torch.nn.Module): + def __init__(self, args: ModelArgs): super().__init__() - self.use_kv_cache = args.use_kv_cache self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads assert args.n_heads % self.n_kv_heads == 0 model_parallel_size = 1 @@ -267,295 +121,146 @@ def __init__(self, args: ModelArgs, layer_id: int): self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads - self.max_batch_size = args.max_batch_size - self.max_seq_len = args.max_seq_len - self.dim = args.dim - # args.dim = 4096, args.n_heads = 32, self.head_dim = 4096 / 32 = 125 - self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) - - self.layer_id = layer_id - - causal_mask = torch.tril( - torch.ones( - self.max_seq_len, - self.max_seq_len, - dtype=torch.bool, - device="cpu", - ) - ) - self.register_buffer("mask", causal_mask, persistent=False) - - if self.use_kv_cache: - self.kv_cache = KVCache( - args.max_batch_size, - args.max_seq_len, - self.n_kv_heads, - self.head_dim, - not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v - args.enable_dynamic_shape, - ) - self.SDPA = SDPA( - kv_cache=self.kv_cache, - dim=self.dim, - head_dim=self.head_dim, - n_rep=self.n_rep, - max_seq_len=self.max_seq_len, - enable_dynamic_shape=args.enable_dynamic_shape, - ) - if args.use_hf_rope: - self.apply_rotary_emb = hf_apply_rotary_emb - else: - self.apply_rotary_emb = apply_rotary_emb + self.wq = torch.nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) + self.wk = torch.nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = torch.nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = torch.nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + # set large value of -inf (or -32768 with int16) when we want to + # ignore correspnding values in the mask + mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-32768")) + mask = torch.triu(mask, diagonal=1) + self.register_buffer("mask", mask) def forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, - input_pos: Optional[torch.Tensor] = None, ): bsz, seqlen, _ = x.shape # QKV - q, k, v = self.wq(x), self.wk(x), self.wv(x) - # We need view_copy elimination - q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim) - k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) - v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings - q, k = self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) - - if self.use_kv_cache: - assert input_pos is not None - output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask) - return self.wo(output) - - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) + xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) # grouped multiquery attention: expand out keys and values - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - + xk = [ + torch.cat([xk[:, :, i : i + 1, :]] * self.n_rep, dim=2) + for i in range(xk.size(2)) + ] + xk = torch.cat(xk, dim=2) + + xv = [ + torch.cat([xv[:, :, i : i + 1, :]] * self.n_rep, dim=2) + for i in range(xv.size(2)) + ] + xv = torch.cat(xv, dim=2) + + # make heads into a batch dimension + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = xk.transpose(1, 2) + xv = xv.transpose(1, 2) + + scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) assert hasattr(self, "mask") - - mask = self.mask[:seqlen, :seqlen] - - output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + scores = ( + scores + self.mask[:, :, :seqlen, :seqlen] + ) # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) output = self.wo(output) - return output -class FeedForward(nn.Module): - def __init__(self, args: ModelArgs): +class FeedForward(torch.nn.Module): + def __init__(self, dim: int, hidden_dim: int, multiple_of: int): super().__init__() - assert args.hidden_dim is not None - hidden_dim: int = args.hidden_dim - self.w1 = nn.Linear(args.dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, args.dim, bias=False) - self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) + self.w1 = torch.nn.Linear(dim, hidden_dim, bias=False) + self.w2 = torch.nn.Linear(hidden_dim, dim, bias=False) + self.w3 = torch.nn.Linear(dim, hidden_dim, bias=False) def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + x = F.silu(self.w1(x)) * self.w3(x) + x = self.w2(x) + return x -class ConditionalFeedForward(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.dim = args.dim - hidden_dim = args.hidden_dim - if hidden_dim is None: - # If hidden_dim is not explicitly set in the ModelArgs, - # then calculate implicitly based on dim and also multiple of `args.multiple_of` - multiple_of = args.multiple_of - hidden_dim = 4 * self.dim - hidden_dim = int(2 * hidden_dim / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) - self.w2 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) - self.w3 = nn.Parameter(torch.randn(args.num_experts, hidden_dim, self.dim)) - self.num_experts = args.num_experts - - def forward(self, x: torch.Tensor, expert_indices: torch.Tensor) -> torch.Tensor: - w1_weights = self.w1[expert_indices].transpose(-1, -2) # [T, A, D, D] - w3_weights = self.w3[expert_indices].transpose(-1, -2) # [T, A, D, D] - w2_weights = self.w2[expert_indices] # [T, A, D, D] - x1 = F.silu(torch.einsum("ti,taio -> tao", x, w1_weights)) - x3 = torch.einsum("ti, taio -> tao", x, w3_weights) - expert_outs = torch.einsum("tao, taoi -> tai", (x1 * x3), w2_weights) - return expert_outs - - -class MOEFeedForward(nn.Module): - def __init__(self, config) -> None: - super().__init__() - self.gate = nn.Linear(config.dim, config.num_experts, bias=False) - self.cond_ffn = ConditionalFeedForward(config) - self.dim = config.dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x.view(-1, self.dim) - # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts - # x: [T, D] - scores = self.gate(x) # [T, E] - expert_weights, expert_indices = torch.topk(scores, 2, dim=-1) # [T, A], [T, A] - expert_weights = expert_weights.softmax(dim=-1) # [T, A] - expert_outs = self.cond_ffn(x, expert_indices) - return torch.einsum("tai,ta -> ti", expert_outs, expert_weights) - - -class TransformerBlock(nn.Module): +class TransformerBlock(torch.nn.Module): def __init__(self, layer_id: int, args: ModelArgs): super().__init__() - self.use_kv_cache = args.use_kv_cache self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads - self.attention = Attention(args, layer_id) - if args.moe: - self.block_sparse_moe = MOEFeedForward(args) + self.attention = Attention(args) + if args.hidden_dim is None: + hidden_dim = 4 * args.dim + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = args.multiple_of * ( + (hidden_dim + args.multiple_of - 1) // args.multiple_of + ) else: - self.feed_forward = FeedForward(args) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - - def forward(self, x, freqs_cos, freqs_sin, input_pos=None): # x: 1xN - h = self.attention.forward( - self.attention_norm(x), freqs_cos, freqs_sin, input_pos + hidden_dim = args.hidden_dim + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=hidden_dim, + multiple_of=args.multiple_of, ) + self.layer_id = layer_id + self.attention_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = torch.nn.RMSNorm(args.dim, eps=args.norm_eps) - h = x + h - if hasattr(self, "block_sparse_moe"): - out = h + self.block_sparse_moe(self.ffn_norm(h)) - else: - out = h + self.feed_forward(self.ffn_norm(h)) + def forward(self, x, freqs_cos, freqs_sin): + h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin) + out = h + self.feed_forward.forward(self.ffn_norm(h)) return out -class Transformer(nn.Module): +class LastTimeStepPool(torch.nn.Module): + def forward(self, logits: torch.Tensor, seq_lens: torch.Tensor) -> torch.Tensor: + bsz, _, dim = logits.shape + idx = seq_lens.unsqueeze(1).expand(bsz, dim).unsqueeze(1) + return logits.gather(1, idx - 1).squeeze(1) + + +class Transformer(torch.nn.Module): def __init__(self, params: ModelArgs): super().__init__() self.params = params self.vocab_size = params.vocab_size self.n_layers = params.n_layers - self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.tok_embeddings = torch.nn.Embedding(params.vocab_size, params.dim) self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params)) - self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = nn.Linear(params.dim, params.vocab_size, bias=False) - self.use_kv_cache = params.use_kv_cache - self.generate_full_logits = params.generate_full_logits - self.max_seq_len = params.max_seq_len - self.output_prune_map = params.output_prune_map - if params.use_hf_rope: - self.precompute_freqs_cis = hf_precompute_freqs_cis - else: - self.precompute_freqs_cis = partial( - precompute_freqs_cis, use_scaled=params.use_scaled_rope - ) - freqs_cos, freqs_sin = self.precompute_freqs_cis( - params.dim // params.n_heads, - ( - params.max_seq_len # Normal llama2. - if params.ffn_dim_multiplier is None - else params.max_seq_len * 2 # Sharded checkpoint. - ), - params.rope_freq_base, + self.norm = torch.nn.RMSNorm(params.dim, eps=params.norm_eps) + self.out = torch.nn.Linear(params.dim, params.vocab_size, bias=False) + + freqs_cos, freqs_sin = precompute_freqs_cis( + self.params.dim // self.params.n_heads, self.params.max_seq_len ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) - def forward( - self, - tokens: Optional[torch.LongTensor] = None, # tokens - input_pos: Optional[ - torch.LongTensor - ] = None, # Scalar tensor indicating size of window of the caches - h: Optional[torch.FloatTensor] = None, # embeddings - ) -> torch.Tensor: - if (tokens is None) ^ (h is not None): - raise ValueError( - "You cannot specify both tokens and h at the same time, and must specify either one" - ) - if tokens is not None and h is None: - h = self.tok_embeddings(tokens) - seqlen = h.shape[1] - - if self.use_kv_cache: - assert ( - input_pos is not None - ), "input_pos must be provided when use_kv_cache is True" - - if self.params.enable_dynamic_shape: - # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. - input_pos_item = input_pos[-1].item() - torch._check_is_size(input_pos_item) - torch._check(input_pos_item < self.params.max_seq_len) - # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor - freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seqlen) - # pyre-ignore: Incompatible parameter type [6] - freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seqlen) - else: - # When not using dynamic shape, use of the .item results in - # symints, due to querying the data from tensor. - # this path avoids that for mps backend, although probably mps backend - # can support dynamic shape? - freqs_cos = self.freqs_cos[input_pos] - freqs_sin = self.freqs_sin[input_pos] - - else: - assert input_pos is None, "input_pos is unused when use_kv_cache is False" - freqs_cos = self.freqs_cos[:seqlen] - freqs_sin = self.freqs_sin[:seqlen] + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + freqs_cos = self.freqs_cos[:seqlen] + freqs_sin = self.freqs_sin[:seqlen] for layer in self.layers: - h = layer( - h, - freqs_cos, - freqs_sin, - input_pos, - ) - - if not self.generate_full_logits: - # Only the last logit is used for the new generated token - h = h[:, -1, :] + h = layer(h, freqs_cos, freqs_sin) h = self.norm(h) - logits = self.output(h) - - if self.output_prune_map is not None: - # expand to original size so that downstream applications can use the logits as-is. - if self.generate_full_logits: - # (1, seq_len, pruned_size) -> (1, seq_len, original_size) - expanded_logits = torch.full( - [logits.shape[0], logits.shape[1], self.vocab_size], - float("-inf"), - device=logits.device, - dtype=logits.dtype, - ) - expanded_logits[:, :, list(self.output_prune_map.values())] = logits - else: - # (1, pruned_size) -> (1, original_size) - expanded_logits = torch.full( - [logits.shape[0], self.vocab_size], - float("-inf"), - device=logits.device, - dtype=logits.dtype, - ) - expanded_logits[:, list(self.output_prune_map.values())] = logits - logits = expanded_logits - - return logits + invocation_logits = self.out(h) + + return invocation_logits diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index a4081d1bd57..5842ef544f0 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -150,16 +150,31 @@ def __init__(self, **kwargs): output_prune_map = {int(k): v for (k, v) in output_prune_map.items()} max_seq_len = self.max_seq_len max_batch_size = 1 + print("params: ", params) + params.pop("rope_theta", None) model_args: ModelArgs = ModelArgs( max_seq_len=max_seq_len, max_batch_size=max_batch_size, - use_kv_cache=self.use_kv_cache, - use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, - generate_full_logits=self.generate_full_logits, - output_prune_map=output_prune_map, - enable_dynamic_shape=self.enable_dynamic_shape, + # input_vocab_size=params["input_vocab_size"], + # use_kv_cache=self.use_kv_cache, + # use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, + # generate_full_logits=self.generate_full_logits, + # output_prune_map=output_prune_map, + # enable_dynamic_shape=self.enable_dynamic_shape, **params, ) + # model_args: ModelArgs = ( + # ModelArgs( + # dim=512, + # hidden_dim=1536, + # n_heads=8, + # n_kv_heads=2, + # n_layers=19, + # vocab_size=128256, + # invocation_vocab_size=8, + # use_layer_norm_op=True, + # ), + # ) if kwargs.get("fairseq2", False): print("Using fairseq2 checkpoint") checkpoint = convert_to_llama_checkpoint(checkpoint=checkpoint) @@ -170,10 +185,24 @@ def __init__(self, **kwargs): print(f"{key} : {weights.numel()} : {weights.size()}") print("============= /weights ================") - # Within the device="meta" context, tensors that are created do not carry data. - # They possess all other metadata a tensor carries such as size, stride, requires_grad. - with torch.device("meta"): - self.model_ = Transformer(model_args) + # Within the device="meta" context, tensors that are created do not carry data. + # They possess all other metadata a tensor carries such as size, stride, requires_grad. + # with torch.device("meta"): + # self.model_ = Transformer(model_args) + # self.model_ = Transformer( + # ModelArgs( + # dim=512, + # hidden_dim=1536, + # n_heads=8, + # n_kv_heads=2, + # n_layers=19, + # vocab_size=128256, + # invocation_vocab_size=8, + # use_layer_norm_op=True, + # ), + # ) + self.model_ = Transformer(model_args) + print("model: ", self.model_) if "int8" in str(checkpoint_path): print("Using int8 weight-only quantization!") @@ -263,11 +292,11 @@ def __init__(self, **kwargs): # assign=True: load params/buffers by assignment instead of performing an in-place copy. # Because we are using device="meta", tensors do not have memory associated with them # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. - missing, unexpected = self.model_.load_state_dict( - checkpoint, - strict=False, - assign=True, - ) # self.model_ = Transformer(gptconf) + # missing, unexpected = self.model_.load_state_dict( + # checkpoint, + # strict=False, + # assign=True, + # ) # self.model_ = Transformer(gptconf) if kwargs.get("verbose", False): print("============= missing keys ================") print(missing) @@ -296,11 +325,13 @@ def get_example_inputs(self): if self.use_kv_cache: return self.get_example_inputs_kvcache_sdpa() else: - return ( - torch.tensor( - [[1, 2, 3]], dtype=torch.long - ), # tokens, with kv cache our input token length is always just 1 token. - ) + # return ( + # torch.tensor( + # [[1, 2, 3]], dtype=torch.long + # ), # tokens, with kv cache our input token length is always just 1 token. + # ) + b = torch.ones(1, 64, dtype=torch.long) + return (b,) # assumption is the custom op doesnt support dynamic shape right now. It might but its untested so lets first get static shape working def get_example_inputs_kvcache_sdpa(self): diff --git a/examples/models/llama2/params/demo_config.json b/examples/models/llama2/params/demo_config.json index 13287f117e9..754d09b5ca2 100644 --- a/examples/models/llama2/params/demo_config.json +++ b/examples/models/llama2/params/demo_config.json @@ -1 +1 @@ -{"dim": 64, "multiple_of": 4, "n_heads": 8, "n_layers": 5, "norm_eps": 1e-05, "vocab_size": 512} \ No newline at end of file +{"dim": 64, "multiple_of": 4, "n_heads": 8, "n_layers": 1, "norm_eps": 1e-05, "vocab_size": 512} diff --git a/examples/models/llama2/runner/targets.bzl b/examples/models/llama2/runner/targets.bzl index 96d47ffce21..eb5dfe87299 100644 --- a/examples/models/llama2/runner/targets.bzl +++ b/examples/models/llama2/runner/targets.bzl @@ -29,6 +29,7 @@ def define_common_targets(): ], # qnn_executorch_backend can be added below //executorch/backends/qualcomm:qnn_executorch_backend exported_deps = [ + "//executorch/backends/qualcomm:qnn_executorch_backend", "//executorch/backends/xnnpack:xnnpack_backend", "//executorch/extension/llm/runner:stats", "//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix, diff --git a/examples/models/model_factory.py b/examples/models/model_factory.py index fb317e3bca3..8913bd50484 100644 --- a/examples/models/model_factory.py +++ b/examples/models/model_factory.py @@ -35,9 +35,11 @@ def create_model( ValueError: If the provided model class is not found in the module. """ package_prefix = "executorch." if not os.getcwd().endswith("executorch") else "" - module = importlib.import_module( - f"{package_prefix}examples.models.{module_name}" - ) + print(f"package_prefix: {package_prefix}") + # module = importlib.import_module( + # f"{package_prefix}examples.models.{module_name}" + # ) + module = importlib.import_module(f"executorch.examples.models.{module_name}") if hasattr(module, model_class_name): model_class = getattr(module, model_class_name) diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index d74cfa0ef07..f16736248dd 100644 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -48,6 +48,7 @@ soc_to_chipset_map = { + "SSG2115P": QcomChipset.SSG2115P, "SM8650": QcomChipset.SM8650, "SM8550": QcomChipset.SM8550, "SM8475": QcomChipset.SM8475, diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 9c4cd4453f0..ac16343a8e8 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -83,6 +83,7 @@ def __init__( self.debug_output_path = f"{self.workspace}/debug_output.bin" self.output_folder = f"{self.workspace}/outputs" self.arch_table = { + "SSG2115P": "73", "SM8650": "75", "SM8550": "73", "SM8475": "69", diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index ae0ca6df757..16f77668839 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -9,6 +9,7 @@ # ExecuTorch. import logging +import os from enum import Enum from typing import Any, Callable, List, Optional @@ -34,6 +35,7 @@ from torch.ao.quantization.quantizer import Quantizer from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.nn.attention import SDPBackend +from tqdm import tqdm FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -150,6 +152,7 @@ def source_transform( return self def _get_dynamic_shape(self) -> Any: + return None if self.dynamic_shapes: return self.dynamic_shapes diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 37b215a51ff..8f827443f7a 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -154,9 +154,9 @@ def get_qnn_partitioner( num_sharding: int = 0, soc_model: str = "SM8650", # default to SM8650 ): - assert ( - use_kv_cache is True - ), "Qualcomm backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment" + # assert ( + # use_kv_cache is True + # ), "Qualcomm backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment" try: # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.partition.qnn_partitioner` from executorch.backends.qualcomm.partition.qnn_partitioner import (