@@ -612,6 +612,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
612612 edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
613613 _get_source_transforms (
614614 args .model ,
615+ dtype_override ,
615616 DType .from_torch_dtype (edge_manager .model .checkpoint_dtype ),
616617 args ,
617618 )
@@ -1040,7 +1041,10 @@ def _load_llama_model(
10401041
10411042
10421043def _get_source_transforms ( # noqa
1043- modelname : str , dtype_override : Optional [DType ], args
1044+ modelname : str ,
1045+ dtype_override : DType ,
1046+ checkpoint_dtype : Optional [DType ],
1047+ args ,
10441048) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
10451049 transforms = []
10461050
@@ -1074,7 +1078,7 @@ def _get_source_transforms( # noqa
10741078 """
10751079 modelname = f"{ modelname } _q"
10761080 transforms .append (
1077- get_quant_weight_transform (args , dtype_override , verbose_export ())
1081+ get_quant_weight_transform (args , checkpoint_dtype , verbose_export ())
10781082 )
10791083
10801084 if args .embedding_quantize :
@@ -1088,7 +1092,7 @@ def _get_source_transforms( # noqa
10881092 this wil be a no-op.
10891093 """
10901094 modelname = f"{ modelname } _e"
1091- transforms .append (get_quant_embedding_transform (args , dtype_override ))
1095+ transforms .append (get_quant_embedding_transform (args , checkpoint_dtype ))
10921096
10931097 if args .quantization_mode or args .embedding_quantize :
10941098 transforms .append (
0 commit comments