@@ -322,7 +322,7 @@ def build_args_parser() -> argparse.ArgumentParser:
322322 default = "fp32" ,
323323 type = str ,
324324 choices = ["fp32" , "fp16" , "bf16" ],
325- help = "Override the dtype of the model (default is the checkpoint dtype) ."
325+ help = "Provide the dtype of the model."
326326 "Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16." ,
327327 )
328328
@@ -565,43 +565,40 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
565565 output_dir_path = canonical_path (args .output_dir , dir = True )
566566 weight_type = WeightType .FAIRSEQ2 if args .fairseq2 else WeightType .LLAMA
567567
568- # dtype override
569- if args .dtype_override is not None :
570- dtype_override = DType [args .dtype_override ]
571- elif args .quantization_mode in ["8da4w" , "8da4w-gptq" ]:
572- dtype_override = DType ["fp16" ]
573- else :
574- dtype_override = None
575-
576- return (
577- _load_llama_model (
578- args .model ,
579- checkpoint = checkpoint_path ,
580- checkpoint_dir = checkpoint_dir ,
581- params_path = params_path ,
582- use_kv_cache = args .use_kv_cache ,
583- use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
584- generate_full_logits = args .generate_full_logits ,
585- weight_type = weight_type ,
586- enable_dynamic_shape = args .enable_dynamic_shape ,
587- calibration_tasks = args .calibration_tasks ,
588- calibration_limit = args .calibration_limit ,
589- calibration_seq_length = args .calibration_seq_length ,
590- calibration_data = args .calibration_data ,
591- tokenizer_path = args .tokenizer_path ,
592- verbose = args .verbose ,
593- max_seq_len = args .max_seq_length ,
594- max_context_len = args .max_context_length ,
595- input_prune_map_path = args .input_prune_map ,
596- output_prune_map_path = args .output_prune_map ,
597- metadata_str = args .metadata ,
598- dtype_override = dtype_override ,
599- args = args ,
600- )
601- .set_output_dir (output_dir_path )
602- .source_transform (_get_source_transforms (args .model , dtype_override , args ))
568+ # Convert dtype override string arg to actual type.
569+ dtype_override = DType [args .dtype_override ]
570+
571+ edge_manager = _load_llama_model (
572+ args .model ,
573+ checkpoint = checkpoint_path ,
574+ checkpoint_dir = checkpoint_dir ,
575+ params_path = params_path ,
576+ use_kv_cache = args .use_kv_cache ,
577+ use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
578+ generate_full_logits = args .generate_full_logits ,
579+ weight_type = weight_type ,
580+ enable_dynamic_shape = args .enable_dynamic_shape ,
581+ calibration_tasks = args .calibration_tasks ,
582+ calibration_limit = args .calibration_limit ,
583+ calibration_seq_length = args .calibration_seq_length ,
584+ calibration_data = args .calibration_data ,
585+ tokenizer_path = args .tokenizer_path ,
586+ verbose = args .verbose ,
587+ max_seq_len = args .max_seq_length ,
588+ max_context_len = args .max_context_length ,
589+ input_prune_map_path = args .input_prune_map ,
590+ output_prune_map_path = args .output_prune_map ,
591+ metadata_str = args .metadata ,
592+ dtype_override = dtype_override ,
593+ args = args ,
603594 )
604595
596+ # At this point, the model is loaded in the default fp32.
597+ edge_manager .model = edge_manager .model .to (dtype = dtype_override .to_torch_dtype ())
598+ edge_manager .set_output_dir (output_dir_path ).source_transform (_get_source_transforms (args .model , dtype_override , args ))
599+
600+ return edge_manager
601+
605602
606603def get_quantizer_and_quant_params (args ):
607604 pt2e_quant_params = get_pt2e_quantization_params (
@@ -1006,6 +1003,8 @@ def _load_llama_model(
10061003 else :
10071004 raise ValueError (f"{ modelname } is not a valid Llama model." )
10081005
1006+ torch_dtype = dtype_override .to_torch_dtype () if dtype_override else None
1007+
10091008 model , example_inputs , example_kwarg_inputs , dynamic_shapes = (
10101009 EagerModelFactory .create_model (
10111010 module_name ,
@@ -1022,41 +1021,16 @@ def _load_llama_model(
10221021 enable_dynamic_shape = enable_dynamic_shape ,
10231022 input_prune_map_path = input_prune_map_path ,
10241023 output_prune_map_path = output_prune_map_path ,
1024+ dtype = torch_dtype ,
10251025 args = args ,
10261026 )
10271027 )
1028- if dtype_override :
1029- assert isinstance (
1030- dtype_override , DType
1031- ), "Override dtype needs to be of type <DType>"
1032- torch_dtype = dtype_override .to_torch_dtype ()
1033- logging .info (f"model.to { torch_dtype } " )
1034- model = model .to (dtype = torch_dtype )
1035- dtype = dtype_override
1036- else :
1037- state_dict = model .state_dict ()
1038- dtype = state_dict [next (iter (state_dict ))].dtype
1039- assert dtype in [
1040- torch .bfloat16 ,
1041- torch .float16 ,
1042- torch .float32 ,
1043- ], f"Only support bfloat16, fp16 or fp32 got { dtype } "
1044- logging .info (f"Loaded model with dtype={ dtype } " )
1045-
1046- if dtype == torch .bfloat16 :
1047- dtype = DType .bf16
1048- elif dtype == torch .float16 :
1049- dtype = DType .fp16
1050- elif dtype == torch .float32 :
1051- dtype = DType .fp32
1052- else :
1053- raise ValueError (f"Unsupported dtype { dtype } " )
10541028
10551029 return LLMEdgeManager (
10561030 model = model ,
10571031 modelname = modelname ,
10581032 max_seq_len = model .max_seq_len ,
1059- dtype = dtype ,
1033+ dtype = dtype_override ,
10601034 use_kv_cache = use_kv_cache ,
10611035 generate_full_logits = generate_full_logits ,
10621036 example_inputs = example_inputs ,
0 commit comments