@@ -800,26 +800,26 @@ def _load_llama_model(
800800 modelname = "llama2"
801801 model_class_name = "Llama2Model"
802802 elif modelname in TORCHTUNE_DEFINED_MODELS :
803- raise NotImplementedError ("Torchtune Llama models are not yet supported in ExecuTorch export." )
803+ raise NotImplementedError (
804+ "Torchtune Llama models are not yet supported in ExecuTorch export."
805+ )
804806 else :
805807 raise ValueError (f"{ modelname } is not a valid Llama model." )
806808
807- model , example_inputs , example_kwarg_inputs , _ = (
808- EagerModelFactory .create_model (
809- modelname ,
810- model_class_name ,
811- checkpoint = checkpoint ,
812- checkpoint_dir = checkpoint_dir ,
813- params = params_path ,
814- use_kv_cache = use_kv_cache ,
815- use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
816- generate_full_logits = generate_full_logits ,
817- fairseq2 = weight_type == WeightType .FAIRSEQ2 ,
818- max_seq_len = max_seq_len ,
819- enable_dynamic_shape = enable_dynamic_shape ,
820- output_prune_map_path = output_prune_map_path ,
821- args = args ,
822- )
809+ model , example_inputs , example_kwarg_inputs , _ = EagerModelFactory .create_model (
810+ modelname ,
811+ model_class_name ,
812+ checkpoint = checkpoint ,
813+ checkpoint_dir = checkpoint_dir ,
814+ params = params_path ,
815+ use_kv_cache = use_kv_cache ,
816+ use_sdpa_with_kv_cache = use_sdpa_with_kv_cache ,
817+ generate_full_logits = generate_full_logits ,
818+ fairseq2 = weight_type == WeightType .FAIRSEQ2 ,
819+ max_seq_len = max_seq_len ,
820+ enable_dynamic_shape = enable_dynamic_shape ,
821+ output_prune_map_path = output_prune_map_path ,
822+ args = args ,
823823 )
824824 if dtype_override :
825825 assert isinstance (
0 commit comments