@@ -370,6 +370,28 @@ def build_args_parser() -> argparse.ArgumentParser:
370370 help = "Use SpinQuant for better quantization performance. Only support cuda and native." ,
371371 )
372372
373+ parser .add_argument (
374+ "--spin_qmode" ,
375+ type = str ,
376+ default = None ,
377+ choices = ["8da4w" ],
378+ help = "Quantization mode for SpinQuant. Only support 8da4w right now." ,
379+ )
380+
381+ parser .add_argument (
382+ "--spin_group_size" ,
383+ type = int ,
384+ default = 32 ,
385+ help = "group_size for SpinQuant weight quantization" ,
386+ )
387+
388+ parser .add_argument (
389+ "--spin_embedding_quantize" ,
390+ default = "8,0" ,
391+ type = str ,
392+ help = "type of embedding quantization for SpinQuant, '<bitwidth>,<groupsize>', e.g., '8,1024'." ,
393+ )
394+
373395 parser .add_argument (
374396 "--output_prune_map" ,
375397 default = None ,
@@ -466,10 +488,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
466488 max_seq_len = args .max_seq_length ,
467489 output_prune_map_path = args .output_prune_map ,
468490 metadata_str = args .metadata ,
491+ dtype_override = dtype_override ,
469492 args = args ,
470493 )
471494 .set_output_dir (output_dir_path )
472- .to_dtype (dtype_override )
473495 .source_transform (_get_source_transforms (modelname , dtype_override , args ))
474496 )
475497
@@ -691,6 +713,7 @@ def _load_llama_model(
691713 max_seq_len : int = 128 ,
692714 output_prune_map_path : Optional [str ] = None ,
693715 metadata_str : Optional [str ] = None ,
716+ dtype_override : Optional [DType ] = None ,
694717 args ,
695718) -> "LLMEdgeManager" :
696719 """
@@ -720,23 +743,32 @@ def _load_llama_model(
720743 output_prune_map_path = output_prune_map_path ,
721744 args = args ,
722745 )
723- state_dict = model .state_dict ()
724- dtype = state_dict [next (iter (state_dict ))].dtype
725- assert dtype in [
726- torch .bfloat16 ,
727- torch .float16 ,
728- torch .float32 ,
729- ], f"Only support bfloat16, fp16 or fp32 got { dtype } "
730- logging .info (f"Loaded model with dtype={ dtype } " )
731-
732- if dtype == torch .bfloat16 :
733- dtype = DType .bf16
734- elif dtype == torch .float16 :
735- dtype = DType .fp16
736- elif dtype == torch .float32 :
737- dtype = DType .fp32
746+ if dtype_override :
747+ assert isinstance (
748+ dtype_override , DType
749+ ), "Override dtype needs to be of type <DType>"
750+ torch_dtype = dtype_override .to_torch_dtype ()
751+ logging .info (f"model.to { torch_dtype } " )
752+ model = model .to (dtype = torch_dtype )
753+ dtype = dtype_override
738754 else :
739- raise ValueError (f"Unsupported dtype { dtype } " )
755+ state_dict = model .state_dict ()
756+ dtype = state_dict [next (iter (state_dict ))].dtype
757+ assert dtype in [
758+ torch .bfloat16 ,
759+ torch .float16 ,
760+ torch .float32 ,
761+ ], f"Only support bfloat16, fp16 or fp32 got { dtype } "
762+ logging .info (f"Loaded model with dtype={ dtype } " )
763+
764+ if dtype == torch .bfloat16 :
765+ dtype = DType .bf16
766+ elif dtype == torch .float16 :
767+ dtype = DType .fp16
768+ elif dtype == torch .float32 :
769+ dtype = DType .fp32
770+ else :
771+ raise ValueError (f"Unsupported dtype { dtype } " )
740772
741773 return LLMEdgeManager (
742774 model = model ,
@@ -769,21 +801,9 @@ def _get_source_transforms( # noqa
769801 modelname : str , dtype_override : Optional [DType ], args
770802) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
771803 transforms = []
772- if args .quantization_mode :
773- modelname = f"{ modelname } _q"
774- if args .use_spin_quant is None :
775- transforms .append (
776- get_quant_weight_transform (args , dtype_override , verbose_export ())
777- )
778- # For SpinQuant, the checkpoints are already quantized
779- # aka the weights have corresponding scales value,
780- # So that means, we don't need to apply quantization
781- # transform. However, we will still need to apply
782- # transformations that change the model structure to
783- # match the checkpoint format.
784- # transform_for_spinquant() will apply these transformations
785- # later in model.py file.
786- elif args .use_spin_quant == "cuda" :
804+
805+ if args .use_spin_quant :
806+ if args .use_spin_quant == "cuda" :
787807 from .source_transformation .spin_quant import (
788808 inject_fast_hadamard_transform_cuda_for_spin_quant ,
789809 )
@@ -796,7 +816,35 @@ def _get_source_transforms( # noqa
796816
797817 transforms .append (inject_fast_hadamard_transform_native_for_spin_quant )
798818
819+ if args .quantization_mode :
820+ """
821+ When this option is selected, it finds all linear layers and transforms
822+ into quantized linear equivalent module.
823+
824+ There are cases where the checkpoint is already quantized, for example
825+ on use_spin_quant is enabled. In that case, it will do the appropriate
826+ transformations based on the given checkpoint first. In those cases,
827+ if quantization_mode is enabled, it will quantize any remaining linear
828+ ops that is not quantized.
829+
830+ There are cases where this may be a no-op, namely, if all linears are
831+ quantized in the checkpoint.
832+ """
833+ modelname = f"{ modelname } _q"
834+ transforms .append (
835+ get_quant_weight_transform (args , dtype_override , verbose_export ())
836+ )
837+
799838 if args .embedding_quantize :
839+ """
840+ When this option is selected, it finds all embedding layers and transforms
841+ into quantized embedding equivalent module.
842+
843+ There are cases where the checkpoint is already quantized, for example
844+ on use_spin_quant is enabled. In that case, it will do the appropriate
845+ transformations based on the given checkpoint first. In those cases,
846+ this wil be a no-op.
847+ """
800848 modelname = f"{ modelname } _e"
801849 transforms .append (get_quant_embedding_transform (args ))
802850
0 commit comments