Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from .source_transformation.quantized_kv_cache import (
replace_kv_cache_with_quantized_kv_cache,
replace_torchtune_kv_cache_with_quantized_kv_cache,
)
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm

Expand All @@ -65,10 +66,15 @@
replace_sdpa_with_coreml_sdpa,
replace_sdpa_with_custom_op,
replace_sdpa_with_flex_sdpa,
replace_sdpa_with_sdpa_only_custom_op,
replace_sdpa_with_simple_sdpa,
)

from .source_transformation.torchtune.attention import replace_mha_with_inference_mha

from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb


IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
Expand Down Expand Up @@ -237,7 +243,7 @@ def build_args_parser() -> argparse.ArgumentParser:
"--use_sdpa_with_kv_cache",
default=False,
action="store_true",
help="Whether to use sdpa_with_kv_cache update op when using kv cache",
help="Whether to use a custom sdpa + kv_cache update when kv cache is enabled.",
)
parser.add_argument(
"--disable_dynamic_shape",
Expand Down Expand Up @@ -589,6 +595,18 @@ def _validate_args(args):
if args.num_sharding > 0 and not args.qnn:
raise ValueError("Model shard is only supported with qnn backend now.")

if args.model in TORCHTUNE_DEFINED_MODELS:
if args.use_sdpa_with_kv_cache:
if not args.use_kv_cache and not args.quantize_kv_cache:
raise ValueError(
f"TorchTune-defined {args.model} only works with custom SDPA op + quantized KV cache at the moment. Please enable use_kv_cache and quantize_kv_cache when use_sdpa_with_kv_cache is enabled."
)
if args.use_kv_cache:
if not args.quantize_kv_cache:
raise ValueError(
f"TorchTune-defined {args.model} only works with quantized KV cache at the moment. Please enable quantize_kv_cache when use_kv_cache is enabled."
)


def _export_llama(args) -> LLMEdgeManager: # noqa: C901
_validate_args(args)
Expand Down Expand Up @@ -892,6 +910,7 @@ def _load_llama_model(
def _get_source_transforms( # noqa
modelname: str, dtype_override: Optional[DType], args
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
is_torchtune_model = modelname in TORCHTUNE_DEFINED_MODELS
transforms = []

if args.use_spin_quant:
Expand Down Expand Up @@ -943,12 +962,29 @@ def _get_source_transforms( # noqa
if args.expand_rope_table:
transforms.append(materialze_broadcast_of_rope_freq_cis)

transforms.append(replace_mha_with_inference_mha)
if args.use_sdpa_with_kv_cache:
transforms.append(replace_sdpa_with_custom_op)
if is_torchtune_model:
assert (
args.use_kv_cache and args.quantize_kv_cache
), "use_sdpa_with_kv_cache requires use_kv_cache=True and quantize_kv_cache=True for TorchTune at the moment."
transforms.append(replace_mha_with_inference_mha)
transforms.append(replace_sdpa_with_sdpa_only_custom_op)
else:
transforms.append(replace_sdpa_with_custom_op)

if args.quantize_kv_cache:
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
transforms.append(replace_kv_cache_with_quantized_kv_cache)
if is_torchtune_model:
transforms.append(
lambda module: replace_torchtune_kv_cache_with_quantized_kv_cache(
module,
is_transposed=not args.use_sdpa_with_kv_cache,
enable_dynamic_shape=args.enable_dynamic_shape,
)
)
else:
transforms.append(replace_kv_cache_with_quantized_kv_cache)

if args.use_kv_cache:
if args.qnn:
Expand Down Expand Up @@ -983,4 +1019,8 @@ def _get_source_transforms( # noqa
if args.vulkan:
transforms.append(replace_with_vulkan_rotary_emb)

print(
f"Performing the following source transformations: {[transform.__name__ for transform in transforms]}"
)

return transforms
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.nn as nn
from executorch.examples.models.llama.llama_transformer import KVCache
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchtune.modules.kv_cache import KVCache as TorchTuneKVCache


"""
Expand Down Expand Up @@ -207,8 +208,31 @@ def from_float(cls, kv_cache, cache_type: QuantizedCacheType):
kv_cache.enable_dynamic_shape,
)

@classmethod
def from_torchtune_float(
cls,
kv_cache,
cache_type: QuantizedCacheType,
is_transposed: bool,
enable_dynamic_shape: bool,
):
cache_shape = kv_cache.k_cache.shape
if kv_cache.is_tranposed:
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
else:
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
return cls(
max_batch_size,
max_seq_length,
n_heads,
head_dim,
cache_type,
is_transposed,
enable_dynamic_shape,
)


def replace_kv_cache_with_quantized_kv_cache(module):
def replace_kv_cache_with_quantized_kv_cache(module: nn.Module) -> nn.Module:
logging.warning(
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
)
Expand All @@ -222,3 +246,41 @@ def replace_kv_cache_with_quantized_kv_cache(module):
else:
replace_kv_cache_with_quantized_kv_cache(child)
return module


def replace_torchtune_kv_cache_with_quantized_kv_cache(
module: nn.Module, is_transposed: bool, enable_dynamic_shape: bool
) -> nn.Module:
"""
Replace TorchTune KVCache with Executorch's quantized KVCache.

Args:
is_transposed: whether q, k, and v are transposed. Should set to false when sdpa custom op source transform is enabled.
enable_dynamic_shape: whether dynamic shapes are enabled.

Returns:
The passed in model.
"""
logging.warning(
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
)
for name, child in module.named_children():
if isinstance(child, TorchTuneKVCache):
cache_shape = child.k_cache.shape
if is_transposed:
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
else:
max_batch_size, max_seq_length, n_heads, head_dim = cache_shape
setattr(
module,
name,
QuantizedKVCache.from_torchtune_float(
child,
QuantizedCacheType.AffineAsymmetric,
is_transposed,
enable_dynamic_shape,
),
)
else:
replace_kv_cache_with_quantized_kv_cache(child)
return module
61 changes: 60 additions & 1 deletion examples/models/llama/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(
input_pos[0].item(),
seqlen,
None, # Attention mask
0, # dropout probability. Ignored by the code
0, # Dropout probability, ignored by the code
True, # is_causal
)
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
Expand All @@ -105,6 +105,65 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
return module


class SDPAOnlyCustom(torch.nn.Module):
"""
Just the custom SDPA op, no KV cache update included. Can only be used
in conjunction with a quantized KV cache.
"""

def __init__(
self,
):
super().__init__()

def forward(
self,
input_pos: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz: int,
seqlen: int,
mask: torch.Tensor = None,
):
# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
input_dtype = q.dtype
q = q.to(dtype=torch.float)
k = k.to(dtype=torch.float)
v = v.to(dtype=torch.float)
output = torch.ops.llama.custom_sdpa(
q,
k,
v,
input_pos[0].item(),
None, # Attention mask
0, # Dropout probability, ignored by the code.
True, # is_causal
)
return output.view(bsz, seqlen, -1).to(dtype=input_dtype)


def _replace_sdpa_with_sdpa_only_custom_op(module: torch.nn.Module):
for name, child in module.named_children():
if isinstance(child, SDPA):
assert (
child.kv_cache.cache_fp_type == torch.float32
), "Only float32 is supported for custom SDPA"
setattr(
module,
name,
SDPAOnlyCustom(),
)
else:
_replace_sdpa_with_sdpa_only_custom_op(child)


def replace_sdpa_with_sdpa_only_custom_op(module: torch.nn.Module) -> torch.nn.Module:
_replace_sdpa_with_sdpa_only_custom_op(module)
return module


class SDPASimple(torch.nn.Module):

def __init__(
Expand Down
Loading