Skip to content

Commit 5838c8a

Browse files
Fix (brevitas_examples/llm): more checks for FX-related args (#1441)
--------- Co-authored-by: Pablo Monteagudo Lago <44771380+pablomlago@users.noreply.github.com>
1 parent 1c0ebbe commit 5838c8a

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/brevitas_examples/llm/llm_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,10 @@ def create_args_parser() -> ArgumentParser:
485485
return parser
486486

487487

488+
def fx_required(args: Namespace):
489+
return args.weight_equalization or args.act_equalization == 'fx' or args.rotation == 'fx' or args.ln_affine_merge or args.convert_layernorm_to_rmsnorm or args.quant_sdpa == 'fx'
490+
491+
488492
def validate(args: Namespace, extra_args: Optional[List[str]] = None) -> None:
489493
if args.optimize_rotations:
490494
assert args.rotation in ['fx', 'fused_no_fx'], f"Rotations can only be optimized if --rotation=fx or --rotation=fused_no_fx"
@@ -494,6 +498,12 @@ def validate(args: Namespace, extra_args: Optional[List[str]] = None) -> None:
494498
assert args.ln_affine_merge, 'Graph rotation requires to merge LN/RMS norm affine parameters'
495499
assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)'
496500
assert args.convert_layernorm_to_rmsnorm, 'Graph rotation requires to replace LayerNorm with RMSNorm'
501+
# FX is not compatible with few-shot evaluation
502+
assert args.few_shot_eval is None, "FX is not compatible with few shot evaluation, use fused_no_fx"
503+
# Otherwise we might end up tracing through dynamo twice and other weird errors.
504+
# Fused_no_fx takes care of all the rotations-related transformations
505+
if fx_required(args):
506+
assert args.rotation != "fused_no_fx", "fused_no_fx is incompatible with any option that requires FX tracing"
497507
elif args.rotation == 'fused_no_fx':
498508
assert not args.convert_layernorm_to_rmsnorm, 'LayerNorm is automatically replaced with RMSNorm when running with --rotation=fused_no_fx. Remove the flag --convert-layernorm-to-rmsnorm'
499509
assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)'

src/brevitas_examples/llm/main.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from brevitas_examples.common.parse_utils import parse_args
3939
from brevitas_examples.llm.gguf_export.export import save_quantized_as_gguf
4040
from brevitas_examples.llm.llm_args import create_args_parser
41+
from brevitas_examples.llm.llm_args import fx_required
4142
from brevitas_examples.llm.llm_args import validate
4243
from brevitas_examples.llm.llm_quant.awq.pre_quant import apply_awq
4344
from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction
@@ -201,10 +202,6 @@ def model_export(model, tokenizer, ref_input, args, config=None):
201202
ds.save(export_path, io_report_callback=None)
202203

203204

204-
def fx_required(args):
205-
return args.weight_equalization or args.act_equalization == 'fx' or args.rotation == 'fx' or args.ln_affine_merge or args.convert_layernorm_to_rmsnorm or args.quant_sdpa == 'fx'
206-
207-
208205
# Recursive function to unwrap equalized layers
209206
def find_equalized_layer(layer):
210207
if hasattr(layer, 'layer'):

0 commit comments

Comments
 (0)