3232import megengine .hub as hub
3333import numpy as np
3434from megengine import Parameter
35- from megengine .functional import cross_entropy_with_softmax
35+ from megengine .functional . loss import cross_entropy
3636from megengine .module import Dropout , Embedding , Linear , Module , Sequential
3737from megengine .module .activation import Softmax
3838
3939
4040def transpose (inp , a , b ):
41- cur_shape = list (range (0 , len ( inp .shape ) ))
41+ cur_shape = list (range (0 , inp .ndim ))
4242 cur_shape [a ], cur_shape [b ] = cur_shape [b ], cur_shape [a ]
43- return inp .dimshuffle ( * cur_shape )
43+ return inp .transpose ( cur_shape )
4444
4545
46- def matmul (a , b , transpose_b = None ):
47- dim = len (b .shape )
48-
49- if transpose_b :
50- b = transpose (b , dim - 1 , dim - 2 )
51-
52- if dim > 3 :
53- a_shape = list (a .shape )
54- b_shape = list (b .shape )
55- reshape_batch_size = 1
56- for i in a_shape [0 : dim - 2 ]:
57- reshape_batch_size *= i
58- a = a .reshape (* ([reshape_batch_size ] + a_shape [dim - 2 : dim ]))
59- b = b .reshape (* ([reshape_batch_size ] + b_shape [dim - 2 : dim ]))
60- c = F .batched_matrix_mul (a , b )
61- c = c .reshape (* (a_shape [0 : dim - 1 ] + b_shape [dim - 1 : dim ]))
62- return c
63- elif dim == 3 :
64- return F .batched_matrix_mul (a , b )
65- else :
66- return F .matrix_mul (a , b )
67-
68- def zeros_like (inp ):
69- return mge .zeros (inp .shape , dtype = inp .dtype )
70-
71- def ones_like (inp ):
72- return mge .ones (inp .shape , dtype = inp .dtype )
73-
7446def gelu (x ):
7547 """Implementation of the gelu activation function.
7648 For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
77- x * 0.5 * (1.0 + F.tanh((np .sqrt(2 / np .pi) * (x + 0.044715 * (x ** 3)))))
49+ x * 0.5 * (1.0 + F.tanh((F .sqrt(2 / math .pi) * (x + 0.044715 * (x ** 3)))))
7850 Also see https://arxiv.org/abs/1606.08415
7951 """
80- return x * 0.5 * (1.0 + F .tanh (( np .sqrt (2 / np .pi ) * (x + 0.044715 * (x ** 3 ) ))))
52+ return x * 0.5 * (1.0 + F .tanh (F .sqrt (2 / math .pi ) * (x + 0.044715 * (x ** 3 ))))
8153
8254
8355ACT2FN = {"gelu" : gelu , "relu" : F .relu }
@@ -221,10 +193,10 @@ def forward(self, input_ids, token_type_ids=None):
221193 seq_length = input_ids .shape [1 ]
222194
223195 if token_type_ids is None :
224- token_type_ids = zeros_like (input_ids )
196+ token_type_ids = F . zeros_like (input_ids )
225197
226198 position_ids = F .linspace (0 , seq_length - 1 , seq_length ).astype (np .int32 )
227- position_ids = F .add_axis ( position_ids , 0 ). broadcast ( * input_ids .shape )
199+ position_ids = F .broadcast_to ( F . expand_dims ( position_ids , 0 ), input_ids .shape )
228200 words_embeddings = self .word_embeddings (input_ids )
229201
230202 position_embeddings = self .position_embeddings (position_ids )
@@ -255,12 +227,11 @@ def __init__(self, config):
255227 self .dropout = Dropout (config .attention_probs_dropout_prob )
256228
257229 def transpose_for_scores (self , x ):
258- new_x_shape = x .shape [:- 1 ] + (
259- self .num_attention_heads ,
260- self .attention_head_size ,
261- )
262- x = x .reshape (* new_x_shape )
263- return x .dimshuffle (0 , 2 , 1 , 3 )
230+ # using symbolic shapes to make trace happy
231+ x_shape = mge .tensor (x .shape )
232+ new_x_shape = F .concat ([x_shape [:- 1 ], (self .num_attention_heads , self .attention_head_size )])
233+ x = x .reshape (new_x_shape )
234+ return x .transpose (0 , 2 , 1 , 3 )
264235
265236 def forward (self , hidden_states , attention_mask ):
266237 mixed_query_layer = self .query (hidden_states )
@@ -272,7 +243,7 @@ def forward(self, hidden_states, attention_mask):
272243 value_layer = self .transpose_for_scores (mixed_value_layer )
273244
274245 # Take the dot product between "query" and "key" to get the raw attention scores.
275- attention_scores = matmul (query_layer , transpose (key_layer , - 1 , - 2 ))
246+ attention_scores = F . matmul (query_layer , transpose (key_layer , - 1 , - 2 ))
276247 attention_scores = attention_scores / math .sqrt (self .attention_head_size )
277248 # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
278249 attention_scores = attention_scores + attention_mask
@@ -284,10 +255,12 @@ def forward(self, hidden_states, attention_mask):
284255 # seem a bit unusual, but is taken from the original Transformer paper.
285256 attention_probs = self .dropout (attention_probs )
286257
287- context_layer = matmul (attention_probs , value_layer )
288- context_layer = context_layer .dimshuffle (0 , 2 , 1 , 3 )
289- new_context_layer_shape = context_layer .shape [:- 2 ] + (self .all_head_size ,)
290- context_layer = context_layer .reshape (* new_context_layer_shape )
258+ context_layer = F .matmul (attention_probs , value_layer )
259+ context_layer = context_layer .transpose (0 , 2 , 1 , 3 )
260+ # using symbolic shapes to make trace happy
261+ context_shape = mge .tensor (context_layer .shape )
262+ new_context_layer_shape = F .concat ([context_shape [:- 2 ], self .all_head_size ])
263+ context_layer = context_layer .reshape (new_context_layer_shape )
291264 return context_layer
292265
293266
@@ -453,17 +426,17 @@ def forward(
453426 output_all_encoded_layers = True ,
454427 ):
455428 if attention_mask is None :
456- attention_mask = ones_like (input_ids )
429+ attention_mask = F . ones_like (input_ids )
457430 if token_type_ids is None :
458- token_type_ids = zeros_like (input_ids )
431+ token_type_ids = F . zeros_like (input_ids )
459432 # print('input_ids', input_ids.sum())
460433 # We create a 3D attention mask from a 2D tensor mask.
461434 # Sizes are [batch_size, 1, 1, to_seq_length]
462435 # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
463436 # this attention mask is more simple than the triangular masking of causal attention
464437 # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
465438 # print('attention_mask', attention_mask.sum())
466- extended_attention_mask = F .add_axis (attention_mask , (1 , 2 ))
439+ extended_attention_mask = F .expand_dims (attention_mask , (1 , 2 ))
467440
468441 # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
469442 # masked positions, this operation will create a tensor which is 0.0 for
@@ -554,7 +527,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=No
554527 logits = self .classifier (pooled_output )
555528
556529 if labels is not None :
557- loss = cross_entropy_with_softmax (
530+ loss = cross_entropy (
558531 logits .reshape (- 1 , self .num_labels ), labels .reshape (- 1 )
559532 )
560533 return logits , loss
0 commit comments