|
117 | 117 | FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
|
118 | 118 | logging.basicConfig(level=logging.INFO, format=FORMAT)
|
119 | 119 | logging.getLogger().setLevel(logging.INFO)
|
| 120 | +# Avoid the error message "Could not initialize NNPACK! Reason: Unsupported hardware." |
| 121 | +torch.backends.nnpack.set_flags(False) |
120 | 122 |
|
121 | 123 |
|
122 | 124 | def next_power_of_two(n):
|
@@ -235,10 +237,16 @@ def quantize(
|
235 | 237 | ).module()
|
236 | 238 |
|
237 | 239 | if quant_dtype == QuantDtype.use_16a4w_block:
|
| 240 | + if args.group_size is None: |
| 241 | + raise ValueError( |
| 242 | + "Group size is required when use quant_dtype 16a4w_block" |
| 243 | + ) |
238 | 244 | conv_nodes = [
|
239 | 245 | n for n in fx_graph_module.graph.nodes if "conv" in n.name
|
240 | 246 | ]
|
241 |
| - block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes} |
| 247 | + block_size_map = { |
| 248 | + n.name: (1, args.group_size, 1, 1) for n in conv_nodes |
| 249 | + } |
242 | 250 | quantizer.set_block_size_map(block_size_map)
|
243 | 251 |
|
244 | 252 | fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
|
@@ -635,7 +643,7 @@ def permute(w, heads):
|
635 | 643 | if args.ptq != "16a8w":
|
636 | 644 | # 16a8w use 16bit kv io, so skip this custom annotation
|
637 | 645 | custom_annotations = custom_annotations + (annotate_matmul_16a8w,)
|
638 |
| - if args.decoder_model in {"stories110m", "stories260k"}: |
| 646 | + if args.decoder_model in {"stories110m", "stories260k", "phi_4_mini"}: |
639 | 647 | custom_annotations = custom_annotations + (
|
640 | 648 | annotate_linear_16a8w_in_affine_layer,
|
641 | 649 | )
|
@@ -853,12 +861,20 @@ def post_process():
|
853 | 861 |
|
854 | 862 | seq_len = args.max_seq_len
|
855 | 863 | multi_prompts = " ".join([f'--prompt "{prompt}"' for prompt in args.prompt])
|
| 864 | + lookahead_args = " ".join( |
| 865 | + [ |
| 866 | + f"--window {args.window}", |
| 867 | + f"--gcap {args.gcap}", |
| 868 | + f"--ngram {args.ngram}", |
| 869 | + ] |
| 870 | + ) |
856 | 871 | runner_args = " ".join(
|
857 | 872 | [
|
858 | 873 | multi_prompts,
|
859 | 874 | f"--eval_mode {EVAL_MODE[args.model_mode]}",
|
860 | 875 | f"--temperature {args.temperature}",
|
861 | 876 | f"--system_prompt '{args.system_prompt}'",
|
| 877 | + lookahead_args if args.model_mode == "lookahead" else "", |
862 | 878 | ]
|
863 | 879 | )
|
864 | 880 |
|
@@ -908,9 +924,6 @@ def post_process():
|
908 | 924 | "--output_path outputs/outputs.txt",
|
909 | 925 | f"--performance_output_path {performance_output_path}",
|
910 | 926 | f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}",
|
911 |
| - f"--window {args.window}", |
912 |
| - f"--gcap {args.gcap}", |
913 |
| - f"--ngram {args.ngram}", |
914 | 927 | runner_args,
|
915 | 928 | ]
|
916 | 929 | )
|
@@ -1175,6 +1188,13 @@ def _build_parser():
|
1175 | 1188 | action="store_true",
|
1176 | 1189 | default=False,
|
1177 | 1190 | )
|
| 1191 | + parser.add_argument( |
| 1192 | + "-G", |
| 1193 | + "--group_size", |
| 1194 | + type=int, |
| 1195 | + default=None, |
| 1196 | + help="group_size used in block quantization for weight quantization.", |
| 1197 | + ) |
1178 | 1198 |
|
1179 | 1199 | parser.add_argument("-v", "--verbose", action="store_true")
|
1180 | 1200 |
|
|
0 commit comments