Skip to content

Commit 0b0b9d5

Browse files
authored
Quick fix for 0.1.7 keras-core release (#1251)
1 parent 5cf2c5e commit 0b0b9d5

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

keras_nlp/models/xlnet/xlnet_content_and_query_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ def positional_embedding(self, pos_seq, inv_freq, bsz=None):
6767

6868
def relative_positional_encoding(self, qlen, klen, bsz=None, clamp_len=-1):
6969
"""create relative positional encoding."""
70-
freq_seq = ops.arange(0, self.hidden_dim, 2.0)
70+
freq_seq = ops.arange(0, self.hidden_dim, 2.0, dtype=self.compute_dtype)
7171
inv_freq = 1 / (10000 ** (freq_seq / self.hidden_dim))
7272

7373
beg, end = klen, -qlen
7474

75-
fwd_pos_seq = ops.arange(beg, end, -1.0)
75+
fwd_pos_seq = ops.arange(beg, end, -1.0, dtype=self.compute_dtype)
7676
if clamp_len > 0:
7777
fwd_pos_seq = ops.clip(
7878
fwd_pos_seq, x_min=-clamp_len, x_max=clamp_len

0 commit comments

Comments
 (0)