@@ -561,42 +561,49 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
561561 output_dir_path = canonical_path (args .output_dir , dir = True )
562562 weight_type = WeightType .FAIRSEQ2 if args .fairseq2 else WeightType .LLAMA
563563
564- # dtype override
565- if args .dtype_override is not None :
566- dtype_override = DType [args .dtype_override ]
567- elif args .quantization_mode in ["8da4w" , "8da4w-gptq" ]:
564+ # Conver dtype override string to actual type.
565+ if args .quantization_mode in ["8da4w" , "8da4w-gptq" ]:
568566 dtype_override = DType ["fp16" ]
569567 else :
570- dtype_override = None
568+ dtype_override = DType [ args . dtype_override ]
571569
572- return (
573- _load_llama_model (
574- args .model ,
575- checkpoint = checkpoint_path ,
576- checkpoint_dir = checkpoint_dir ,
577- params_path = params_path ,
578- use_kv_cache = args .use_kv_cache ,
579- use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
580- generate_full_logits = args .generate_full_logits ,
581- weight_type = weight_type ,
582- enable_dynamic_shape = args .enable_dynamic_shape ,
583- calibration_tasks = args .calibration_tasks ,
584- calibration_limit = args .calibration_limit ,
585- calibration_seq_length = args .calibration_seq_length ,
586- calibration_data = args .calibration_data ,
587- tokenizer_path = args .tokenizer_path ,
588- verbose = args .verbose ,
589- max_seq_len = args .max_seq_length ,
590- max_context_len = args .max_context_length ,
591- input_prune_map_path = args .input_prune_map ,
592- output_prune_map_path = args .output_prune_map ,
593- metadata_str = args .metadata ,
594- dtype_override = dtype_override ,
595- args = args ,
596- )
597- .set_output_dir (output_dir_path )
598- .source_transform (_get_source_transforms (args .model , dtype_override , args ))
570+ edge_manager = _load_llama_model (
571+ args .model ,
572+ checkpoint = checkpoint_path ,
573+ checkpoint_dir = checkpoint_dir ,
574+ params_path = params_path ,
575+ use_kv_cache = args .use_kv_cache ,
576+ use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
577+ generate_full_logits = args .generate_full_logits ,
578+ weight_type = weight_type ,
579+ enable_dynamic_shape = args .enable_dynamic_shape ,
580+ calibration_tasks = args .calibration_tasks ,
581+ calibration_limit = args .calibration_limit ,
582+ calibration_seq_length = args .calibration_seq_length ,
583+ calibration_data = args .calibration_data ,
584+ tokenizer_path = args .tokenizer_path ,
585+ verbose = args .verbose ,
586+ max_seq_len = args .max_seq_length ,
587+ max_context_len = args .max_context_length ,
588+ input_prune_map_path = args .input_prune_map ,
589+ output_prune_map_path = args .output_prune_map ,
590+ metadata_str = args .metadata ,
591+ dtype_override = dtype_override ,
592+ args = args ,
599593 )
594+ .set_output_dir (output_dir_path )
595+ .source_transform (_get_source_transforms (args .model , dtype_override , args ))
596+
597+ # Override dtype of the model as specified by the user args.
598+ if dtype_override :
599+ assert isinstance (
600+ dtype_override , DType
601+ ), "Override dtype needs to be of type <DType>"
602+ torch_dtype = dtype_override .to_torch_dtype ()
603+ logging .info (f"model.to { torch_dtype } " )
604+ edge_manager .model = edge_manager .model .to (dtype = torch_dtype )
605+
606+ return edge_manager
600607
601608
602609def get_quantizer_and_quant_params (args ):
@@ -971,38 +978,12 @@ def _load_llama_model(
971978 args = args ,
972979 )
973980 )
974- if dtype_override :
975- assert isinstance (
976- dtype_override , DType
977- ), "Override dtype needs to be of type <DType>"
978- torch_dtype = dtype_override .to_torch_dtype ()
979- logging .info (f"model.to { torch_dtype } " )
980- model = model .to (dtype = torch_dtype )
981- dtype = dtype_override
982- else :
983- state_dict = model .state_dict ()
984- dtype = state_dict [next (iter (state_dict ))].dtype
985- assert dtype in [
986- torch .bfloat16 ,
987- torch .float16 ,
988- torch .float32 ,
989- ], f"Only support bfloat16, fp16 or fp32 got { dtype } "
990- logging .info (f"Loaded model with dtype={ dtype } " )
991-
992- if dtype == torch .bfloat16 :
993- dtype = DType .bf16
994- elif dtype == torch .float16 :
995- dtype = DType .fp16
996- elif dtype == torch .float32 :
997- dtype = DType .fp32
998- else :
999- raise ValueError (f"Unsupported dtype { dtype } " )
1000981
1001982 return LLMEdgeManager (
1002983 model = model ,
1003984 modelname = modelname ,
1004985 max_seq_len = model .max_seq_len ,
1005- dtype = dtype ,
986+ dtype = dtype_override ,
1006987 use_kv_cache = use_kv_cache ,
1007988 generate_full_logits = generate_full_logits ,
1008989 example_inputs = example_inputs ,
0 commit comments