@@ -561,11 +561,13 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
561561 output_dir_path = canonical_path (args .output_dir , dir = True )
562562 weight_type = WeightType .FAIRSEQ2 if args .fairseq2 else WeightType .LLAMA
563563
564- # Conver dtype override string to actual type.
565- if args .quantization_mode in ["8da4w" , "8da4w-gptq" ]:
564+ # Convert dtype override string to actual type.
565+ if args .dtype_override is not None :
566+ dtype_override = DType [args .dtype_override ]
567+ elif args .quantization_mode in ["8da4w" , "8da4w-gptq" ]:
566568 dtype_override = DType ["fp16" ]
567569 else :
568- dtype_override = DType [ args . dtype_override ]
570+ dtype_override = None
569571
570572 edge_manager = _load_llama_model (
571573 args .model ,
@@ -590,7 +592,16 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
590592 metadata_str = args .metadata ,
591593 dtype_override = dtype_override ,
592594 args = args ,
593- ).set_output_dir (output_dir_path ).source_transform (_get_source_transforms (args .model , dtype_override , args ))
595+ )
596+ # Assumes the checkpoint has uniform dtype.
597+ checkpoint_dtype = next (edge_manager .model .parameters ()).dtype
598+ print (f"checkpoint dtype: { checkpoint_dtype } " )
599+ # We want to quantize with the model in the checkpoint dtype before casting to dtype_override.
600+ edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
601+ _get_source_transforms (
602+ args .model , DType .from_torch_dtype (checkpoint_dtype ), args
603+ )
604+ )
594605
595606 # Override dtype of the model as specified by the user args.
596607 if dtype_override :
@@ -977,11 +988,21 @@ def _load_llama_model(
977988 )
978989 )
979990
991+ if dtype_override :
992+ assert isinstance (
993+ dtype_override , DType
994+ ), "Override dtype needs to be of type <DType>"
995+ dtype = dtype_override
996+ else :
997+ checkpoint_dtype = next (model .parameters ()).dtype
998+ dtype = DType .from_torch_dtype (checkpoint_dtype )
999+ logging .info (f"Loaded model with dtype={ dtype } " )
1000+
9801001 return LLMEdgeManager (
9811002 model = model ,
9821003 modelname = modelname ,
9831004 max_seq_len = model .max_seq_len ,
984- dtype = dtype_override ,
1005+ dtype = dtype ,
9851006 use_kv_cache = use_kv_cache ,
9861007 generate_full_logits = generate_full_logits ,
9871008 example_inputs = example_inputs ,
0 commit comments