Skip to content

Commit fd2f145

Browse files
authored
fix rope device for long sequence (#2514)
* fix rope device for long sequence * restore device removed by mistake
1 parent f3059a5 commit fd2f145

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

onmt/modules/multi_headed_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def forward(
421421
if seqlen > self.rope.size(0):
422422
self.rope = rotaryembeddings(
423423
self.dim_per_head, maxseqlen=(seqlen + 2048)
424-
)
424+
).to(self.rope.device)
425425
rope = self.rope[start_pos : start_pos + seqlen]
426426
query, key = apply_rotary_emb(
427427
query, key, rope, interleave=self.rotary_interleave
@@ -465,8 +465,8 @@ def forward(
465465
if seqlen > self.rope.size(0):
466466
self.rope = rotaryembeddings(
467467
self.dim_per_head, maxseqlen=(seqlen + 2048)
468-
)
469-
rope = self.rope[start_pos : start_pos + seqlen]
468+
).to(self.rope.device)
469+
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
470470
query, key = apply_rotary_emb(
471471
query, key, rope, interleave=self.rotary_interleave
472472
)

0 commit comments

Comments
 (0)