Skip to content

Commit 6906e46

Browse files
author
Alexander
committed
temp
1 parent 95594fd commit 6906e46

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

examples/Bert.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ def __init__(
5555

5656
def __call__(
5757
self,
58-
token_ids: Int[Array, "seq_len"],
59-
position_ids: Int[Array, "seq_len"],
60-
segment_ids: Int[Array, "seq_len"],
58+
token_ids: Array,
59+
position_ids: Array,
60+
segment_ids: Array,
6161
enable_dropout: bool = False,
6262
key: jax.random.PRNGKey | None = None,
63-
) -> Float[Array, "seq_len hidden_size"]:
63+
) -> Array:
6464
tokens = jax.vmap(self.token_embedder)(token_ids)
6565
segments = jax.vmap(self.segment_embedder)(segment_ids)
6666
positions = jax.vmap(self.position_embedder)(position_ids)
@@ -124,9 +124,9 @@ def __init__(
124124

125125
def __call__(
126126
self,
127-
token_ids: Int[Array, "seq_len"],
128-
position_ids: Int[Array, "seq_len"],
129-
segment_ids: Int[Array, "seq_len"],
127+
token_ids: Array,
128+
position_ids: Array,
129+
segment_ids: Array,
130130
*,
131131
enable_dropout: bool = False,
132132
key: jax.random.PRNGKey | None = None,
@@ -190,10 +190,10 @@ def __init__(self, config: Mapping, num_classes: int, key: jax.random.PRNGKey):
190190

191191
def __call__(
192192
self,
193-
inputs: dict[str, Int[Array, "seq_len"]],
193+
inputs: dict[str, Array],
194194
enable_dropout: bool = True,
195195
key: jax.random.PRNGKey = None,
196-
) -> Float[Array, "num_classes"]:
196+
) -> Array:
197197
seq_len = inputs["token_ids"].shape[-1]
198198
position_ids = jnp.arange(seq_len)
199199

examples/transformer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,11 @@ def __init__(
9393

9494
def __call__(
9595
self,
96-
inputs: Float[Array, "seq_len hidden_size"],
97-
mask: Int[Array, " seq_len"] | None,
96+
inputs: Array,
97+
mask: Array | None,
9898
enable_dropout: bool = False,
9999
key: "jax.random.PRNGKey" = None,
100-
) -> Float[Array, "seq_len hidden_size"]:
100+
) -> Array:
101101
if mask is not None:
102102
mask = self.make_self_attention_mask(mask)
103103
attention_key, dropout_key = (
@@ -132,7 +132,7 @@ def _single_head_attention(q, k, v):
132132
return result
133133

134134
def make_self_attention_mask(
135-
self, mask: Int[Array, " seq_len"]
135+
self, mask: Array
136136
) -> Float[Array, "num_heads seq_len seq_len"]:
137137
"""Create self-attention mask from sequence-level mask."""
138138
mask = jnp.multiply(
@@ -176,12 +176,12 @@ def __init__(
176176

177177
def __call__(
178178
self,
179-
inputs: Float[Array, "seq_len hidden_size"],
180-
mask: Int[Array, " seq_len"] | None = None,
179+
inputs: Array,
180+
mask: Array | None = None,
181181
*,
182182
enable_dropout: bool = False,
183183
key: jax.random.PRNGKey | None = None,
184-
) -> Float[Array, "seq_len hidden_size"]:
184+
) -> Array:
185185
attn_key, ff_key = (None, None) if key is None else jax.random.split(key)
186186
attention_output = self.attention_block(
187187
inputs, mask, enable_dropout=enable_dropout, key=attn_key

0 commit comments

Comments
 (0)