@@ -33,6 +33,7 @@ class QWen2Config(ConfigBase): # pylint: disable=too-many-instance-attributes
33
33
rms_norm_eps : float
34
34
rope_theta : int
35
35
vocab_size : int
36
+ tie_word_embeddings : bool = False
36
37
context_window_size : int = 0
37
38
prefill_chunk_size : int = 0
38
39
tensor_parallel_shards : int = 1
@@ -120,6 +121,19 @@ def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id:
120
121
}
121
122
122
123
124
+ class Qwen2Embedding (nn .Embedding ):
125
+ """The embedding module specialized for Qwen2 so that
126
+ it can be shared with the final lm_head.
127
+ """
128
+
129
+ def lm_head_forward (self , x : nn .Tensor ):
130
+ """The lm_head forwarding, which transposes the weight and multiplies
131
+ with the input tensor.
132
+ """
133
+ weight = nn .op .permute_dims (self .weight )
134
+ return nn .op .matmul (x , weight , out_dtype = "float32" )
135
+
136
+
123
137
class QWen2MLP (nn .Module ):
124
138
def __init__ (self , config : QWen2Config ):
125
139
self .intermediate_size = config .intermediate_size // config .tensor_parallel_shards
@@ -185,7 +199,7 @@ def _apply_residual(self, out, residual):
185
199
186
200
class QWen2Model (nn .Module ):
187
201
def __init__ (self , config : QWen2Config ):
188
- self .embed_tokens = nn . Embedding (config .vocab_size , config .hidden_size )
202
+ self .embed_tokens = Qwen2Embedding (config .vocab_size , config .hidden_size )
189
203
self .layers = nn .ModuleList (
190
204
[QWen2DecoderLayer (config ) for _ in range (config .num_hidden_layers )]
191
205
)
@@ -202,7 +216,9 @@ def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):
202
216
class QWen2LMHeadModel (nn .Module ): # pylint: disable=too-many-instance-attributes
203
217
def __init__ (self , config : QWen2Config ):
204
218
self .model = QWen2Model (config )
205
- self .lm_head = nn .Linear (config .hidden_size , config .vocab_size , bias = False )
219
+ self .tie_word_embeddings = config .tie_word_embeddings
220
+ if not config .tie_word_embeddings :
221
+ self .lm_head = nn .Linear (config .hidden_size , config .vocab_size , bias = False )
206
222
self .dtype = config .dtype
207
223
self .hidden_size = config .hidden_size
208
224
self .num_hidden_layers = config .num_hidden_layers
@@ -231,7 +247,11 @@ def batch_forward(
231
247
hidden_states = self .model (input_embeds , paged_kv_cache )
232
248
if logit_positions is not None :
233
249
hidden_states = op .take (hidden_states , logit_positions , axis = 1 )
234
- logits = self .lm_head (hidden_states )
250
+
251
+ if self .tie_word_embeddings :
252
+ logits = self .model .embed_tokens .lm_head_forward (hidden_states )
253
+ else :
254
+ logits = self .lm_head (hidden_states )
235
255
if logits .dtype != "float32" :
236
256
logits = logits .astype ("float32" )
237
257
return logits
@@ -250,7 +270,10 @@ def _index(x: te.Tensor): # x[:-1,:]
250
270
251
271
hidden_states = self .model (input_embed , paged_kv_cache )
252
272
hidden_states = op .tensor_expr_op (_index , name_hint = "index" , args = [hidden_states ])
253
- logits = self .lm_head (hidden_states )
273
+ if self .tie_word_embeddings :
274
+ logits = self .model .embed_tokens .lm_head_forward (hidden_states )
275
+ else :
276
+ logits = self .lm_head (hidden_states )
254
277
if logits .dtype != "float32" :
255
278
logits = logits .astype ("float32" )
256
279
return logits , paged_kv_cache
@@ -259,7 +282,10 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):
259
282
op_ext .configure ()
260
283
261
284
hidden_states = self .model (input_embed , paged_kv_cache )
262
- logits = self .lm_head (hidden_states )
285
+ if self .tie_word_embeddings :
286
+ logits = self .model .embed_tokens .lm_head_forward (hidden_states )
287
+ else :
288
+ logits = self .lm_head (hidden_states )
263
289
if logits .dtype != "float32" :
264
290
logits = logits .astype ("float32" )
265
291
return logits , paged_kv_cache
0 commit comments