@@ -661,10 +661,37 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
661661 logging .info (f"Checkpoint dtype: { edge_manager .model .checkpoint_dtype } " )
662662 edge_manager = edge_manager .set_output_dir (output_dir_path ).source_transform (
663663 _get_source_transforms (
664- modelname = args .model ,
665664 dtype_override = dtype_override ,
665+ checkpoint = args .checkpoint ,
666666 checkpoint_dtype = DType .from_torch_dtype (checkpoint_dtype ), # type: ignore
667- args = args ,
667+ tokenizer_path = args .tokenizer_path ,
668+ use_spin_quant = args .use_spin_quant ,
669+ embedding_quantize = args .embedding_quantize ,
670+ use_shared_embedding = args .use_shared_embedding ,
671+ quantization_mode = args .quantization_mode ,
672+ group_size = args .group_size ,
673+ calibration_tasks = args .calibration_tasks ,
674+ calibration_limit = args .calibration_limit ,
675+ calibration_seq_length = args .calibration_seq_length ,
676+ expand_rope_table = args .expand_rope_table ,
677+ use_custom_sdpa_with_attention_mask = getattr (
678+ args , "use_custom_sdpa_with_attention_mask" , False
679+ ),
680+ use_sdpa_with_kv_cache = args .use_sdpa_with_kv_cache ,
681+ quantize_kv_cache = args .quantize_kv_cache ,
682+ use_kv_cache = args .use_kv_cache ,
683+ qnn = args .qnn ,
684+ use_qnn_sha = args .use_qnn_sha ,
685+ optimized_rotation_path = args .optimized_rotation_path ,
686+ mps = args .mps ,
687+ coreml = args .coreml ,
688+ coreml_ios = args .coreml_ios ,
689+ vulkan = args .vulkan ,
690+ use_qat = args .use_qat ,
691+ use_lora = args .use_lora ,
692+ preq_mode = args .preq_mode ,
693+ preq_group_size = args .preq_group_size ,
694+ preq_embedding_quantize = args .preq_embedding_quantize ,
668695 )
669696 )
670697
@@ -1189,23 +1216,69 @@ def _load_llama_model(
11891216
11901217
11911218def _get_source_transforms ( # noqa
1192- modelname : str ,
11931219 dtype_override : DType ,
11941220 * ,
1221+ checkpoint : Optional [str ] = None ,
11951222 checkpoint_dtype : Optional [DType ] = None ,
1196- args ,
1223+ tokenizer_path : Optional [str ] = None ,
1224+ use_spin_quant : Optional [str ] = None ,
1225+ embedding_quantize : Optional [str ] = None ,
1226+ use_shared_embedding : bool = False ,
1227+ quantization_mode : Optional [str ] = None ,
1228+ group_size : Optional [int ] = None ,
1229+ calibration_tasks : Optional [List [str ]] = None ,
1230+ calibration_limit : Optional [int ] = None ,
1231+ calibration_seq_length : Optional [int ] = None ,
1232+ expand_rope_table : bool = False ,
1233+ use_custom_sdpa_with_attention_mask : bool = False ,
1234+ use_sdpa_with_kv_cache : bool = False ,
1235+ quantize_kv_cache : bool = False ,
1236+ use_kv_cache : bool = False ,
1237+ qnn : bool = False ,
1238+ use_qnn_sha : bool = False ,
1239+ optimized_rotation_path : Optional [str ] = None ,
1240+ mps : bool = False ,
1241+ coreml : bool = False ,
1242+ coreml_ios : int = 15 ,
1243+ vulkan : bool = False ,
1244+ use_qat : bool = False ,
1245+ use_lora : int = 0 ,
1246+ preq_mode : Optional [str ] = None ,
1247+ preq_group_size : Optional [int ] = None ,
1248+ preq_embedding_quantize : Optional [str ] = None ,
11971249) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
11981250 """
11991251 Return a list of functions that transform a graph.
12001252
12011253 Args:
1202- modelname: The name of the model.
12031254 dtype_override: The dtype to use for the model.
1255+ checkpoint: Path to the checkpoint file.
12041256 checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
12051257 it means that you want to run quantize transformations on the weights represented
12061258 in their original dtype, while the overall dtype of the model maybe something
12071259 different. If not specified, defaults to dtype_override.
1208- args: The arguments passed to the script.
1260+ tokenizer_path: Path to the tokenizer file.
1261+ use_spin_quant: Type of spin quant to use ("cuda" or "native").
1262+ embedding_quantize: Type of embedding quantization.
1263+ quantization_mode: Type of quantization mode.
1264+ expand_rope_table: Whether to expand rope table.
1265+ use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1266+ use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1267+ quantize_kv_cache: Whether to quantize KV cache.
1268+ use_kv_cache: Whether to use KV cache.
1269+ qnn: Whether to use QNN.
1270+ use_qnn_sha: Whether to use QNN SHA.
1271+ optimized_rotation_path: Path to optimized rotation.
1272+ mps: Whether to use MPS.
1273+ coreml: Whether to use CoreML.
1274+ coreml_ios: CoreML iOS version.
1275+ vulkan: Whether to use Vulkan.
1276+ use_shared_embedding: Whether to use shared embedding.
1277+ use_qat: Whether to use QAT.
1278+ use_lora: LoRA rank (0 means no LoRA).
1279+ preq_mode: Pre-quantization mode.
1280+ preq_group_size: Pre-quantization group size.
1281+ preq_embedding_quantize: Pre-quantization embedding quantize.
12091282
12101283 Returns:
12111284 A list of transformation functions.
@@ -1216,21 +1289,21 @@ def _get_source_transforms( # noqa
12161289
12171290 transforms = []
12181291
1219- if args . use_spin_quant :
1220- if args . use_spin_quant == "cuda" :
1292+ if use_spin_quant :
1293+ if use_spin_quant == "cuda" :
12211294 from .source_transformation .spin_quant import (
12221295 inject_fast_hadamard_transform_cuda_for_spin_quant ,
12231296 )
12241297
12251298 transforms .append (inject_fast_hadamard_transform_cuda_for_spin_quant )
1226- elif args . use_spin_quant == "native" :
1299+ elif use_spin_quant == "native" :
12271300 from .source_transformation .spin_quant import (
12281301 inject_fast_hadamard_transform_native_for_spin_quant ,
12291302 )
12301303
12311304 transforms .append (inject_fast_hadamard_transform_native_for_spin_quant )
12321305
1233- if args . embedding_quantize :
1306+ if embedding_quantize :
12341307 """
12351308 When this option is selected, it finds all embedding layers and transforms
12361309 into quantized embedding equivalent module.
@@ -1240,12 +1313,27 @@ def _get_source_transforms( # noqa
12401313 transformations based on the given checkpoint first. In those cases,
12411314 this wil be a no-op.
12421315 """
1243- modelname = f"{ modelname } _e"
1316+
1317+ # Create a mock args object with the necessary attributes
1318+ class Args :
1319+ pass
1320+
1321+ args = Args ()
1322+ args .checkpoint = checkpoint
1323+ args .tokenizer_path = tokenizer_path
1324+ args .embedding_quantize = embedding_quantize
1325+ args .use_shared_embedding = use_shared_embedding
1326+ args .use_qat = use_qat
1327+ args .use_lora = use_lora
1328+ args .preq_mode = preq_mode
1329+ args .preq_group_size = preq_group_size
1330+ args .preq_embedding_quantize = preq_embedding_quantize
1331+
12441332 transforms .append (get_quant_embedding_transform (args , checkpoint_dtype ))
12451333
12461334 # quantization_mode should be applied after embedding_quantize
12471335 # to support shared_embedding
1248- if args . quantization_mode :
1336+ if quantization_mode :
12491337 """
12501338 When this option is selected, it finds all linear layers and transforms
12511339 into quantized linear equivalent module.
@@ -1259,7 +1347,25 @@ def _get_source_transforms( # noqa
12591347 There are cases where this may be a no-op, namely, if all linears are
12601348 quantized in the checkpoint.
12611349 """
1262- modelname = f"{ modelname } _q"
1350+
1351+ # Create a mock args object with the necessary attributes
1352+ class Args :
1353+ pass
1354+
1355+ args = Args ()
1356+ args .checkpoint = checkpoint
1357+ args .tokenizer_path = tokenizer_path
1358+ args .quantization_mode = quantization_mode
1359+ args .group_size = group_size
1360+ args .use_shared_embedding = use_shared_embedding
1361+ args .calibration_tasks = calibration_tasks
1362+ args .calibration_limit = calibration_limit
1363+ args .calibration_seq_length = calibration_seq_length
1364+ args .use_shared_embedding = use_shared_embedding
1365+ args .use_qat = use_qat
1366+ args .use_lora = use_lora
1367+ args .preq_mode = preq_mode
1368+
12631369 transforms .append (
12641370 get_quant_weight_transform (
12651371 args = args ,
@@ -1268,15 +1374,12 @@ def _get_source_transforms( # noqa
12681374 )
12691375 )
12701376
1271- if args . expand_rope_table :
1377+ if expand_rope_table :
12721378 transforms .append (materialze_broadcast_of_rope_freq_cis )
12731379
1274- use_attention_mask_for_custom_sdpa = False
1275- if isinstance (args , argparse .Namespace ):
1276- if getattr (args , "use_custom_sdpa_with_attention_mask" , None ):
1277- use_attention_mask_for_custom_sdpa = True
1380+ use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
12781381
1279- if args . use_sdpa_with_kv_cache :
1382+ if use_sdpa_with_kv_cache :
12801383 transforms .append (replace_kv_cache_with_custom_kv_cache )
12811384 # todo: do this optionally
12821385 # if use attention mask instead of causal attention
@@ -1288,24 +1391,22 @@ def _get_source_transforms( # noqa
12881391 else :
12891392 transforms .append (replace_sdpa_with_custom_op )
12901393
1291- if args . quantize_kv_cache :
1292- assert args . use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
1394+ if quantize_kv_cache :
1395+ assert use_kv_cache , "quantize_kv_cache requires use_kv_cache=True"
12931396 transforms .append (replace_kv_cache_with_quantized_kv_cache )
12941397 # Right now
12951398 transforms .append (replace_sdpa_with_quantized_sdpa )
12961399
1297- if args . use_kv_cache :
1298- if args . qnn :
1400+ if use_kv_cache :
1401+ if qnn :
12991402 from executorch .backends .qualcomm .utils .utils import (
13001403 convert_linear_to_conv2d ,
13011404 )
13021405
1303- if args . use_qnn_sha :
1304- if args . optimized_rotation_path :
1406+ if use_qnn_sha :
1407+ if optimized_rotation_path :
13051408 transforms .append (fuse_layer_norms )
1306- transforms .append (
1307- get_model_with_r1_r2 (args .optimized_rotation_path )
1308- )
1409+ transforms .append (get_model_with_r1_r2 (optimized_rotation_path ))
13091410 transforms .append (replace_attention_to_attention_sha )
13101411 transforms .append (replace_causal_mask )
13111412 transforms .append (replace_rms_norm_with_native_rms_norm )
@@ -1316,29 +1417,27 @@ def _get_source_transforms( # noqa
13161417 transforms .append (replace_sdpa_with_flex_sdpa )
13171418 transforms .append (replace_causal_mask )
13181419 transforms .append (replace_rms_norm_with_native_rms_norm )
1319- if args . optimized_rotation_path :
1420+ if optimized_rotation_path :
13201421 transforms .append (fuse_layer_norms )
1321- transforms .append (
1322- get_model_with_r1_r2 (args .optimized_rotation_path )
1323- )
1422+ transforms .append (get_model_with_r1_r2 (optimized_rotation_path ))
13241423 # pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
13251424 transforms .append (convert_linear_to_conv2d )
13261425
1327- elif args . mps :
1426+ elif mps :
13281427 # Currently mps doesn't support sdpa op, use the simpler decomposition
13291428 # to get free perf gain.
13301429 transforms .append (replace_sdpa_with_simple_sdpa )
13311430 transforms .append (replace_causal_mask )
13321431
1333- elif args . coreml :
1432+ elif coreml :
13341433 # iOS 18 introduced fused sdpa op
1335- if args . coreml_ios >= 18 :
1434+ if coreml_ios >= 18 :
13361435 transforms .append (replace_sdpa_with_coreml_sdpa )
13371436 else :
13381437 transforms .append (replace_sdpa_with_simple_sdpa )
13391438 transforms .append (replace_kv_cache_with_coreml_kv_cache )
13401439
1341- if args . vulkan :
1440+ if vulkan :
13421441 transforms .append (replace_with_vulkan_rotary_emb )
13431442
13441443 return transforms
0 commit comments