Skip to content

Commit b3ddb21

Browse files
authored
Tie embedding weight sharing (microsoft#1690)
Follow up the idea in microsoft#1461 Now after GatherBlockQuantized implemented in microsoft/onnxruntime#25214, we can tie embedding here. Tested on phi-4-mini-instruct, cpu model size reduces from 5.15 GB to 2.69 GB (47.8% drop)
1 parent 13fe160 commit b3ddb21

File tree

1 file changed

+33
-8
lines changed

1 file changed

+33
-8
lines changed

src/python/py/models/builder.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,11 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
264264

265265
# Quantization-specific variables (INT4, INT8, etc.)
266266
int4_algo_config = self.make_int4_algo_config(extra_options.get("int4_algo_config", "default"))
267+
self.int4_block_size = extra_options.get("int4_block_size", 32)
267268
self.quant_attrs = {
268269
"int4": {
269270
"accuracy_level": int(extra_options.get("int4_accuracy_level", 4 if self.ep in ["cpu", "webgpu"] else 0)),
270-
"block_size": int(extra_options.get("int4_block_size", 32)),
271+
"block_size": int(self.int4_block_size),
271272
"is_symmetric": extra_options.get("int4_is_symmetric", True),
272273
"op_types_to_quantize": extra_options.get("int4_op_types_to_quantize", ("MatMul", )),
273274
"nodes_to_exclude": extra_options.get("int4_nodes_to_exclude", []),
@@ -280,6 +281,13 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
280281
self.quant_attrs["config"] = config.quantization_config
281282
self.quant_attrs["use_g_idx"] = config.quantization_config["desc_act"] if "desc_act" in config.quantization_config else False
282283

284+
self.int4_tied_embeddings = config.tie_word_embeddings if hasattr(config, "tie_word_embeddings") and config.tie_word_embeddings is not None else False
285+
self.int4_tied_embeddings = extra_options.get("int4_tied_embeddings", self.int4_tied_embeddings)
286+
self.int8_lm_head = extra_options.get("int4_algo_config", "default") in {"k_quant_mixed", "k_quant_last"}
287+
if not self.int8_lm_head:
288+
# matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match.
289+
self.int4_tied_embeddings = False
290+
283291
def to_str_dtype(self, dtype: ir.DataType) -> str:
284292
return dtype.name
285293

@@ -1069,13 +1077,28 @@ def make_packed_add(self, q_add, k_add, v_add, name, root_input, **kwargs):
10691077
self.make_add_bias(add, name, root_input, **kwargs)
10701078

10711079
def make_embedding(self, embedding):
1072-
weight = "model.embed_tokens.weight"
1073-
self.make_initializer(embedding, weight, to=self.io_dtype)
1074-
10751080
basename = "/model/embed_tokens"
1076-
gather_name = f"{basename}/Gather"
1077-
gather_output = f"{gather_name}/output_0"
1078-
self.make_node('Gather', inputs=[weight, 'input_ids'], outputs=[gather_output], name=gather_name)
1081+
if self.int4_tied_embeddings:
1082+
gather_name = f"{basename}/GatherBlockQuantized"
1083+
gather_output = f"{gather_name}/output_0"
1084+
1085+
weight_reshape_name = f"{basename}/Reshape"
1086+
bits = 8 if self.int8_lm_head else 4
1087+
weight_reshape_inputs = [f"lm_head.MatMul.weight_Q{bits}G{self.int4_block_size}", f"/model/constants/INT64/[{self.vocab_size}, {self.hidden_size}]"]
1088+
weight_reshape_output = f"{weight_reshape_name}/output_0"
1089+
# quantized weight dtype is uint8, see here
1090+
# https://github.com/microsoft/onnxruntime/blob/0c9356cb986fd4cd2c5d510909d31186010ba226/onnxruntime/python/tools/quantization/neural_compressor/weight_only.py#L73
1091+
self.make_reshape(weight_reshape_name, weight_reshape_inputs, dtype=ir.DataType.UINT8, shape=['vocab_size', 'hidden_size'])
1092+
1093+
self.make_node('GatherBlockQuantized', inputs=[weight_reshape_output, 'input_ids', 'lm_head.MatMul.weight_scale', 'lm_head.MatMul.weight_zp'], outputs=[gather_output], name=gather_name, domain="com.microsoft", bits=bits, block_size=int(self.int4_block_size))
1094+
else:
1095+
weight = "model.embed_tokens.weight"
1096+
self.make_initializer(embedding, weight, to=self.io_dtype)
1097+
1098+
gather_name = f"{basename}/Gather"
1099+
gather_output = f"{gather_name}/output_0"
1100+
self.make_node('Gather', inputs=[weight, 'input_ids'], outputs=[gather_output], name=gather_name)
1101+
10791102
self.make_value(gather_output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])
10801103

10811104
if self.embed_attrs["scale"] != 1:
@@ -4172,7 +4195,7 @@ def check_extra_options(kv_pairs):
41724195
"""
41734196
bools = [
41744197
"int4_is_symmetric", "exclude_embeds", "exclude_lm_head", "include_hidden_states", "enable_cuda_graph",
4175-
"use_8bits_moe", "use_qdq", "use_webgpu_fp32", "use_cuda_bf16",
4198+
"use_8bits_moe", "use_qdq", "use_webgpu_fp32", "use_cuda_bf16", "int4_tied_embeddings"
41764199
]
41774200
for key in bools:
41784201
if key in kv_pairs:
@@ -4459,6 +4482,8 @@ def get_args():
44594482
Currently supported options are: 'default', 'rtn', 'k_quant_mixed', 'k_quant_last'.
44604483
k_quant_mixed = k_quant algorithm with mixed precision (int4 + int8).
44614484
k_quant_last = k_quant algorithm where only the last MatMul (/lm_head/MatMul) is quantized as int8. Other MatMuls are quantized as int4.
4485+
int4_tied_embeddings = Enable weight sharing for quantization. Default is false.
4486+
Use this option when you want to share the weights in the embedding and unembedding.
44624487
num_hidden_layers = Manually specify the number of layers in your ONNX model.
44634488
Used for unit testing purposes.
44644489
filename = Filename for ONNX model (default is 'model.onnx').

0 commit comments

Comments
 (0)