diff --git a/examples/windows/onnx_ptq/genai_llm/README.md b/examples/windows/onnx_ptq/genai_llm/README.md index f5f012e11..8c0f8e94d 100644 --- a/examples/windows/onnx_ptq/genai_llm/README.md +++ b/examples/windows/onnx_ptq/genai_llm/README.md @@ -58,6 +58,8 @@ The table below lists key command-line arguments of the ONNX PTQ example script. | `--no_position_ids` | Default: position_ids input enabled | Use this option to disable position_ids input in calibration data| | `--enable_mixed_quant` | Default: disabled mixed quant | Use this option to enable mixed precsion quantization| | `--layers_8bit` | Default: None | Use this option to Overrides default mixed quant strategy| +| `--gather_quantize_axis` | Default: None | Use this option to enable INT4 quantization of Gather nodes - choose 0 or 1| +| `--gather_block_size` | Default: 32 | Block-size for Gather node's INT4 quantization (when its enabled using gather_quantize_axis option)| Run the following command to view all available parameters in the script: diff --git a/examples/windows/onnx_ptq/genai_llm/quantize.py b/examples/windows/onnx_ptq/genai_llm/quantize.py index 870baa7b9..57021ed4d 100644 --- a/examples/windows/onnx_ptq/genai_llm/quantize.py +++ b/examples/windows/onnx_ptq/genai_llm/quantize.py @@ -441,6 +441,8 @@ def main(args): awqclip_bsz_col=args.awqclip_bsz_col, enable_mixed_quant=args.enable_mixed_quant, layers_8bit=args.layers_8bit, + gather_block_size=args.gather_block_size, + gather_quantize_axis=args.gather_quantize_axis, ) logging.info(f"\nQuantization process took {time.time() - t} seconds") @@ -553,7 +555,19 @@ def main(args): "--block_size", type=int, default=128, - help="Block size for AWQ quantization", + help="Block size for INT4 quantization of MatMul/Gemm nodes", + ) + parser.add_argument( + "--gather_block_size", + type=int, + default=32, + help="Block size for INT4 quantization of Gather nodes", + ) + parser.add_argument( + "--gather_quantize_axis", + type=int, + default=None, + help="Quantization axis for INT4 quantization of Gather nodes", ) parser.add_argument( "--use_zero_point",