@@ -55,12 +55,12 @@ def __init__(
5555
5656 def __call__ (
5757 self ,
58- token_ids : Int [ Array , "seq_len" ] ,
59- position_ids : Int [ Array , "seq_len" ] ,
60- segment_ids : Int [ Array , "seq_len" ] ,
58+ token_ids : Array ,
59+ position_ids : Array ,
60+ segment_ids : Array ,
6161 enable_dropout : bool = False ,
6262 key : jax .random .PRNGKey | None = None ,
63- ) -> Float [ Array , "seq_len hidden_size" ] :
63+ ) -> Array :
6464 tokens = jax .vmap (self .token_embedder )(token_ids )
6565 segments = jax .vmap (self .segment_embedder )(segment_ids )
6666 positions = jax .vmap (self .position_embedder )(position_ids )
@@ -124,9 +124,9 @@ def __init__(
124124
125125 def __call__ (
126126 self ,
127- token_ids : Int [ Array , "seq_len" ] ,
128- position_ids : Int [ Array , "seq_len" ] ,
129- segment_ids : Int [ Array , "seq_len" ] ,
127+ token_ids : Array ,
128+ position_ids : Array ,
129+ segment_ids : Array ,
130130 * ,
131131 enable_dropout : bool = False ,
132132 key : jax .random .PRNGKey | None = None ,
@@ -190,10 +190,10 @@ def __init__(self, config: Mapping, num_classes: int, key: jax.random.PRNGKey):
190190
191191 def __call__ (
192192 self ,
193- inputs : dict [str , Int [ Array , "seq_len" ] ],
193+ inputs : dict [str , Array ],
194194 enable_dropout : bool = True ,
195195 key : jax .random .PRNGKey = None ,
196- ) -> Float [ Array , "num_classes" ] :
196+ ) -> Array :
197197 seq_len = inputs ["token_ids" ].shape [- 1 ]
198198 position_ids = jnp .arange (seq_len )
199199
0 commit comments