Skip to content

Commit e78fc82

Browse files
committed
解决ELMO不支持使用cuda
1 parent c421a6d commit e78fc82

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

fastNLP/models/biaffine_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def forward(self, words1, words2, seq_len, target1=None):
376376
if self.encoder_name.endswith('lstm'):
377377
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
378378
x = x[sort_idx]
379-
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
379+
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=True)
380380
feat, _ = self.encoder(x) # -> [N,L,C]
381381
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
382382
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)

fastNLP/modules/encoder/_elmo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def __init__(self, config):
251251
def forward(self, inputs, seq_len):
252252
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
253253
inputs = inputs[sort_idx]
254-
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first)
254+
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=self.batch_first)
255255
output, hx = self.encoder(inputs, None) # -> [N,L,C]
256256
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first)
257257
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
@@ -316,7 +316,7 @@ def forward(self, inputs, seq_len):
316316
max_len = inputs.size(1)
317317
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
318318
inputs = inputs[sort_idx]
319-
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True)
319+
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens.cpu(), batch_first=True)
320320
output, _ = self._lstm_forward(inputs, None)
321321
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
322322
output = output[:, unsort_idx]

0 commit comments

Comments
 (0)