Skip to content

Commit d373cf9

Browse files
Change IO dtype for INT4 CUDA models (microsoft#1629)
### Description This PR allows a user to set the IO dtype (i.e. the input/output dtype) for an INT4 CUDA ONNX model to be bfloat16 precision instead of float16 precision. This can be used by using `-p/--precision int4`, `-e/--execution_provider cuda`, and then setting `--extra_options use_bf16_cuda=true/True/1`. ### Motivation and Context Models lose accuracy when converting weights from their native bfloat16 precision to float16 precision. With the [recent support](microsoft/onnxruntime#25161) of bfloat16 precision in `MatMulNBits`, the conversion is not always needed any more.
1 parent 95d7a33 commit d373cf9

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

src/python/py/models/builder.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2417,8 +2417,7 @@ def make_relu(self, layer_id, root_input, activation):
24172417

24182418
def make_relu_squared(self, layer_id, root_input, activation):
24192419
relu_name = self.make_relu(layer_id, root_input, "Relu")
2420-
basename = f"/model/layers.{layer_id}/mlp/square/{activation}"
2421-
pow_name = f"{basename}/pow"
2420+
pow_name = f"/model/layers.{layer_id}/mlp/act_fn/Pow"
24222421
pow_inputs = [f"{relu_name}/output_0", "/model/constants/INT32/[2]"]
24232422
self.make_node("Pow", inputs=pow_inputs, outputs=[f"{pow_name}/output_0"], name=pow_name, domain="")
24242423
self.make_value(f"{pow_name}/output_0", self.io_dtype, shape=['batch_size', 'sequence_length', self.intermediate_size])
@@ -3645,7 +3644,7 @@ def check_extra_options(kv_pairs):
36453644
"""
36463645
Check key-value pairs and set values correctly
36473646
"""
3648-
bools = ["int4_is_symmetric", "exclude_embeds", "exclude_lm_head", "include_hidden_states", "enable_cuda_graph", "use_8bits_moe", "use_qdq", "use_webgpu_fp32"]
3647+
bools = ["int4_is_symmetric", "exclude_embeds", "exclude_lm_head", "include_hidden_states", "enable_cuda_graph", "use_8bits_moe", "use_qdq", "use_webgpu_fp32", "use_cuda_bf16"]
36493648
for key in bools:
36503649
if key in kv_pairs:
36513650
if kv_pairs[key] in {"false", "False", "0"}:
@@ -3710,11 +3709,15 @@ def parse_hf_token(hf_token):
37103709

37113710

37123711
def set_io_dtype(precision, execution_provider, extra_options) -> ir.DataType:
3713-
if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") or extra_options.get("use_webgpu_fp32", False):
3712+
int4_cpu = precision == "int4" and execution_provider == "cpu"
3713+
fp32_webgpu = execution_provider == "webgpu" and extra_options.get("use_webgpu_fp32", False)
3714+
bf16_cuda = precision == "int4" and execution_provider == "cuda" and extra_options.get("use_cuda_bf16", False)
3715+
3716+
if precision in {"int8", "fp32"} or int4_cpu or fp32_webgpu:
37143717
# FP32 precision
37153718
return ir.DataType.FLOAT
37163719

3717-
if precision == "bf16":
3720+
if precision == "bf16" or bf16_cuda:
37183721
# BF16 precision
37193722
return ir.DataType.BFLOAT16
37203723

@@ -3951,8 +3954,10 @@ def get_args():
39513954
If true, the QMoE op will use 8-bit quantization. If false, the QMoE op will use 4-bit quantization.
39523955
use_qdq = Use the QDQ decomposition for ops.
39533956
Use this option when you want to use quantize-dequantize ops. For example, you will have a quantized MatMul op instead of the MatMulNBits op.
3954-
use_webgpu_fp32 = Use FP32 for WebGPU EP.
3957+
use_webgpu_fp32 = Use FP32 I/O precision for WebGPU EP.
39553958
Use this option to enable GPUs that do not support FP16 on WebGPU (e.g. GTX 10xx).
3959+
use_cuda_bf16 = Use BF16 I/O precision in quantized ONNX models for CUDA EP.
3960+
Use this option to create quantized ONNX models that use BF16 precision.
39563961
adapter_path = Path to folder on disk containing the adapter files (adapter_config.json and adapter model weights).
39573962
Use this option for LoRA models.
39583963
"""),

0 commit comments

Comments
 (0)