@@ -157,7 +157,8 @@ def build_model(
157157    argString  =  f"--model { model } { checkpoint } { params } { extra_opts } { output_dir }  
158158    parser  =  build_args_parser ()
159159    args  =  parser .parse_args (shlex .split (argString ))
160-     return  export_llama (args )
160+     llm_config  =  convert_args_to_llm_config (args )
161+     return  export_llama (llm_config )
161162
162163
163164def  parse_list_of_ints (s ):
@@ -579,15 +580,10 @@ def export_llama(
579580) ->  str :
580581    if  isinstance (export_options , argparse .Namespace ):
581582        # Legacy CLI. 
582-         args  =  export_options 
583583        llm_config  =  convert_args_to_llm_config (export_options )
584584    elif  isinstance (export_options , DictConfig ):
585585        # Hydra CLI. 
586586        llm_config  =  export_options 
587-         # Create an args object for backward compatibility during transition 
588-         args  =  argparse .Namespace ()
589-         for  key , value  in  llm_config .items ():
590-             setattr (args , key , value )
591587    else :
592588        raise  ValueError (
593589            "Input to export_llama must be either of type argparse.Namespace or LlmConfig" 
@@ -626,7 +622,7 @@ def export_llama(
626622            from  executorch .util .python_profiler  import  CProfilerFlameGraph 
627623
628624            with  CProfilerFlameGraph (llm_config .debug .profile_path ):
629-                 builder  =  _export_llama (llm_config ,  args )
625+                 builder  =  _export_llama (llm_config )
630626                assert  (
631627                    filename  :=  builder .get_saved_pte_filename ()
632628                ) is  not None , "Fail to get file name from builder" 
@@ -637,14 +633,14 @@ def export_llama(
637633            )
638634            return  "" 
639635    else :
640-         builder  =  _export_llama (llm_config ,  args )
636+         builder  =  _export_llama (llm_config )
641637        assert  (
642638            filename  :=  builder .get_saved_pte_filename ()
643639        ) is  not None , "Fail to get file name from builder" 
644640        return  filename 
645641
646642
647- def  _prepare_for_llama_export (llm_config ,  args ) ->  LLMEdgeManager :
643+ def  _prepare_for_llama_export (llm_config :  LlmConfig ) ->  LLMEdgeManager :
648644    """ 
649645    Helper function for export_llama. Loads the model from checkpoint and params, 
650646    and sets up a LLMEdgeManager with initial transforms and dtype conversion. 
@@ -672,7 +668,7 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
672668    dtype_override  =  DType [llm_config .model .dtype_override ]
673669
674670    edge_manager  =  _load_llama_model (
675-         llm_config . base . model_class ,
671+         llm_config ,
676672        checkpoint = checkpoint_path ,
677673        checkpoint_dir = checkpoint_dir ,
678674        params_path = params_path ,
@@ -695,7 +691,6 @@ def _prepare_for_llama_export(llm_config, args) -> LLMEdgeManager:
695691        dtype_override = dtype_override ,
696692        use_qnn = llm_config .backend .qnn .enabled ,
697693        export_only = llm_config .export .export_only ,
698-         args = args ,
699694    )
700695
701696    # At this point, the model is loaded in the default fp32. 
@@ -1054,7 +1049,7 @@ def _to_edge_and_lower_llama(  # noqa: C901
10541049    return  builder 
10551050
10561051
1057- def  _export_llama (llm_config ,  args ) ->  LLMEdgeManager :  # noqa: C901 
1052+ def  _export_llama (llm_config :  LlmConfig ) ->  LLMEdgeManager :  # noqa: C901 
10581053    _validate_args (llm_config )
10591054
10601055    pt2e_quant_params , quantizers , quant_dtype  =  get_quantizer_and_quant_params (
@@ -1066,7 +1061,7 @@ def _export_llama(llm_config, args) -> LLMEdgeManager:  # noqa: C901
10661061        additional_passes  =  [InitializedMutableBufferPass (["kv_cache_pos" ])]
10671062
10681063    # export_to_edge 
1069-     builder_exported  =  _prepare_for_llama_export (llm_config ,  args ).export ()
1064+     builder_exported  =  _prepare_for_llama_export (llm_config ).export ()
10701065    builder_exported .run_canonical_optimizations ()
10711066    modelname  =  builder_exported .modelname 
10721067
@@ -1174,7 +1169,7 @@ def _load_llama_model_metadata(
11741169
11751170
11761171def  _load_llama_model (
1177-     modelname :  str   =   "llama3" ,
1172+     llm_config :  LlmConfig ,
11781173    * ,
11791174    checkpoint : Optional [str ] =  None ,
11801175    checkpoint_dir : Optional [str ] =  None ,
@@ -1198,8 +1193,6 @@ def _load_llama_model(
11981193    dtype_override : Optional [DType ] =  None ,
11991194    use_qnn : bool  =  False ,
12001195    export_only : bool  =  False ,
1201-     args ,
1202-     llm_config : Optional [LlmConfig ] =  None ,
12031196) ->  "LLMEdgeManager" :
12041197    """ 
12051198    A helper util that builds a Llama2 model. It returns a LLMEdgeManager that 
@@ -1208,6 +1201,7 @@ def _load_llama_model(
12081201        An instance of LLMEdgeManager which contains the eager mode model. 
12091202    """ 
12101203
1204+     modelname  =  llm_config .base .model_class 
12111205    if  modelname  in  EXECUTORCH_DEFINED_MODELS :
12121206        module_name  =  "llama" 
12131207        model_class_name  =  "Llama2Model"   # TODO: Change to "LlamaModel" in examples/models/llama/model.py. 
@@ -1220,13 +1214,11 @@ def _load_llama_model(
12201214    else :
12211215        raise  ValueError (f"{ modelname }  )
12221216
1223-     torch_dtype  =  dtype_override .to_torch_dtype () if  dtype_override  else  None 
1224- 
12251217    model , example_inputs , example_kwarg_inputs , dynamic_shapes  =  (
12261218        EagerModelFactory .create_model (
12271219            module_name ,
12281220            model_class_name ,
1229-             model_args = { " llm_config" :  llm_config } ,
1221+             llm_config = llm_config ,
12301222        )
12311223    )
12321224
0 commit comments