@@ -251,7 +251,7 @@ def __init__(self, config):
251
251
def forward (self , inputs , seq_len ):
252
252
sort_lens , sort_idx = torch .sort (seq_len , dim = 0 , descending = True )
253
253
inputs = inputs [sort_idx ]
254
- inputs = nn .utils .rnn .pack_padded_sequence (inputs , sort_lens , batch_first = self .batch_first )
254
+ inputs = nn .utils .rnn .pack_padded_sequence (inputs , sort_lens . cpu () , batch_first = self .batch_first )
255
255
output , hx = self .encoder (inputs , None ) # -> [N,L,C]
256
256
output , _ = nn .utils .rnn .pad_packed_sequence (output , batch_first = self .batch_first )
257
257
_ , unsort_idx = torch .sort (sort_idx , dim = 0 , descending = False )
@@ -316,7 +316,7 @@ def forward(self, inputs, seq_len):
316
316
max_len = inputs .size (1 )
317
317
sort_lens , sort_idx = torch .sort (seq_len , dim = 0 , descending = True )
318
318
inputs = inputs [sort_idx ]
319
- inputs = nn .utils .rnn .pack_padded_sequence (inputs , sort_lens , batch_first = True )
319
+ inputs = nn .utils .rnn .pack_padded_sequence (inputs , sort_lens . cpu () , batch_first = True )
320
320
output , _ = self ._lstm_forward (inputs , None )
321
321
_ , unsort_idx = torch .sort (sort_idx , dim = 0 , descending = False )
322
322
output = output [:, unsort_idx ]
0 commit comments