@@ -661,10 +661,35 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
661661 logging .info (f"Checkpoint dtype: { edge_manager .model .checkpoint_dtype } " )
662662 edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
663663 _get_source_transforms (
664- modelname = args .model ,
665664 dtype_override = dtype_override ,
665+ checkpoint = args .checkpoint ,
666666 checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
667- args = args ,
667+ tokenizer_path = args .tokenizer_path ,
668+ use_spin_quant = args .use_spin_quant ,
669+ embedding_quantize = args .embedding_quantize ,
670+ use_shared_embedding = args .use_shared_embedding ,
671+ quantization_mode = args .quantization_mode ,
672+ group_size = args .group_size ,
673+ calibration_tasks = args .calibration_tasks ,
674+ calibration_limit = args .calibration_limit ,
675+ calibration_seq_length = args .calibration_seq_length ,
676+ expand_rope_table = args .expand_rope_table ,
677+ use_custom_sdpa_with_attention_mask = getattr (args , "use_custom_sdpa_with_attention_mask" , False ),
678+ use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
679+ quantize_kv_cache = args .quantize_kv_cache ,
680+ use_kv_cache = args .use_kv_cache ,
681+ qnn = args .qnn ,
682+ use_qnn_sha = args .use_qnn_sha ,
683+ optimized_rotation_path = args .optimized_rotation_path ,
684+ mps = args .mps ,
685+ coreml = args .coreml ,
686+ coreml_ios = args .coreml_ios ,
687+ vulkan = args .vulkan ,
688+ use_qat = args .use_qat ,
689+ use_lora = args .use_lora ,
690+ preq_mode = args .preq_mode ,
691+ preq_group_size = args .preq_group_size ,
692+ preq_embedding_quantize = args .preq_embedding_quantize ,
668693 )
669694 )
670695
@@ -1155,23 +1180,69 @@ def _load_llama_model(
11551180
11561181
11571182def _get_source_transforms ( # noqa
1158- modelname : str ,
11591183 dtype_override : DType ,
11601184 * ,
1185+ checkpoint : Optional [str ] = None ,
11611186 checkpoint_dtype : Optional [DType ] = None ,
1162- args ,
1187+ tokenizer_path : Optional [str ] = None ,
1188+ use_spin_quant : Optional [str ] = None ,
1189+ embedding_quantize : Optional [str ] = None ,
1190+ use_shared_embedding : bool = False ,
1191+ quantization_mode : Optional [str ] = None ,
1192+ group_size : Optional [int ] = None ,
1193+ calibration_tasks : Optional [List [str ]] = None ,
1194+ calibration_limit : Optional [int ] = None ,
1195+ calibration_seq_length : Optional [int ] = None ,
1196+ expand_rope_table : bool = False ,
1197+ use_custom_sdpa_with_attention_mask : bool = False ,
1198+ use_sdpa_with_kv_cache : bool = False ,
1199+ quantize_kv_cache : bool = False ,
1200+ use_kv_cache : bool = False ,
1201+ qnn : bool = False ,
1202+ use_qnn_sha : bool = False ,
1203+ optimized_rotation_path : Optional [str ] = None ,
1204+ mps : bool = False ,
1205+ coreml : bool = False ,
1206+ coreml_ios : int = 15 ,
1207+ vulkan : bool = False ,
1208+ use_qat : bool = False ,
1209+ use_lora : int = 0 ,
1210+ preq_mode : Optional [str ] = None ,
1211+ preq_group_size : Optional [int ] = None ,
1212+ preq_embedding_quantize : Optional [str ] = None ,
11631213) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
11641214 """
11651215 Return a list of functions that transform a graph.
11661216
11671217 Args:
1168- modelname: The name of the model.
11691218 dtype_override: The dtype to use for the model.
1219+ checkpoint: Path to the checkpoint file.
11701220 checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
11711221 it means that you want to run quantize transformations on the weights represented
11721222 in their original dtype, while the overall dtype of the model maybe something
11731223 different. If not specified, defaults to dtype_override.
1174- args: The arguments passed to the script.
1224+ tokenizer_path: Path to the tokenizer file.
1225+ use_spin_quant: Type of spin quant to use ("cuda" or "native").
1226+ embedding_quantize: Type of embedding quantization.
1227+ quantization_mode: Type of quantization mode.
1228+ expand_rope_table: Whether to expand rope table.
1229+ use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1230+ use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1231+ quantize_kv_cache: Whether to quantize KV cache.
1232+ use_kv_cache: Whether to use KV cache.
1233+ qnn: Whether to use QNN.
1234+ use_qnn_sha: Whether to use QNN SHA.
1235+ optimized_rotation_path: Path to optimized rotation.
1236+ mps: Whether to use MPS.
1237+ coreml: Whether to use CoreML.
1238+ coreml_ios: CoreML iOS version.
1239+ vulkan: Whether to use Vulkan.
1240+ use_shared_embedding: Whether to use shared embedding.
1241+ use_qat: Whether to use QAT.
1242+ use_lora: LoRA rank (0 means no LoRA).
1243+ preq_mode: Pre-quantization mode.
1244+ preq_group_size: Pre-quantization group size.
1245+ preq_embedding_quantize: Pre-quantization embedding quantize.
11751246
11761247 Returns:
11771248 A list of transformation functions.
@@ -1182,21 +1253,21 @@ def _get_source_transforms( # noqa
11821253
11831254 transforms = []
11841255
1185- if args . use_spin_quant :
1186- if args . use_spin_quant == "cuda" :
1256+ if use_spin_quant :
1257+ if use_spin_quant == "cuda" :
11871258 from .source_transformation .spin_quant import (
11881259 inject_fast_hadamard_transform_cuda_for_spin_quant ,
11891260 )
11901261
11911262 transforms .append (inject_fast_hadamard_transform_cuda_for_spin_quant )
1192- elif args . use_spin_quant == "native" :
1263+ elif use_spin_quant == "native" :
11931264 from .source_transformation .spin_quant import (
11941265 inject_fast_hadamard_transform_native_for_spin_quant ,
11951266 )
11961267
11971268 transforms .append (inject_fast_hadamard_transform_native_for_spin_quant )
11981269
1199- if args . embedding_quantize :
1270+ if embedding_quantize :
12001271 """
12011272 When this option is selected, it finds all embedding layers and transforms
12021273 into quantized embedding equivalent module.
@@ -1206,12 +1277,25 @@ def _get_source_transforms( # noqa
12061277 transformations based on the given checkpoint first. In those cases,
12071278 this wil be a no-op.
12081279 """
1209- modelname = f"{ modelname } _e"
1280+ # Create a mock args object with the necessary attributes
1281+ class Args :
1282+ pass
1283+ args = Args ()
1284+ args .checkpoint = checkpoint
1285+ args .tokenizer_path = tokenizer_path
1286+ args .embedding_quantize = embedding_quantize
1287+ args .use_shared_embedding = use_shared_embedding
1288+ args .use_qat = use_qat
1289+ args .use_lora = use_lora
1290+ args .preq_mode = preq_mode
1291+ args .preq_group_size = preq_group_size
1292+ args .preq_embedding_quantize = preq_embedding_quantize
1293+
12101294 transforms .append (get_quant_embedding_transform (args , checkpoint_dtype ))
12111295
12121296 # quantization_mode should be applied after embedding_quantize
12131297 # to support shared_embedding
1214- if args . quantization_mode :
1298+ if quantization_mode :
12151299 """
12161300 When this option is selected, it finds all linear layers and transforms
12171301 into quantized linear equivalent module.
@@ -1225,7 +1309,23 @@ def _get_source_transforms( # noqa
12251309 There are cases where this may be a no-op, namely, if all linears are
12261310 quantized in the checkpoint.
12271311 """
1228- modelname = f"{ modelname } _q"
1312+ # Create a mock args object with the necessary attributes
1313+ class Args :
1314+ pass
1315+ args = Args ()
1316+ args .checkpoint = checkpoint
1317+ args .tokenizer_path = tokenizer_path
1318+ args .quantization_mode = quantization_mode
1319+ args .group_size = group_size
1320+ args .use_shared_embedding = use_shared_embedding
1321+ args .calibration_tasks = calibration_tasks
1322+ args .calibration_limit = calibration_limit
1323+ args .calibration_seq_length = calibration_seq_length
1324+ args .use_shared_embedding = use_shared_embedding
1325+ args .use_qat = use_qat
1326+ args .use_lora = use_lora
1327+ args .preq_mode = preq_mode
1328+
12291329 transforms .append (
12301330 get_quant_weight_transform (
12311331 args = args ,
@@ -1234,15 +1334,12 @@ def _get_source_transforms( # noqa
12341334 )
12351335 )
12361336
1237- if args . expand_rope_table :
1337+ if expand_rope_table :
12381338 transforms .append (materialze_broadcast_of_rope_freq_cis )
12391339
1240- use_attention_mask_for_custom_sdpa = False
1241- if isinstance (args , argparse .Namespace ):
1242- if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1243- use_attention_mask_for_custom_sdpa = True
1340+ use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
12441341
1245- if args . use_sdpa_with_kv_cache :
1342+ if use_sdpa_with_kv_cache :
12461343 transforms .append (replace_kv_cache_with_custom_kv_cache )
12471344 # todo: do this optionally
12481345 # if use attention mask instead of causal attention
@@ -1254,23 +1351,23 @@ def _get_source_transforms( # noqa
12541351 else :
12551352 transforms .append (replace_sdpa_with_custom_op )
12561353
1257- if args . quantize_kv_cache :
1258- assert args . use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1354+ if quantize_kv_cache :
1355+ assert use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
12591356 transforms .append (replace_kv_cache_with_quantized_kv_cache )
12601357 # Right now
12611358 transforms .append (replace_sdpa_with_quantized_sdpa )
12621359
1263- if args . use_kv_cache :
1264- if args . qnn :
1360+ if use_kv_cache :
1361+ if qnn :
12651362 from executorch .backends .qualcomm .utils .utils import (
12661363 convert_linear_to_conv2d ,
12671364 )
12681365
1269- if args . use_qnn_sha :
1270- if args . optimized_rotation_path :
1366+ if use_qnn_sha :
1367+ if optimized_rotation_path :
12711368 transforms .append (fuse_layer_norms )
12721369 transforms .append (
1273- get_model_with_r1_r2 (args . optimized_rotation_path )
1370+ get_model_with_r1_r2 (optimized_rotation_path )
12741371 )
12751372 transforms .append (replace_attention_to_attention_sha )
12761373 transforms .append (replace_causal_mask )
@@ -1282,29 +1379,29 @@ def _get_source_transforms( # noqa
12821379 transforms .append (replace_sdpa_with_flex_sdpa )
12831380 transforms .append (replace_causal_mask )
12841381 transforms .append (replace_rms_norm_with_native_rms_norm )
1285- if args . optimized_rotation_path :
1382+ if optimized_rotation_path :
12861383 transforms .append (fuse_layer_norms )
12871384 transforms .append (
1288- get_model_with_r1_r2 (args . optimized_rotation_path )
1385+ get_model_with_r1_r2 (optimized_rotation_path )
12891386 )
12901387 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
12911388 transforms .append (convert_linear_to_conv2d )
12921389
1293- elif args . mps :
1390+ elif mps :
12941391 # Currently mps doesn't support sdpa op, use the simpler decomposition
12951392 # to get free perf gain.
12961393 transforms .append (replace_sdpa_with_simple_sdpa )
12971394 transforms .append (replace_causal_mask )
12981395
1299- elif args . coreml :
1396+ elif coreml :
13001397 # iOS 18 introduced fused sdpa op
1301- if args . coreml_ios >= 18 :
1398+ if coreml_ios >= 18 :
13021399 transforms .append (replace_sdpa_with_coreml_sdpa )
13031400 else :
13041401 transforms .append (replace_sdpa_with_simple_sdpa )
13051402 transforms .append (replace_kv_cache_with_coreml_kv_cache )
13061403
1307- if args . vulkan :
1404+ if vulkan :
13081405 transforms .append (replace_with_vulkan_rotary_emb )
13091406
13101407 return transforms
0 commit comments