@@ -55,9 +55,9 @@ def __init__(
5555
5656 def __call__ (
5757 self ,
58- token_ids : Int [Array , " seq_len" ],
59- position_ids : Int [Array , " seq_len" ],
60- segment_ids : Int [Array , " seq_len" ],
58+ token_ids : Int [Array , "seq_len" ],
59+ position_ids : Int [Array , "seq_len" ],
60+ segment_ids : Int [Array , "seq_len" ],
6161 enable_dropout : bool = False ,
6262 key : jax .random .PRNGKey | None = None ,
6363 ) -> Float [Array , "seq_len hidden_size" ]:
@@ -124,9 +124,9 @@ def __init__(
124124
125125 def __call__ (
126126 self ,
127- token_ids : Int [Array , " seq_len" ],
128- position_ids : Int [Array , " seq_len" ],
129- segment_ids : Int [Array , " seq_len" ],
127+ token_ids : Int [Array , "seq_len" ],
128+ position_ids : Int [Array , "seq_len" ],
129+ segment_ids : Int [Array , "seq_len" ],
130130 * ,
131131 enable_dropout : bool = False ,
132132 key : jax .random .PRNGKey | None = None ,
@@ -190,10 +190,10 @@ def __init__(self, config: Mapping, num_classes: int, key: jax.random.PRNGKey):
190190
191191 def __call__ (
192192 self ,
193- inputs : dict [str , Int [Array , " seq_len" ]],
193+ inputs : dict [str , Int [Array , "seq_len" ]],
194194 enable_dropout : bool = True ,
195195 key : jax .random .PRNGKey = None ,
196- ) -> Float [Array , " num_classes" ]:
196+ ) -> Float [Array , "num_classes" ]:
197197 seq_len = inputs ["token_ids" ].shape [- 1 ]
198198 position_ids = jnp .arange (seq_len )
199199
0 commit comments