Skip to content

Commit 95594fd

Browse files
author
Alexander
committed
fixed code formatting
1 parent 870de14 commit 95594fd

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

examples/Bert.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def __init__(
5555

5656
def __call__(
5757
self,
58-
token_ids: Int[Array, " seq_len"],
59-
position_ids: Int[Array, " seq_len"],
60-
segment_ids: Int[Array, " seq_len"],
58+
token_ids: Int[Array, "seq_len"],
59+
position_ids: Int[Array, "seq_len"],
60+
segment_ids: Int[Array, "seq_len"],
6161
enable_dropout: bool = False,
6262
key: jax.random.PRNGKey | None = None,
6363
) -> Float[Array, "seq_len hidden_size"]:
@@ -124,9 +124,9 @@ def __init__(
124124

125125
def __call__(
126126
self,
127-
token_ids: Int[Array, " seq_len"],
128-
position_ids: Int[Array, " seq_len"],
129-
segment_ids: Int[Array, " seq_len"],
127+
token_ids: Int[Array, "seq_len"],
128+
position_ids: Int[Array, "seq_len"],
129+
segment_ids: Int[Array, "seq_len"],
130130
*,
131131
enable_dropout: bool = False,
132132
key: jax.random.PRNGKey | None = None,
@@ -190,10 +190,10 @@ def __init__(self, config: Mapping, num_classes: int, key: jax.random.PRNGKey):
190190

191191
def __call__(
192192
self,
193-
inputs: dict[str, Int[Array, " seq_len"]],
193+
inputs: dict[str, Int[Array, "seq_len"]],
194194
enable_dropout: bool = True,
195195
key: jax.random.PRNGKey = None,
196-
) -> Float[Array, " num_classes"]:
196+
) -> Float[Array, "num_classes"]:
197197
seq_len = inputs["token_ids"].shape[-1]
198198
position_ids = jnp.arange(seq_len)
199199

0 commit comments

Comments
 (0)