Skip to content

Commit 9be4b92

Browse files
authored
Add tie_word_embedding option for Qwen2 model (#2535)
1 parent a096c91 commit 9be4b92

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

python/mlc_llm/model/qwen2/qwen2_model.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class QWen2Config(ConfigBase): # pylint: disable=too-many-instance-attributes
3333
rms_norm_eps: float
3434
rope_theta: int
3535
vocab_size: int
36+
tie_word_embeddings: bool = False
3637
context_window_size: int = 0
3738
prefill_chunk_size: int = 0
3839
tensor_parallel_shards: int = 1
@@ -120,6 +121,19 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
120121
}
121122

122123

124+
class Qwen2Embedding(nn.Embedding):
125+
"""The embedding module specialized for Qwen2 so that
126+
it can be shared with the final lm_head.
127+
"""
128+
129+
def lm_head_forward(self, x: nn.Tensor):
130+
"""The lm_head forwarding, which transposes the weight and multiplies
131+
with the input tensor.
132+
"""
133+
weight = nn.op.permute_dims(self.weight)
134+
return nn.op.matmul(x, weight, out_dtype="float32")
135+
136+
123137
class QWen2MLP(nn.Module):
124138
def __init__(self, config: QWen2Config):
125139
self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards
@@ -185,7 +199,7 @@ def _apply_residual(self, out, residual):
185199

186200
class QWen2Model(nn.Module):
187201
def __init__(self, config: QWen2Config):
188-
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
202+
self.embed_tokens = Qwen2Embedding(config.vocab_size, config.hidden_size)
189203
self.layers = nn.ModuleList(
190204
[QWen2DecoderLayer(config) for _ in range(config.num_hidden_layers)]
191205
)
@@ -202,7 +216,9 @@ def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):
202216
class QWen2LMHeadModel(nn.Module): # pylint: disable=too-many-instance-attributes
203217
def __init__(self, config: QWen2Config):
204218
self.model = QWen2Model(config)
205-
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
219+
self.tie_word_embeddings = config.tie_word_embeddings
220+
if not config.tie_word_embeddings:
221+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
206222
self.dtype = config.dtype
207223
self.hidden_size = config.hidden_size
208224
self.num_hidden_layers = config.num_hidden_layers
@@ -231,7 +247,11 @@ def batch_forward(
231247
hidden_states = self.model(input_embeds, paged_kv_cache)
232248
if logit_positions is not None:
233249
hidden_states = op.take(hidden_states, logit_positions, axis=1)
234-
logits = self.lm_head(hidden_states)
250+
251+
if self.tie_word_embeddings:
252+
logits = self.model.embed_tokens.lm_head_forward(hidden_states)
253+
else:
254+
logits = self.lm_head(hidden_states)
235255
if logits.dtype != "float32":
236256
logits = logits.astype("float32")
237257
return logits
@@ -250,7 +270,10 @@ def _index(x: te.Tensor): # x[:-1,:]
250270

251271
hidden_states = self.model(input_embed, paged_kv_cache)
252272
hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states])
253-
logits = self.lm_head(hidden_states)
273+
if self.tie_word_embeddings:
274+
logits = self.model.embed_tokens.lm_head_forward(hidden_states)
275+
else:
276+
logits = self.lm_head(hidden_states)
254277
if logits.dtype != "float32":
255278
logits = logits.astype("float32")
256279
return logits, paged_kv_cache
@@ -259,7 +282,10 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
259282
op_ext.configure()
260283

261284
hidden_states = self.model(input_embed, paged_kv_cache)
262-
logits = self.lm_head(hidden_states)
285+
if self.tie_word_embeddings:
286+
logits = self.model.embed_tokens.lm_head_forward(hidden_states)
287+
else:
288+
logits = self.lm_head(hidden_states)
263289
if logits.dtype != "float32":
264290
logits = logits.astype("float32")
265291
return logits, paged_kv_cache

0 commit comments

Comments
 (0)