Skip to content

Commit e8cb9ef

Browse files
author
Alexander
committed
temp
1 parent a126971 commit e8cb9ef

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _single_head_attention(q, k, v):
133133

134134
def make_self_attention_mask(
135135
self, mask: Array
136-
) -> Float[Array, "num_heads seq_len seq_len"]:
136+
) -> Array:
137137
"""Create self-attention mask from sequence-level mask."""
138138
mask = jnp.multiply(
139139
jnp.expand_dims(mask, axis=-1), jnp.expand_dims(mask, axis=-2)

0 commit comments

Comments
 (0)