Skip to content

Commit 26ce4dc

Browse files
committed
feat: refactor _get_source_transforms to remove args parameter and unused modelname
1 parent 2837867 commit 26ce4dc

File tree

1 file changed

+107
-32
lines changed

1 file changed

+107
-32
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 107 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -651,10 +651,29 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
651651
logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}")
652652
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
653653
_get_source_transforms(
654-
modelname=args.model,
655654
dtype_override=dtype_override,
656655
checkpoint_dtype=DType.from_torch_dtype(checkpoint_dtype), # type: ignore
657-
args=args,
656+
use_spin_quant=args.use_spin_quant,
657+
embedding_quantize=args.embedding_quantize,
658+
quantization_mode=args.quantization_mode,
659+
expand_rope_table=args.expand_rope_table,
660+
use_custom_sdpa_with_attention_mask=getattr(args, "use_custom_sdpa_with_attention_mask", False),
661+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
662+
quantize_kv_cache=args.quantize_kv_cache,
663+
use_kv_cache=args.use_kv_cache,
664+
qnn=args.qnn,
665+
use_qnn_sha=args.use_qnn_sha,
666+
optimized_rotation_path=args.optimized_rotation_path,
667+
mps=args.mps,
668+
coreml=args.coreml,
669+
coreml_ios=args.coreml_ios,
670+
vulkan=args.vulkan,
671+
use_shared_embedding=args.use_shared_embedding,
672+
use_qat=args.use_qat,
673+
use_lora=args.use_lora,
674+
preq_mode=args.preq_mode,
675+
preq_group_size=args.preq_group_size,
676+
preq_embedding_quantize=args.preq_embedding_quantize,
658677
)
659678
)
660679

@@ -1145,23 +1164,61 @@ def _load_llama_model(
11451164

11461165

11471166
def _get_source_transforms( # noqa
1148-
modelname: str,
11491167
dtype_override: DType,
11501168
*,
11511169
checkpoint_dtype: Optional[DType] = None,
1152-
args,
1170+
use_spin_quant: Optional[str] = None,
1171+
embedding_quantize: Optional[str] = None,
1172+
quantization_mode: Optional[str] = None,
1173+
expand_rope_table: bool = False,
1174+
use_custom_sdpa_with_attention_mask: bool = False,
1175+
use_sdpa_with_kv_cache: bool = False,
1176+
quantize_kv_cache: bool = False,
1177+
use_kv_cache: bool = False,
1178+
qnn: bool = False,
1179+
use_qnn_sha: bool = False,
1180+
optimized_rotation_path: Optional[str] = None,
1181+
mps: bool = False,
1182+
coreml: bool = False,
1183+
coreml_ios: int = 15,
1184+
vulkan: bool = False,
1185+
use_shared_embedding: bool = False,
1186+
use_qat: bool = False,
1187+
use_lora: int = 0,
1188+
preq_mode: Optional[str] = None,
1189+
preq_group_size: int = 32,
1190+
preq_embedding_quantize: str = "8,0",
11531191
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
11541192
"""
11551193
Return a list of functions that transform a graph.
11561194
11571195
Args:
1158-
modelname: The name of the model.
11591196
dtype_override: The dtype to use for the model.
11601197
checkpoint_dtype: The dtype of the checkpoint. At the moment, if this is specified,
11611198
it means that you want to run quantize transformations on the weights represented
11621199
in their original dtype, while the overall dtype of the model maybe something
11631200
different. If not specified, defaults to dtype_override.
1164-
args: The arguments passed to the script.
1201+
use_spin_quant: Type of spin quant to use ("cuda" or "native").
1202+
embedding_quantize: Type of embedding quantization.
1203+
quantization_mode: Type of quantization mode.
1204+
expand_rope_table: Whether to expand rope table.
1205+
use_custom_sdpa_with_attention_mask: Whether to use custom SDPA with attention mask.
1206+
use_sdpa_with_kv_cache: Whether to use SDPA with KV cache.
1207+
quantize_kv_cache: Whether to quantize KV cache.
1208+
use_kv_cache: Whether to use KV cache.
1209+
qnn: Whether to use QNN.
1210+
use_qnn_sha: Whether to use QNN SHA.
1211+
optimized_rotation_path: Path to optimized rotation.
1212+
mps: Whether to use MPS.
1213+
coreml: Whether to use CoreML.
1214+
coreml_ios: CoreML iOS version.
1215+
vulkan: Whether to use Vulkan.
1216+
use_shared_embedding: Whether to use shared embedding.
1217+
use_qat: Whether to use QAT.
1218+
use_lora: LoRA rank (0 means no LoRA).
1219+
preq_mode: Pre-quantization mode.
1220+
preq_group_size: Pre-quantization group size.
1221+
preq_embedding_quantize: Pre-quantization embedding quantize.
11651222
11661223
Returns:
11671224
A list of transformation functions.
@@ -1172,21 +1229,21 @@ def _get_source_transforms( # noqa
11721229

11731230
transforms = []
11741231

1175-
if args.use_spin_quant:
1176-
if args.use_spin_quant == "cuda":
1232+
if use_spin_quant:
1233+
if use_spin_quant == "cuda":
11771234
from .source_transformation.spin_quant import (
11781235
inject_fast_hadamard_transform_cuda_for_spin_quant,
11791236
)
11801237

11811238
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
1182-
elif args.use_spin_quant == "native":
1239+
elif use_spin_quant == "native":
11831240
from .source_transformation.spin_quant import (
11841241
inject_fast_hadamard_transform_native_for_spin_quant,
11851242
)
11861243

11871244
transforms.append(inject_fast_hadamard_transform_native_for_spin_quant)
11881245

1189-
if args.embedding_quantize:
1246+
if embedding_quantize:
11901247
"""
11911248
When this option is selected, it finds all embedding layers and transforms
11921249
into quantized embedding equivalent module.
@@ -1196,12 +1253,23 @@ def _get_source_transforms( # noqa
11961253
transformations based on the given checkpoint first. In those cases,
11971254
this wil be a no-op.
11981255
"""
1199-
modelname = f"{modelname}_e"
1256+
# Create a mock args object with the necessary attributes
1257+
class Args:
1258+
pass
1259+
args = Args()
1260+
args.embedding_quantize = embedding_quantize
1261+
args.use_shared_embedding = use_shared_embedding
1262+
args.use_qat = use_qat
1263+
args.use_lora = use_lora
1264+
args.preq_mode = preq_mode
1265+
args.preq_group_size = preq_group_size
1266+
args.preq_embedding_quantize = preq_embedding_quantize
1267+
12001268
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
12011269

12021270
# quantization_mode should be applied after embedding_quantize
12031271
# to support shared_embedding
1204-
if args.quantization_mode:
1272+
if quantization_mode:
12051273
"""
12061274
When this option is selected, it finds all linear layers and transforms
12071275
into quantized linear equivalent module.
@@ -1215,7 +1283,17 @@ def _get_source_transforms( # noqa
12151283
There are cases where this may be a no-op, namely, if all linears are
12161284
quantized in the checkpoint.
12171285
"""
1218-
modelname = f"{modelname}_q"
1286+
# Create a mock args object with the necessary attributes
1287+
class Args:
1288+
pass
1289+
args = Args()
1290+
args.quantization_mode = quantization_mode
1291+
args.group_size = preq_group_size # Using preq_group_size as group_size
1292+
args.use_shared_embedding = use_shared_embedding
1293+
args.use_qat = use_qat
1294+
args.use_lora = use_lora
1295+
args.preq_mode = preq_mode
1296+
12191297
transforms.append(
12201298
get_quant_weight_transform(
12211299
args=args,
@@ -1224,15 +1302,12 @@ def _get_source_transforms( # noqa
12241302
)
12251303
)
12261304

1227-
if args.expand_rope_table:
1305+
if expand_rope_table:
12281306
transforms.append(materialze_broadcast_of_rope_freq_cis)
12291307

1230-
use_attention_mask_for_custom_sdpa = False
1231-
if isinstance(args, argparse.Namespace):
1232-
if getattr(args, "use_custom_sdpa_with_attention_mask", None):
1233-
use_attention_mask_for_custom_sdpa = True
1308+
use_attention_mask_for_custom_sdpa = use_custom_sdpa_with_attention_mask
12341309

1235-
if args.use_sdpa_with_kv_cache:
1310+
if use_sdpa_with_kv_cache:
12361311
transforms.append(replace_kv_cache_with_custom_kv_cache)
12371312
# todo: do this optionally
12381313
# if use attention mask instead of causal attention
@@ -1244,23 +1319,23 @@ def _get_source_transforms( # noqa
12441319
else:
12451320
transforms.append(replace_sdpa_with_custom_op)
12461321

1247-
if args.quantize_kv_cache:
1248-
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
1322+
if quantize_kv_cache:
1323+
assert use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
12491324
transforms.append(replace_kv_cache_with_quantized_kv_cache)
12501325
# Right now
12511326
transforms.append(replace_sdpa_with_quantized_sdpa)
12521327

1253-
if args.use_kv_cache:
1254-
if args.qnn:
1328+
if use_kv_cache:
1329+
if qnn:
12551330
from executorch.backends.qualcomm.utils.utils import (
12561331
convert_linear_to_conv2d,
12571332
)
12581333

1259-
if args.use_qnn_sha:
1260-
if args.optimized_rotation_path:
1334+
if use_qnn_sha:
1335+
if optimized_rotation_path:
12611336
transforms.append(fuse_layer_norms)
12621337
transforms.append(
1263-
get_model_with_r1_r2(args.optimized_rotation_path)
1338+
get_model_with_r1_r2(optimized_rotation_path)
12641339
)
12651340
transforms.append(replace_attention_to_attention_sha)
12661341
transforms.append(replace_causal_mask)
@@ -1272,29 +1347,29 @@ def _get_source_transforms( # noqa
12721347
transforms.append(replace_sdpa_with_flex_sdpa)
12731348
transforms.append(replace_causal_mask)
12741349
transforms.append(replace_rms_norm_with_native_rms_norm)
1275-
if args.optimized_rotation_path:
1350+
if optimized_rotation_path:
12761351
transforms.append(fuse_layer_norms)
12771352
transforms.append(
1278-
get_model_with_r1_r2(args.optimized_rotation_path)
1353+
get_model_with_r1_r2(optimized_rotation_path)
12791354
)
12801355
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
12811356
transforms.append(convert_linear_to_conv2d)
12821357

1283-
elif args.mps:
1358+
elif mps:
12841359
# Currently mps doesn't support sdpa op, use the simpler decomposition
12851360
# to get free perf gain.
12861361
transforms.append(replace_sdpa_with_simple_sdpa)
12871362
transforms.append(replace_causal_mask)
12881363

1289-
elif args.coreml:
1364+
elif coreml:
12901365
# iOS 18 introduced fused sdpa op
1291-
if args.coreml_ios >= 18:
1366+
if coreml_ios >= 18:
12921367
transforms.append(replace_sdpa_with_coreml_sdpa)
12931368
else:
12941369
transforms.append(replace_sdpa_with_simple_sdpa)
12951370
transforms.append(replace_kv_cache_with_coreml_kv_cache)
12961371

1297-
if args.vulkan:
1372+
if vulkan:
12981373
transforms.append(replace_with_vulkan_rotary_emb)
12991374

13001375
return transforms

0 commit comments

Comments
 (0)