@@ -579,49 +579,54 @@ def export_llama(
579579 if isinstance (export_options , argparse .Namespace ):
580580 # Legacy CLI.
581581 args = export_options
582- llm_config = convert_args_to_llm_config (export_options ) # noqa: F841
582+ llm_config = convert_args_to_llm_config (export_options )
583583 elif isinstance (export_options , DictConfig ):
584584 # Hydra CLI.
585- llm_config = export_options # noqa: F841
585+ llm_config = export_options
586+ # Create an args object for backward compatibility during transition
587+ args = argparse .Namespace ()
588+ for key , value in llm_config .items ():
589+ setattr (args , key , value )
586590 else :
587591 raise ValueError (
588592 "Input to export_llama must be either of type argparse.Namespace or LlmConfig"
589593 )
590594
591- # TODO: refactor rest of export_llama to use llm_config instead of args.
592-
593595 # If a checkpoint isn't provided for an HF OSS model, download and convert the
594596 # weights first.
595- if not args .checkpoint and args .model in HUGGING_FACE_REPO_IDS :
596- repo_id = HUGGING_FACE_REPO_IDS [args .model ]
597- if args .model == "qwen2_5" :
597+ model_name = llm_config .base .model_class
598+ if not llm_config .base .checkpoint and model_name in HUGGING_FACE_REPO_IDS :
599+ repo_id = HUGGING_FACE_REPO_IDS [model_name ]
600+ if model_name == "qwen2_5" :
598601 from executorch .examples .models .qwen2_5 import ( # pyre-ignore[21]
599602 convert_weights ,
600603 )
601- elif args . model .startswith ("qwen3" ):
604+ elif model_name .startswith ("qwen3" ):
602605 from executorch .examples .models .qwen3 import ( # pyre-ignore[21]
603606 convert_weights ,
604607 )
605- elif args . model == "phi_4_mini" :
608+ elif model_name == "phi_4_mini" :
606609 from executorch .examples .models .phi_4_mini import ( # pyre-ignore[21]
607610 convert_weights ,
608611 )
609- elif args . model == "smollm2" :
612+ elif model_name == "smollm2" :
610613 from executorch .examples .models .smollm2 import ( # pyre-ignore[21]
611614 convert_weights ,
612615 )
613616 else :
614617 raise ValueError (
615- f"Converting weights to meta format for { args . model } is not yet supported"
618+ f"Converting weights to meta format for { model_name } is not yet supported"
616619 )
617- args .checkpoint = download_and_convert_hf_checkpoint (repo_id , convert_weights )
620+ checkpoint = download_and_convert_hf_checkpoint (repo_id , convert_weights )
621+ llm_config .base .checkpoint = checkpoint
622+ args .checkpoint = checkpoint
618623
619- if args .profile_path is not None :
624+ if llm_config . debug .profile_path is not None :
620625 try :
621626 from executorch .util .python_profiler import CProfilerFlameGraph
622627
623- with CProfilerFlameGraph (args .profile_path ):
624- builder = _export_llama (args )
628+ with CProfilerFlameGraph (llm_config . debug .profile_path ):
629+ builder = _export_llama (llm_config , args )
625630 assert (
626631 filename := builder .get_saved_pte_filename ()
627632 ) is not None , "Fail to get file name from builder"
@@ -632,53 +637,53 @@ def export_llama(
632637 )
633638 return ""
634639 else :
635- builder = _export_llama (args )
640+ builder = _export_llama (llm_config , args )
636641 assert (
637642 filename := builder .get_saved_pte_filename ()
638643 ) is not None , "Fail to get file name from builder"
639644 return filename
640645
641646
642- def _prepare_for_llama_export (args ) -> LLMEdgeManager :
647+ def _prepare_for_llama_export (llm_config , args ) -> LLMEdgeManager :
643648 """
644649 Helper function for export_llama. Loads the model from checkpoint and params,
645650 and sets up a LLMEdgeManager with initial transforms and dtype conversion.
646651
647652 Returns a LLMEdgeManager prior to calling export_to_edge with quantizers
648653 """
649654 # load model from checkpoint and params.json
650- checkpoint_path = canonical_path (args . checkpoint ) if args .checkpoint else None
655+ checkpoint_path = canonical_path (llm_config . base . checkpoint ) if llm_config . base .checkpoint else None
651656 checkpoint_dir = (
652- canonical_path (args . checkpoint_dir ) if args .checkpoint_dir else None
657+ canonical_path (llm_config . base . checkpoint_dir ) if llm_config . base .checkpoint_dir else None
653658 )
654- params_path = canonical_path (args . params ) if args .params else None
655- output_dir_path = canonical_path (args .output_dir , dir = True )
656- weight_type = WeightType .FAIRSEQ2 if args .fairseq2 else WeightType .LLAMA
659+ params_path = canonical_path (llm_config . base . params ) if llm_config . base .params else None
660+ output_dir_path = canonical_path (llm_config . export .output_dir , dir = True )
661+ weight_type = WeightType .FAIRSEQ2 if llm_config . base .fairseq2 else WeightType .LLAMA
657662
658- # Convert dtype override string arg to actual type.
659- dtype_override = DType [args .dtype_override ]
663+ # Convert dtype override string to actual type
664+ dtype_override = DType [llm_config . model .dtype_override ]
660665
661666 edge_manager = _load_llama_model (
662- args . model ,
667+ llm_config . base . model_class ,
663668 checkpoint = checkpoint_path ,
664669 checkpoint_dir = checkpoint_dir ,
665670 params_path = params_path ,
666- use_kv_cache = args .use_kv_cache ,
667- use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
668- generate_full_logits = args .generate_full_logits ,
671+ use_kv_cache = llm_config . model .use_kv_cache ,
672+ use_sdpa_with_kv_cache = llm_config . model .use_sdpa_with_kv_cache ,
673+ generate_full_logits = llm_config . debug .generate_full_logits ,
669674 weight_type = weight_type ,
670- enable_dynamic_shape = args .enable_dynamic_shape ,
671- calibration_tasks = args .calibration_tasks ,
672- calibration_limit = args .calibration_limit ,
673- calibration_seq_length = args .calibration_seq_length ,
674- calibration_data = args .calibration_data ,
675- tokenizer_path = args .tokenizer_path ,
676- verbose = args .verbose ,
677- max_seq_len = args .max_seq_length ,
678- max_context_len = args .max_context_length ,
679- input_prune_map_path = args .input_prune_map ,
680- output_prune_map_path = args .output_prune_map ,
681- metadata_str = args .metadata ,
675+ enable_dynamic_shape = llm_config . model .enable_dynamic_shape ,
676+ calibration_tasks = llm_config . quantization .calibration_tasks ,
677+ calibration_limit = llm_config . quantization .calibration_limit ,
678+ calibration_seq_length = llm_config . quantization .calibration_seq_length ,
679+ calibration_data = llm_config . quantization .calibration_data ,
680+ tokenizer_path = llm_config . base .tokenizer_path ,
681+ verbose = llm_config . debug .verbose ,
682+ max_seq_len = llm_config . export .max_seq_length ,
683+ max_context_len = llm_config . export .max_context_length ,
684+ input_prune_map_path = llm_config . model .input_prune_map ,
685+ output_prune_map_path = llm_config . model .output_prune_map ,
686+ metadata_str = llm_config . base .metadata ,
682687 dtype_override = dtype_override ,
683688 args = args ,
684689 )
@@ -710,63 +715,63 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
710715 edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
711716 _get_source_transforms (
712717 dtype_override = dtype_override ,
713- checkpoint = args .checkpoint ,
718+ checkpoint = llm_config . base .checkpoint ,
714719 checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
715- tokenizer_path = args .tokenizer_path ,
716- use_spin_quant = args .use_spin_quant ,
717- embedding_quantize = args .embedding_quantize ,
718- use_shared_embedding = args .use_shared_embedding ,
719- quantization_mode = args . quantization_mode ,
720- group_size = args .group_size ,
721- calibration_tasks = args .calibration_tasks ,
722- calibration_limit = args .calibration_limit ,
723- calibration_seq_length = args .calibration_seq_length ,
724- expand_rope_table = args .expand_rope_table ,
720+ tokenizer_path = llm_config . base .tokenizer_path ,
721+ use_spin_quant = llm_config . quantization .use_spin_quant ,
722+ embedding_quantize = llm_config . quantization .embedding_quantize ,
723+ use_shared_embedding = llm_config . model .use_shared_embedding ,
724+ quantization_mode = llm_config . quantization . qmode ,
725+ group_size = llm_config . quantization .group_size ,
726+ calibration_tasks = llm_config . quantization .calibration_tasks ,
727+ calibration_limit = llm_config . quantization .calibration_limit ,
728+ calibration_seq_length = llm_config . quantization .calibration_seq_length ,
729+ expand_rope_table = llm_config . model .expand_rope_table ,
725730 use_custom_sdpa_with_attention_mask = getattr (
726- args , "use_custom_sdpa_with_attention_mask" , False
731+ llm_config . model , "use_custom_sdpa_with_attention_mask" , False
727732 ),
728- use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
729- quantize_kv_cache = args .quantize_kv_cache ,
730- use_kv_cache = args .use_kv_cache ,
731- qnn = args . qnn ,
732- use_qnn_sha = args . use_qnn_sha ,
733- optimized_rotation_path = args .optimized_rotation_path ,
734- mps = args . mps ,
735- coreml = args . coreml ,
736- coreml_ios = args . coreml_ios ,
737- vulkan = args . vulkan ,
738- use_qat = args .use_qat ,
739- use_lora = args .use_lora ,
740- preq_mode = args .preq_mode ,
741- preq_group_size = args .preq_group_size ,
742- preq_embedding_quantize = args .preq_embedding_quantize ,
733+ use_sdpa_with_kv_cache = llm_config . model .use_sdpa_with_kv_cache ,
734+ quantize_kv_cache = llm_config . model .quantize_kv_cache ,
735+ use_kv_cache = llm_config . model .use_kv_cache ,
736+ qnn = llm_config . backend . qnn . enabled ,
737+ use_qnn_sha = llm_config . backend . qnn . use_sha ,
738+ optimized_rotation_path = llm_config . backend . qnn .optimized_rotation_path ,
739+ mps = llm_config . backend . mps . enabled ,
740+ coreml = llm_config . backend . coreml . enabled ,
741+ coreml_ios = llm_config . backend . coreml . ios ,
742+ vulkan = llm_config . backend . vulkan . enabled ,
743+ use_qat = llm_config . quantization .use_qat ,
744+ use_lora = llm_config . base .use_lora ,
745+ preq_mode = llm_config . base .preq_mode ,
746+ preq_group_size = llm_config . base .preq_group_size ,
747+ preq_embedding_quantize = llm_config . base .preq_embedding_quantize ,
743748 )
744749 )
745750
746751 return edge_manager
747752
748753
749- def get_quantizer_and_quant_params (args ):
754+ def get_quantizer_and_quant_params (llm_config ):
750755 pt2e_quant_params = get_pt2e_quantization_params (
751- args . pt2e_quantize , args . quantization_mode
756+ llm_config . quantization . pt2e_quantize , llm_config . quantization . qmode
752757 )
753- quantizers = get_pt2e_quantizers (pt2e_quant_params , args .so_library )
758+ quantizers = get_pt2e_quantizers (pt2e_quant_params , llm_config . export .so_library )
754759 quant_dtype = None
755- if args . qnn and args .pt2e_quantize :
760+ if llm_config . backend . qnn . enabled and llm_config . quantization .pt2e_quantize :
756761 assert len (quantizers ) == 0 , "Should not enable both xnnpack and qnn"
757762 qnn_quantizer , quant_dtype = get_qnn_quantizer (
758- args . pt2e_quantize , args . quantization_mode
763+ llm_config . quantization . pt2e_quantize , llm_config . quantization . qmode
759764 )
760765 quantizers .append (qnn_quantizer )
761- if args . coreml and args .pt2e_quantize :
766+ if llm_config . backend . coreml . enabled and llm_config . quantization .pt2e_quantize :
762767 assert len (quantizers ) == 0 , "Should not enable both xnnpack / qnn and coreml"
763- coreml_quantizer = get_coreml_quantizer (args .pt2e_quantize )
768+ coreml_quantizer = get_coreml_quantizer (llm_config . quantization .pt2e_quantize )
764769 quantizers .append (coreml_quantizer )
765- if args . vulkan and args .pt2e_quantize :
770+ if llm_config . backend . vulkan . enabled and llm_config . quantization .pt2e_quantize :
766771 assert (
767772 len (quantizers ) == 0
768773 ), "Should not enable both vulkan and other quantizers"
769- vulkan_quantizer = get_vulkan_quantizer (args .pt2e_quantize )
774+ vulkan_quantizer = get_vulkan_quantizer (llm_config . quantization .pt2e_quantize )
770775 quantizers .append (vulkan_quantizer )
771776 logging .info (f"Applying quantizers: { quantizers } " )
772777 return pt2e_quant_params , quantizers , quant_dtype
@@ -789,28 +794,32 @@ def _qmode_type(value):
789794 )
790795
791796
792- def _validate_args (args ):
797+ def _validate_args (llm_config ):
793798 """
794799 TODO: Combine all the backends under --backend args
795800 """
796801
797- if args . max_context_length < args .max_seq_length :
802+ if llm_config . export . max_context_length < llm_config . export .max_seq_length :
798803 raise ValueError (
799- f"max_context_length { args . max_context_length } must be >= max_seq_len { args .max_seq_length } . max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
804+ f"max_context_length { llm_config . export . max_context_length } must be >= max_seq_len { llm_config . export .max_seq_length } . max_context_length impacts kv cache size that is used to remember history, while max_seq_length refers to user prompt length. Please use --max_context_length to specify context length."
800805 )
801- if args .enable_dynamic_shape and (args .coreml or args .mps or args .qnn ):
806+ if llm_config .model .enable_dynamic_shape and (
807+ llm_config .backend .coreml .enabled or
808+ llm_config .backend .mps .enabled or
809+ llm_config .backend .qnn .enabled
810+ ):
802811 raise ValueError (
803812 "Dynamic shape is not supported with coreml, MPS or qnn backends."
804813 " Please use --disable_dynamic_shape."
805814 )
806815
807- if args . num_sharding > 0 and not args . qnn :
816+ if llm_config . backend . qnn . num_sharding > 0 and not llm_config . backend . qnn . enabled :
808817 raise ValueError ("Model shard is only supported with qnn backend now." )
809818
810- if args .use_shared_embedding :
819+ if llm_config . model .use_shared_embedding :
811820 if not (
812- args .embedding_quantize is not None
813- and args .embedding_quantize .startswith ("torchao:" )
821+ llm_config . quantization .embedding_quantize is not None
822+ and llm_config . quantization .embedding_quantize .startswith ("torchao:" )
814823 ):
815824 raise ValueError (
816825 "Shared embedding is only supported with torchao quantization."
@@ -1038,38 +1047,39 @@ def _to_edge_and_lower_llama( # noqa: C901
10381047 return builder
10391048
10401049
1041- def _export_llama (args ) -> LLMEdgeManager : # noqa: C901
1042- _validate_args (args )
1050+ def _export_llama (llm_config , args ) -> LLMEdgeManager : # noqa: C901
1051+ _validate_args (llm_config )
10431052
1044- pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (args )
1053+ pt2e_quant_params , quantizers , quant_dtype = get_quantizer_and_quant_params (llm_config )
10451054
10461055 additional_passes = []
1047- if args . model in TORCHTUNE_DEFINED_MODELS :
1056+ if llm_config . base . model_class in TORCHTUNE_DEFINED_MODELS :
10481057 additional_passes = [InitializedMutableBufferPass (["kv_cache_pos" ])]
10491058
10501059 # export_to_edge
1051- builder_exported = _prepare_for_llama_export (args ).export ()
1060+ builder_exported = _prepare_for_llama_export (llm_config , args ).export ()
10521061 builder_exported .run_canonical_optimizations ()
10531062 modelname = builder_exported .modelname
10541063
1055- if args .export_only :
1064+ if llm_config . export .export_only :
10561065 exit ()
10571066
10581067 if pt2e_quant_params is not None and pt2e_quant_params .quantize_linear is not None :
1059- # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
1068+ # Force xnnpack to be true if pt2e_quant_params is not None and xnnpack is False
1069+ llm_config .backend .xnnpack .enabled = True
10601070 args .xnnpack = True
10611071
1062- if args . xnnpack :
1072+ if llm_config . backend . xnnpack . enabled :
10631073 builder = _to_edge_and_lower_llama_xnnpack (
10641074 builder_exported ,
10651075 modelname ,
10661076 additional_passes ,
10671077 pt2e_quant_params ,
10681078 quantizers ,
10691079 quant_dtype ,
1070- xnnpack_extended_ops = args . xnnpack_extended_ops ,
1071- generate_etrecord = args .generate_etrecord ,
1072- verbose = args .verbose ,
1080+ xnnpack_extended_ops = llm_config . backend . xnnpack . extended_ops ,
1081+ generate_etrecord = llm_config . debug .generate_etrecord ,
1082+ verbose = llm_config . debug .verbose ,
10731083 )
10741084 else :
10751085 builder = _to_edge_and_lower_llama (
@@ -1079,33 +1089,33 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
10791089 pt2e_quant_params ,
10801090 quantizers ,
10811091 quant_dtype ,
1082- vulkan = args . vulkan ,
1083- mps = args . mps ,
1084- coreml = args . coreml ,
1085- qnn = args . qnn ,
1086- dtype_override = args .dtype_override ,
1087- enable_dynamic_shape = args .enable_dynamic_shape ,
1088- use_kv_cache = args .use_kv_cache ,
1089- embedding_quantize = args .embedding_quantize ,
1090- pt2e_quantize = args .pt2e_quantize ,
1091- coreml_ios = args . coreml_ios ,
1092- coreml_quantize = args . coreml_quantize ,
1093- coreml_compute_units = args . coreml_compute_units ,
1094- use_qnn_sha = args . use_qnn_sha ,
1095- num_sharding = args .num_sharding ,
1096- soc_model = args .soc_model ,
1097- generate_etrecord = args .generate_etrecord ,
1098- verbose = args .verbose ,
1092+ vulkan = llm_config . backend . vulkan . enabled ,
1093+ mps = llm_config . backend . mps . enabled ,
1094+ coreml = llm_config . backend . coreml . enabled ,
1095+ qnn = llm_config . backend . qnn . enabled ,
1096+ dtype_override = llm_config . model .dtype_override ,
1097+ enable_dynamic_shape = llm_config . model .enable_dynamic_shape ,
1098+ use_kv_cache = llm_config . model .use_kv_cache ,
1099+ embedding_quantize = llm_config . quantization .embedding_quantize ,
1100+ pt2e_quantize = llm_config . quantization .pt2e_quantize ,
1101+ coreml_ios = llm_config . backend . coreml . ios_version ,
1102+ coreml_quantize = llm_config . backend . coreml . quantize ,
1103+ coreml_compute_units = llm_config . backend . coreml . compute_units ,
1104+ use_qnn_sha = llm_config . backend . qnn . use_sha ,
1105+ num_sharding = llm_config . backend . qnn .num_sharding ,
1106+ soc_model = llm_config . backend . qnn .soc_model ,
1107+ generate_etrecord = llm_config . debug .generate_etrecord ,
1108+ verbose = llm_config . debug .verbose ,
10991109 )
11001110
1101- if args .profile_memory :
1111+ if llm_config . debug .profile_memory :
11021112 generate_memory_trace (builder .export_program , "memory_profile.json" )
11031113
11041114 if builder .dtype == DType .fp16 :
11051115 modelname = f"{ modelname } _h"
11061116
1107- if args .output_name :
1108- modelname = args .output_name
1117+ if llm_config . export .output_name :
1118+ modelname = llm_config . export .output_name
11091119 if modelname .endswith (".pte" ):
11101120 output_file = modelname
11111121 modelname = modelname [:- 4 ]
0 commit comments