@@ -579,49 +579,54 @@ def export_llama(
579579    if  isinstance (export_options , argparse .Namespace ):
580580        # Legacy CLI. 
581581        args  =  export_options 
582-         llm_config  =  convert_args_to_llm_config (export_options )   # noqa: F841 
582+         llm_config  =  convert_args_to_llm_config (export_options )
583583    elif  isinstance (export_options , DictConfig ):
584584        # Hydra CLI. 
585-         llm_config  =  export_options   # noqa: F841 
585+         llm_config  =  export_options 
586+         # Create an args object for backward compatibility during transition 
587+         args  =  argparse .Namespace ()
588+         for  key , value  in  llm_config .items ():
589+             setattr (args , key , value )
586590    else :
587591        raise  ValueError (
588592            "Input to export_llama must be either of type argparse.Namespace or LlmConfig" 
589593        )
590594
591-     # TODO: refactor rest of export_llama to use llm_config instead of args. 
592- 
593595    # If a checkpoint isn't provided for an HF OSS model, download and convert the 
594596    # weights first. 
595-     if  not  args .checkpoint  and  args .model  in  HUGGING_FACE_REPO_IDS :
596-         repo_id  =  HUGGING_FACE_REPO_IDS [args .model ]
597-         if  args .model  ==  "qwen2_5" :
597+     model_name  =  llm_config .base .model_class 
598+     if  not  llm_config .base .checkpoint  and  model_name  in  HUGGING_FACE_REPO_IDS :
599+         repo_id  =  HUGGING_FACE_REPO_IDS [model_name ]
600+         if  model_name  ==  "qwen2_5" :
598601            from  executorch .examples .models .qwen2_5  import  (  # pyre-ignore[21] 
599602                convert_weights ,
600603            )
601-         elif  args . model .startswith ("qwen3" ):
604+         elif  model_name .startswith ("qwen3" ):
602605            from  executorch .examples .models .qwen3  import  (  # pyre-ignore[21] 
603606                convert_weights ,
604607            )
605-         elif  args . model  ==  "phi_4_mini" :
608+         elif  model_name  ==  "phi_4_mini" :
606609            from  executorch .examples .models .phi_4_mini  import  (  # pyre-ignore[21] 
607610                convert_weights ,
608611            )
609-         elif  args . model  ==  "smollm2" :
612+         elif  model_name  ==  "smollm2" :
610613            from  executorch .examples .models .smollm2  import  (  # pyre-ignore[21] 
611614                convert_weights ,
612615            )
613616        else :
614617            raise  ValueError (
615-                 f"Converting weights to meta format for { args . model }  
618+                 f"Converting weights to meta format for { model_name }  
616619            )
617-         args .checkpoint  =  download_and_convert_hf_checkpoint (repo_id , convert_weights )
620+         checkpoint  =  download_and_convert_hf_checkpoint (repo_id , convert_weights )
621+         llm_config .base .checkpoint  =  checkpoint 
622+         args .checkpoint  =  checkpoint 
618623
619-     if  args .profile_path  is  not None :
624+     if  llm_config . debug .profile_path  is  not None :
620625        try :
621626            from  executorch .util .python_profiler  import  CProfilerFlameGraph 
622627
623-             with  CProfilerFlameGraph (args .profile_path ):
624-                 builder  =  _export_llama (args )
628+             with  CProfilerFlameGraph (llm_config . debug .profile_path ):
629+                 builder  =  _export_llama (llm_config ,  args )
625630                assert  (
626631                    filename  :=  builder .get_saved_pte_filename ()
627632                ) is  not None , "Fail to get file name from builder" 
@@ -632,53 +637,53 @@ def export_llama(
632637            )
633638            return  "" 
634639    else :
635-         builder  =  _export_llama (args )
640+         builder  =  _export_llama (llm_config ,  args )
636641        assert  (
637642            filename  :=  builder .get_saved_pte_filename ()
638643        ) is  not None , "Fail to get file name from builder" 
639644        return  filename 
640645
641646
642- def  _prepare_for_llama_export (args ) ->  LLMEdgeManager :
647+ def  _prepare_for_llama_export (llm_config ,  args ) ->  LLMEdgeManager :
643648    """ 
644649    Helper function for export_llama. Loads the model from checkpoint and params, 
645650    and sets up a LLMEdgeManager with initial transforms and dtype conversion. 
646651
647652    Returns a LLMEdgeManager prior to calling export_to_edge with quantizers 
648653    """ 
649654    # load model from checkpoint and params.json 
650-     checkpoint_path  =  canonical_path (args . checkpoint ) if  args .checkpoint  else  None 
655+     checkpoint_path  =  canonical_path (llm_config . base . checkpoint ) if  llm_config . base .checkpoint  else  None 
651656    checkpoint_dir  =  (
652-         canonical_path (args . checkpoint_dir ) if  args .checkpoint_dir  else  None 
657+         canonical_path (llm_config . base . checkpoint_dir ) if  llm_config . base .checkpoint_dir  else  None 
653658    )
654-     params_path  =  canonical_path (args . params ) if  args .params  else  None 
655-     output_dir_path  =  canonical_path (args .output_dir , dir = True )
656-     weight_type  =  WeightType .FAIRSEQ2  if  args .fairseq2  else  WeightType .LLAMA 
659+     params_path  =  canonical_path (llm_config . base . params ) if  llm_config . base .params  else  None 
660+     output_dir_path  =  canonical_path (llm_config . export .output_dir , dir = True )
661+     weight_type  =  WeightType .FAIRSEQ2  if  llm_config . base .fairseq2  else  WeightType .LLAMA 
657662
658-     # Convert dtype override string arg  to actual type.  
659-     dtype_override  =  DType [args .dtype_override ]
663+     # Convert dtype override string to actual type 
664+     dtype_override  =  DType [llm_config . model .dtype_override ]
660665
661666    edge_manager  =  _load_llama_model (
662-         args . model ,
667+         llm_config . base . model_class ,
663668        checkpoint = checkpoint_path ,
664669        checkpoint_dir = checkpoint_dir ,
665670        params_path = params_path ,
666-         use_kv_cache = args .use_kv_cache ,
667-         use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
668-         generate_full_logits = args .generate_full_logits ,
671+         use_kv_cache = llm_config . model .use_kv_cache ,
672+         use_sdpa_with_kv_cache = llm_config . model .use_sdpa_with_kv_cache ,
673+         generate_full_logits = llm_config . debug .generate_full_logits ,
669674        weight_type = weight_type ,
670-         enable_dynamic_shape = args .enable_dynamic_shape ,
671-         calibration_tasks = args .calibration_tasks ,
672-         calibration_limit = args .calibration_limit ,
673-         calibration_seq_length = args .calibration_seq_length ,
674-         calibration_data = args .calibration_data ,
675-         tokenizer_path = args .tokenizer_path ,
676-         verbose = args .verbose ,
677-         max_seq_len = args .max_seq_length ,
678-         max_context_len = args .max_context_length ,
679-         input_prune_map_path = args .input_prune_map ,
680-         output_prune_map_path = args .output_prune_map ,
681-         metadata_str = args .metadata ,
675+         enable_dynamic_shape = llm_config . model .enable_dynamic_shape ,
676+         calibration_tasks = llm_config . quantization .calibration_tasks ,
677+         calibration_limit = llm_config . quantization .calibration_limit ,
678+         calibration_seq_length = llm_config . quantization .calibration_seq_length ,
679+         calibration_data = llm_config . quantization .calibration_data ,
680+         tokenizer_path = llm_config . base .tokenizer_path ,
681+         verbose = llm_config . debug .verbose ,
682+         max_seq_len = llm_config . export .max_seq_length ,
683+         max_context_len = llm_config . export .max_context_length ,
684+         input_prune_map_path = llm_config . model .input_prune_map ,
685+         output_prune_map_path = llm_config . model .output_prune_map ,
686+         metadata_str = llm_config . base .metadata ,
682687        dtype_override = dtype_override ,
683688        args = args ,
684689    )
@@ -710,63 +715,63 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
710715    edge_manager  =  edge_manager .set_output_dir (output_dir_path ).source_transform (
711716        _get_source_transforms (
712717            dtype_override = dtype_override ,
713-             checkpoint = args .checkpoint ,
718+             checkpoint = llm_config . base .checkpoint ,
714719            checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ),  # type: ignore 
715-             tokenizer_path = args .tokenizer_path ,
716-             use_spin_quant = args .use_spin_quant ,
717-             embedding_quantize = args .embedding_quantize ,
718-             use_shared_embedding = args .use_shared_embedding ,
719-             quantization_mode = args . quantization_mode ,
720-             group_size = args .group_size ,
721-             calibration_tasks = args .calibration_tasks ,
722-             calibration_limit = args .calibration_limit ,
723-             calibration_seq_length = args .calibration_seq_length ,
724-             expand_rope_table = args .expand_rope_table ,
720+             tokenizer_path = llm_config . base .tokenizer_path ,
721+             use_spin_quant = llm_config . quantization .use_spin_quant ,
722+             embedding_quantize = llm_config . quantization .embedding_quantize ,
723+             use_shared_embedding = llm_config . model .use_shared_embedding ,
724+             quantization_mode = llm_config . quantization . qmode ,
725+             group_size = llm_config . quantization .group_size ,
726+             calibration_tasks = llm_config . quantization .calibration_tasks ,
727+             calibration_limit = llm_config . quantization .calibration_limit ,
728+             calibration_seq_length = llm_config . quantization .calibration_seq_length ,
729+             expand_rope_table = llm_config . model .expand_rope_table ,
725730            use_custom_sdpa_with_attention_mask = getattr (
726-                 args , "use_custom_sdpa_with_attention_mask" , False 
731+                 llm_config . model , "use_custom_sdpa_with_attention_mask" , False 
727732            ),
728-             use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
729-             quantize_kv_cache = args .quantize_kv_cache ,
730-             use_kv_cache = args .use_kv_cache ,
731-             qnn = args . qnn ,
732-             use_qnn_sha = args . use_qnn_sha ,
733-             optimized_rotation_path = args .optimized_rotation_path ,
734-             mps = args . mps ,
735-             coreml = args . coreml ,
736-             coreml_ios = args . coreml_ios ,
737-             vulkan = args . vulkan ,
738-             use_qat = args .use_qat ,
739-             use_lora = args .use_lora ,
740-             preq_mode = args .preq_mode ,
741-             preq_group_size = args .preq_group_size ,
742-             preq_embedding_quantize = args .preq_embedding_quantize ,
733+             use_sdpa_with_kv_cache = llm_config . model .use_sdpa_with_kv_cache ,
734+             quantize_kv_cache = llm_config . model .quantize_kv_cache ,
735+             use_kv_cache = llm_config . model .use_kv_cache ,
736+             qnn = llm_config . backend . qnn . enabled ,
737+             use_qnn_sha = llm_config . backend . qnn . use_sha ,
738+             optimized_rotation_path = llm_config . backend . qnn .optimized_rotation_path ,
739+             mps = llm_config . backend . mps . enabled ,
740+             coreml = llm_config . backend . coreml . enabled ,
741+             coreml_ios = llm_config . backend . coreml . ios ,
742+             vulkan = llm_config . backend . vulkan . enabled ,
743+             use_qat = llm_config . quantization .use_qat ,
744+             use_lora = llm_config . base .use_lora ,
745+             preq_mode = llm_config . base .preq_mode ,
746+             preq_group_size = llm_config . base .preq_group_size ,
747+             preq_embedding_quantize = llm_config . base .preq_embedding_quantize ,
743748        )
744749    )
745750
746751    return  edge_manager 
747752
748753
749- def  get_quantizer_and_quant_params (args ):
754+ def  get_quantizer_and_quant_params (llm_config ):
750755    pt2e_quant_params  =  get_pt2e_quantization_params (
751-         args . pt2e_quantize , args . quantization_mode 
756+         llm_config . quantization . pt2e_quantize , llm_config . quantization . qmode 
752757    )
753-     quantizers  =  get_pt2e_quantizers (pt2e_quant_params , args .so_library )
758+     quantizers  =  get_pt2e_quantizers (pt2e_quant_params , llm_config . export .so_library )
754759    quant_dtype  =  None 
755-     if  args . qnn  and  args .pt2e_quantize :
760+     if  llm_config . backend . qnn . enabled  and  llm_config . quantization .pt2e_quantize :
756761        assert  len (quantizers ) ==  0 , "Should not enable both xnnpack and qnn" 
757762        qnn_quantizer , quant_dtype  =  get_qnn_quantizer (
758-             args . pt2e_quantize , args . quantization_mode 
763+             llm_config . quantization . pt2e_quantize , llm_config . quantization . qmode 
759764        )
760765        quantizers .append (qnn_quantizer )
761-     if  args . coreml  and  args .pt2e_quantize :
766+     if  llm_config . backend . coreml . enabled  and  llm_config . quantization .pt2e_quantize :
762767        assert  len (quantizers ) ==  0 , "Should not enable both xnnpack / qnn and coreml" 
763-         coreml_quantizer  =  get_coreml_quantizer (args .pt2e_quantize )
768+         coreml_quantizer  =  get_coreml_quantizer (llm_config . quantization .pt2e_quantize )
764769        quantizers .append (coreml_quantizer )
765-     if  args . vulkan  and  args .pt2e_quantize :
770+     if  llm_config . backend . vulkan . enabled  and  llm_config . quantization .pt2e_quantize :
766771        assert  (
767772            len (quantizers ) ==  0 
768773        ), "Should not enable both vulkan and other quantizers" 
769-         vulkan_quantizer  =  get_vulkan_quantizer (args .pt2e_quantize )
774+         vulkan_quantizer  =  get_vulkan_quantizer (llm_config . quantization .pt2e_quantize )
770775        quantizers .append (vulkan_quantizer )
771776    logging .info (f"Applying quantizers: { quantizers }  )
772777    return  pt2e_quant_params , quantizers , quant_dtype 
@@ -789,28 +794,32 @@ def _qmode_type(value):
789794    )
790795
791796
792- def  _validate_args (args ):
797+ def  _validate_args (llm_config ):
793798    """ 
794799    TODO: Combine all the backends under --backend args 
795800    """ 
796801
797-     if  args . max_context_length  <  args .max_seq_length :
802+     if  llm_config . export . max_context_length  <  llm_config . export .max_seq_length :
798803        raise  ValueError (
799-             f"max_context_length { args . max_context_length } { args .max_seq_length }  
804+             f"max_context_length { llm_config . export . max_context_length } { llm_config . export .max_seq_length }  
800805        )
801-     if  args .enable_dynamic_shape  and  (args .coreml  or  args .mps  or  args .qnn ):
806+     if  llm_config .model .enable_dynamic_shape  and  (
807+         llm_config .backend .coreml .enabled  or  
808+         llm_config .backend .mps .enabled  or  
809+         llm_config .backend .qnn .enabled 
810+     ):
802811        raise  ValueError (
803812            "Dynamic shape is not supported with coreml, MPS or qnn backends." 
804813            " Please use --disable_dynamic_shape." 
805814        )
806815
807-     if  args . num_sharding  >  0  and  not  args . qnn :
816+     if  llm_config . backend . qnn . num_sharding  >  0  and  not  llm_config . backend . qnn . enabled :
808817        raise  ValueError ("Model shard is only supported with qnn backend now." )
809818
810-     if  args .use_shared_embedding :
819+     if  llm_config . model .use_shared_embedding :
811820        if  not  (
812-             args .embedding_quantize  is  not None 
813-             and  args .embedding_quantize .startswith ("torchao:" )
821+             llm_config . quantization .embedding_quantize  is  not None 
822+             and  llm_config . quantization .embedding_quantize .startswith ("torchao:" )
814823        ):
815824            raise  ValueError (
816825                "Shared embedding is only supported with torchao quantization." 
@@ -1038,38 +1047,39 @@ def _to_edge_and_lower_llama(  # noqa: C901
10381047    return  builder 
10391048
10401049
1041- def  _export_llama (args ) ->  LLMEdgeManager :  # noqa: C901 
1042-     _validate_args (args )
1050+ def  _export_llama (llm_config ,  args ) ->  LLMEdgeManager :  # noqa: C901 
1051+     _validate_args (llm_config )
10431052
1044-     pt2e_quant_params , quantizers , quant_dtype  =  get_quantizer_and_quant_params (args )
1053+     pt2e_quant_params , quantizers , quant_dtype  =  get_quantizer_and_quant_params (llm_config )
10451054
10461055    additional_passes  =  []
1047-     if  args . model  in  TORCHTUNE_DEFINED_MODELS :
1056+     if  llm_config . base . model_class  in  TORCHTUNE_DEFINED_MODELS :
10481057        additional_passes  =  [InitializedMutableBufferPass (["kv_cache_pos" ])]
10491058
10501059    # export_to_edge 
1051-     builder_exported  =  _prepare_for_llama_export (args ).export ()
1060+     builder_exported  =  _prepare_for_llama_export (llm_config ,  args ).export ()
10521061    builder_exported .run_canonical_optimizations ()
10531062    modelname  =  builder_exported .modelname 
10541063
1055-     if  args .export_only :
1064+     if  llm_config . export .export_only :
10561065        exit ()
10571066
10581067    if  pt2e_quant_params  is  not None  and  pt2e_quant_params .quantize_linear  is  not None :
1059-         # Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False 
1068+         # Force xnnpack to be true if pt2e_quant_params is not None and xnnpack is False 
1069+         llm_config .backend .xnnpack .enabled  =  True 
10601070        args .xnnpack  =  True 
10611071
1062-     if  args . xnnpack :
1072+     if  llm_config . backend . xnnpack . enabled :
10631073        builder  =  _to_edge_and_lower_llama_xnnpack (
10641074            builder_exported ,
10651075            modelname ,
10661076            additional_passes ,
10671077            pt2e_quant_params ,
10681078            quantizers ,
10691079            quant_dtype ,
1070-             xnnpack_extended_ops = args . xnnpack_extended_ops ,
1071-             generate_etrecord = args .generate_etrecord ,
1072-             verbose = args .verbose ,
1080+             xnnpack_extended_ops = llm_config . backend . xnnpack . extended_ops ,
1081+             generate_etrecord = llm_config . debug .generate_etrecord ,
1082+             verbose = llm_config . debug .verbose ,
10731083        )
10741084    else :
10751085        builder  =  _to_edge_and_lower_llama (
@@ -1079,33 +1089,33 @@ def _export_llama(args) -> LLMEdgeManager:  # noqa: C901
10791089            pt2e_quant_params ,
10801090            quantizers ,
10811091            quant_dtype ,
1082-             vulkan = args . vulkan ,
1083-             mps = args . mps ,
1084-             coreml = args . coreml ,
1085-             qnn = args . qnn ,
1086-             dtype_override = args .dtype_override ,
1087-             enable_dynamic_shape = args .enable_dynamic_shape ,
1088-             use_kv_cache = args .use_kv_cache ,
1089-             embedding_quantize = args .embedding_quantize ,
1090-             pt2e_quantize = args .pt2e_quantize ,
1091-             coreml_ios = args . coreml_ios ,
1092-             coreml_quantize = args . coreml_quantize ,
1093-             coreml_compute_units = args . coreml_compute_units ,
1094-             use_qnn_sha = args . use_qnn_sha ,
1095-             num_sharding = args .num_sharding ,
1096-             soc_model = args .soc_model ,
1097-             generate_etrecord = args .generate_etrecord ,
1098-             verbose = args .verbose ,
1092+             vulkan = llm_config . backend . vulkan . enabled ,
1093+             mps = llm_config . backend . mps . enabled ,
1094+             coreml = llm_config . backend . coreml . enabled ,
1095+             qnn = llm_config . backend . qnn . enabled ,
1096+             dtype_override = llm_config . model .dtype_override ,
1097+             enable_dynamic_shape = llm_config . model .enable_dynamic_shape ,
1098+             use_kv_cache = llm_config . model .use_kv_cache ,
1099+             embedding_quantize = llm_config . quantization .embedding_quantize ,
1100+             pt2e_quantize = llm_config . quantization .pt2e_quantize ,
1101+             coreml_ios = llm_config . backend . coreml . ios_version ,
1102+             coreml_quantize = llm_config . backend . coreml . quantize ,
1103+             coreml_compute_units = llm_config . backend . coreml . compute_units ,
1104+             use_qnn_sha = llm_config . backend . qnn . use_sha ,
1105+             num_sharding = llm_config . backend . qnn .num_sharding ,
1106+             soc_model = llm_config . backend . qnn .soc_model ,
1107+             generate_etrecord = llm_config . debug .generate_etrecord ,
1108+             verbose = llm_config . debug .verbose ,
10991109        )
11001110
1101-     if  args .profile_memory :
1111+     if  llm_config . debug .profile_memory :
11021112        generate_memory_trace (builder .export_program , "memory_profile.json" )
11031113
11041114    if  builder .dtype  ==  DType .fp16 :
11051115        modelname  =  f"{ modelname }  
11061116
1107-     if  args .output_name :
1108-         modelname  =  args .output_name 
1117+     if  llm_config . export .output_name :
1118+         modelname  =  llm_config . export .output_name 
11091119        if  modelname .endswith (".pte" ):
11101120            output_file  =  modelname 
11111121            modelname  =  modelname [:- 4 ]
0 commit comments