|
54 | 54 | ) |
55 | 55 | from .source_transformation.quantized_kv_cache import ( |
56 | 56 | replace_kv_cache_with_quantized_kv_cache, |
57 | | - replace_torchtune_kv_cache_with_quantized_kv_cache, |
58 | 57 | ) |
59 | 58 | from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm |
60 | 59 |
|
|
66 | 65 | replace_sdpa_with_coreml_sdpa, |
67 | 66 | replace_sdpa_with_custom_op, |
68 | 67 | replace_sdpa_with_flex_sdpa, |
69 | | - replace_sdpa_with_sdpa_only_custom_op, |
70 | 68 | replace_sdpa_with_simple_sdpa, |
71 | 69 | ) |
72 | | - |
73 | | -from .source_transformation.torchtune.attention import replace_mha_with_inference_mha |
74 | | - |
75 | 70 | from .source_transformation.vulkan_rope import replace_with_vulkan_rotary_emb |
76 | 71 |
|
77 | | - |
78 | 72 | IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False) |
79 | 73 | FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" |
80 | 74 | logging.basicConfig(level=logging.INFO, format=FORMAT) |
@@ -243,7 +237,7 @@ def build_args_parser() -> argparse.ArgumentParser: |
243 | 237 | "--use_sdpa_with_kv_cache", |
244 | 238 | default=False, |
245 | 239 | action="store_true", |
246 | | - help="Whether to use a custom sdpa + kv_cache update when kv cache is enabled.", |
| 240 | + help="Whether to use sdpa_with_kv_cache update op when using kv cache", |
247 | 241 | ) |
248 | 242 | parser.add_argument( |
249 | 243 | "--disable_dynamic_shape", |
@@ -595,18 +589,6 @@ def _validate_args(args): |
595 | 589 | if args.num_sharding > 0 and not args.qnn: |
596 | 590 | raise ValueError("Model shard is only supported with qnn backend now.") |
597 | 591 |
|
598 | | - if args.model in TORCHTUNE_DEFINED_MODELS: |
599 | | - if args.use_sdpa_with_kv_cache: |
600 | | - if not args.use_kv_cache and not args.quantize_kv_cache: |
601 | | - raise ValueError( |
602 | | - f"TorchTune-defined {args.model} only works with custom SDPA op + quantized KV cache at the moment. Please enable use_kv_cache and quantize_kv_cache when use_sdpa_with_kv_cache is enabled." |
603 | | - ) |
604 | | - if args.use_kv_cache: |
605 | | - if not args.quantize_kv_cache: |
606 | | - raise ValueError( |
607 | | - f"TorchTune-defined {args.model} only works with quantized KV cache at the moment. Please enable quantize_kv_cache when use_kv_cache is enabled." |
608 | | - ) |
609 | | - |
610 | 592 |
|
611 | 593 | def _export_llama(args) -> LLMEdgeManager: # noqa: C901 |
612 | 594 | _validate_args(args) |
@@ -910,7 +892,6 @@ def _load_llama_model( |
910 | 892 | def _get_source_transforms( # noqa |
911 | 893 | modelname: str, dtype_override: Optional[DType], args |
912 | 894 | ) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: |
913 | | - is_torchtune_model = modelname in TORCHTUNE_DEFINED_MODELS |
914 | 895 | transforms = [] |
915 | 896 |
|
916 | 897 | if args.use_spin_quant: |
@@ -962,29 +943,12 @@ def _get_source_transforms( # noqa |
962 | 943 | if args.expand_rope_table: |
963 | 944 | transforms.append(materialze_broadcast_of_rope_freq_cis) |
964 | 945 |
|
965 | | - transforms.append(replace_mha_with_inference_mha) |
966 | 946 | if args.use_sdpa_with_kv_cache: |
967 | | - if is_torchtune_model: |
968 | | - assert ( |
969 | | - args.use_kv_cache and args.quantize_kv_cache |
970 | | - ), "use_sdpa_with_kv_cache requires use_kv_cache=True and quantize_kv_cache=True for TorchTune at the moment." |
971 | | - transforms.append(replace_mha_with_inference_mha) |
972 | | - transforms.append(replace_sdpa_with_sdpa_only_custom_op) |
973 | | - else: |
974 | | - transforms.append(replace_sdpa_with_custom_op) |
| 947 | + transforms.append(replace_sdpa_with_custom_op) |
975 | 948 |
|
976 | 949 | if args.quantize_kv_cache: |
977 | 950 | assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True" |
978 | | - if is_torchtune_model: |
979 | | - transforms.append( |
980 | | - lambda module: replace_torchtune_kv_cache_with_quantized_kv_cache( |
981 | | - module, |
982 | | - is_transposed=not args.use_sdpa_with_kv_cache, |
983 | | - enable_dynamic_shape=args.enable_dynamic_shape, |
984 | | - ) |
985 | | - ) |
986 | | - else: |
987 | | - transforms.append(replace_kv_cache_with_quantized_kv_cache) |
| 951 | + transforms.append(replace_kv_cache_with_quantized_kv_cache) |
988 | 952 |
|
989 | 953 | if args.use_kv_cache: |
990 | 954 | if args.qnn: |
@@ -1019,8 +983,4 @@ def _get_source_transforms( # noqa |
1019 | 983 | if args.vulkan: |
1020 | 984 | transforms.append(replace_with_vulkan_rotary_emb) |
1021 | 985 |
|
1022 | | - print( |
1023 | | - f"Performing the following source transformations: {[transform.__name__ for transform in transforms]}" |
1024 | | - ) |
1025 | | - |
1026 | 986 | return transforms |
0 commit comments