@@ -264,10 +264,11 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
264
264
265
265
# Quantization-specific variables (INT4, INT8, etc.)
266
266
int4_algo_config = self .make_int4_algo_config (extra_options .get ("int4_algo_config" , "default" ))
267
+ self .int4_block_size = extra_options .get ("int4_block_size" , 32 )
267
268
self .quant_attrs = {
268
269
"int4" : {
269
270
"accuracy_level" : int (extra_options .get ("int4_accuracy_level" , 4 if self .ep in ["cpu" , "webgpu" ] else 0 )),
270
- "block_size" : int (extra_options . get ( " int4_block_size" , 32 ) ),
271
+ "block_size" : int (self . int4_block_size ),
271
272
"is_symmetric" : extra_options .get ("int4_is_symmetric" , True ),
272
273
"op_types_to_quantize" : extra_options .get ("int4_op_types_to_quantize" , ("MatMul" , )),
273
274
"nodes_to_exclude" : extra_options .get ("int4_nodes_to_exclude" , []),
@@ -280,6 +281,13 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
280
281
self .quant_attrs ["config" ] = config .quantization_config
281
282
self .quant_attrs ["use_g_idx" ] = config .quantization_config ["desc_act" ] if "desc_act" in config .quantization_config else False
282
283
284
+ self .int4_tied_embeddings = config .tie_word_embeddings if hasattr (config , "tie_word_embeddings" ) and config .tie_word_embeddings is not None else False
285
+ self .int4_tied_embeddings = extra_options .get ("int4_tied_embeddings" , self .int4_tied_embeddings )
286
+ self .int8_lm_head = extra_options .get ("int4_algo_config" , "default" ) in {"k_quant_mixed" , "k_quant_last" }
287
+ if not self .int8_lm_head :
288
+ # matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match.
289
+ self .int4_tied_embeddings = False
290
+
283
291
def to_str_dtype (self , dtype : ir .DataType ) -> str :
284
292
return dtype .name
285
293
@@ -1069,13 +1077,28 @@ def make_packed_add(self, q_add, k_add, v_add, name, root_input, **kwargs):
1069
1077
self .make_add_bias (add , name , root_input , ** kwargs )
1070
1078
1071
1079
def make_embedding (self , embedding ):
1072
- weight = "model.embed_tokens.weight"
1073
- self .make_initializer (embedding , weight , to = self .io_dtype )
1074
-
1075
1080
basename = "/model/embed_tokens"
1076
- gather_name = f"{ basename } /Gather"
1077
- gather_output = f"{ gather_name } /output_0"
1078
- self .make_node ('Gather' , inputs = [weight , 'input_ids' ], outputs = [gather_output ], name = gather_name )
1081
+ if self .int4_tied_embeddings :
1082
+ gather_name = f"{ basename } /GatherBlockQuantized"
1083
+ gather_output = f"{ gather_name } /output_0"
1084
+
1085
+ weight_reshape_name = f"{ basename } /Reshape"
1086
+ bits = 8 if self .int8_lm_head else 4
1087
+ weight_reshape_inputs = [f"lm_head.MatMul.weight_Q{ bits } G{ self .int4_block_size } " , f"/model/constants/INT64/[{ self .vocab_size } , { self .hidden_size } ]" ]
1088
+ weight_reshape_output = f"{ weight_reshape_name } /output_0"
1089
+ # quantized weight dtype is uint8, see here
1090
+ # https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73
1091
+ self .make_reshape (weight_reshape_name , weight_reshape_inputs , dtype = ir .DataType .UINT8 , shape = ['vocab_size' , 'hidden_size' ])
1092
+
1093
+ self .make_node ('GatherBlockQuantized' , inputs = [weight_reshape_output , 'input_ids' , 'lm_head.MatMul.weight_scale' , 'lm_head.MatMul.weight_zp' ], outputs = [gather_output ], name = gather_name , domain = "com.microsoft" , bits = bits , block_size = int (self .int4_block_size ))
1094
+ else :
1095
+ weight = "model.embed_tokens.weight"
1096
+ self .make_initializer (embedding , weight , to = self .io_dtype )
1097
+
1098
+ gather_name = f"{ basename } /Gather"
1099
+ gather_output = f"{ gather_name } /output_0"
1100
+ self .make_node ('Gather' , inputs = [weight , 'input_ids' ], outputs = [gather_output ], name = gather_name )
1101
+
1079
1102
self .make_value (gather_output , self .io_dtype , shape = ['batch_size' , 'sequence_length' , self .hidden_size ])
1080
1103
1081
1104
if self .embed_attrs ["scale" ] != 1 :
@@ -4172,7 +4195,7 @@ def check_extra_options(kv_pairs):
4172
4195
"""
4173
4196
bools = [
4174
4197
"int4_is_symmetric" , "exclude_embeds" , "exclude_lm_head" , "include_hidden_states" , "enable_cuda_graph" ,
4175
- "use_8bits_moe" , "use_qdq" , "use_webgpu_fp32" , "use_cuda_bf16" ,
4198
+ "use_8bits_moe" , "use_qdq" , "use_webgpu_fp32" , "use_cuda_bf16" , "int4_tied_embeddings"
4176
4199
]
4177
4200
for key in bools :
4178
4201
if key in kv_pairs :
@@ -4459,6 +4482,8 @@ def get_args():
4459
4482
Currently supported options are: 'default', 'rtn', 'k_quant_mixed', 'k_quant_last'.
4460
4483
k_quant_mixed = k_quant algorithm with mixed precision (int4 + int8).
4461
4484
k_quant_last = k_quant algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4.
4485
+ int4_tied_embeddings = Enable weight sharing for quantization. Default is false.
4486
+ Use this option when you want to share the weights in the embedding and unembedding.
4462
4487
num_hidden_layers = Manually specify the number of layers in your ONNX model.
4463
4488
Used for unit testing purposes.
4464
4489
filename = Filename for ONNX model (default is 'model.onnx').
0 commit comments