Skip to content

Commit a126971

Browse files
author
Alexander
committed
temp
1 parent 6906e46 commit a126971

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ def __init__(
3434

3535
def __call__(
3636
self,
37-
inputs: Float[Array, " hidden_size"],
37+
inputs: Array,
3838
enable_dropout: bool = True,
3939
key: jax.random.PRNGKey | None = None,
40-
) -> Float[Array, " hidden_size"]:
40+
) -> Array:
4141
# Feed-forward.
4242
hidden = self.mlp(inputs)
4343
hidden = jax.nn.gelu(hidden)

0 commit comments

Comments
 (0)