Skip to content

Commit d17d502

Browse files
Codlegpengzhi
authored andcommitted
Fix bug in rnn (#252)
* Fix bug in rnn When creating sequence_length, it should have same device as inputs. If not, the mask_sequence will meet a error. * Fix all creation ops * Revert "Fix all creation ops" This reverts commit 51c67f4. * fix all creations ops in rnn.py
1 parent 8097309 commit d17d502

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

texar/torch/utils/rnn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ def bidirectional_dynamic_rnn(
166166

167167
if sequence_length is None:
168168
sequence_length = torch.tensor([time_steps] * batch_size,
169-
dtype=torch.int32)
169+
dtype=torch.int32,
170+
device=inputs.device)
170171

171172
# Backward direction
172173
inputs_reverse = reverse_sequence(inputs=inputs,
@@ -278,7 +279,9 @@ def dynamic_rnn(
278279

279280
if sequence_length is not None:
280281
if not isinstance(sequence_length, torch.Tensor):
281-
sequence_length = torch.tensor(sequence_length, dtype=torch.int32)
282+
sequence_length = torch.tensor(sequence_length,
283+
dtype=torch.int32,
284+
device=inputs.device)
282285

283286
if sequence_length.dim() != 1:
284287
raise ValueError(
@@ -290,7 +293,8 @@ def dynamic_rnn(
290293
% sequence_length.shape)
291294
else:
292295
sequence_length = torch.tensor([time_steps] * batch_size,
293-
dtype=torch.int32)
296+
dtype=torch.int32,
297+
device=inputs.device)
294298

295299
if initial_state is not None:
296300
state = initial_state

0 commit comments

Comments
 (0)