Skip to content

Commit b347e93

Browse files
authored
[typing] Fix return typehint for decoder and inv_freq annotation (#39610)
* fix return typehint for decoder and annotate inv_freq * fix modular * Fix consistency * Move annotation on class level * missing annotations * add comment
1 parent 7188e2e commit b347e93

File tree

93 files changed

+229
-59
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

93 files changed

+229
-59
lines changed

src/transformers/models/arcee/modeling_arcee.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def extra_repr(self):
8282

8383

8484
class ArceeRotaryEmbedding(nn.Module):
85+
inv_freq: torch.Tensor # fix linting for `register_buffer`
86+
8587
def __init__(self, config: ArceeConfig, device=None):
8688
super().__init__()
8789
# BC: "rope_type" was originally "type"
@@ -278,7 +280,7 @@ def forward(
278280
cache_position: Optional[torch.LongTensor] = None,
279281
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
280282
**kwargs: Unpack[TransformersKwargs],
281-
) -> tuple[torch.Tensor]:
283+
) -> torch.Tensor:
282284
residual = hidden_states
283285
hidden_states = self.input_layernorm(hidden_states)
284286
# Self Attention

src/transformers/models/aria/modeling_aria.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def forward(
598598
cache_position: Optional[torch.LongTensor] = None,
599599
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
600600
**kwargs: Unpack[TransformersKwargs],
601-
) -> tuple[torch.Tensor]:
601+
) -> torch.Tensor:
602602
residual = hidden_states
603603
hidden_states = self.input_layernorm(hidden_states)
604604
# Self Attention
@@ -668,6 +668,8 @@ def _init_weights(self, module):
668668

669669

670670
class AriaTextRotaryEmbedding(nn.Module):
671+
inv_freq: torch.Tensor # fix linting for `register_buffer`
672+
671673
def __init__(self, config: AriaTextConfig, device=None):
672674
super().__init__()
673675
# BC: "rope_type" was originally "type"

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTens
189189

190190

191191
class BambaRotaryEmbedding(nn.Module):
192+
inv_freq: torch.Tensor # fix linting for `register_buffer`
193+
192194
def __init__(self, config: BambaConfig, device=None):
193195
super().__init__()
194196
# BC: "rope_type" was originally "type"

src/transformers/models/bitnet/modeling_bitnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def forward(
184184
past_key_value: Optional[Cache] = None,
185185
cache_position: Optional[torch.LongTensor] = None,
186186
**kwargs: Unpack[FlashAttentionKwargs],
187-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
187+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
188188
input_shape = hidden_states.shape[:-1]
189189
hidden_shape = (*input_shape, -1, self.head_dim)
190190

@@ -243,7 +243,7 @@ def forward(
243243
cache_position: Optional[torch.LongTensor] = None,
244244
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
245245
**kwargs: Unpack[TransformersKwargs],
246-
) -> tuple[torch.Tensor]:
246+
) -> torch.Tensor:
247247
residual = hidden_states
248248
hidden_states = self.input_layernorm(hidden_states)
249249
# Self Attention
@@ -268,6 +268,8 @@ def forward(
268268

269269

270270
class BitNetRotaryEmbedding(nn.Module):
271+
inv_freq: torch.Tensor # fix linting for `register_buffer`
272+
271273
def __init__(self, config: BitNetConfig, device=None):
272274
super().__init__()
273275
# BC: "rope_type" was originally "type"

src/transformers/models/bitnet/modular_bitnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def forward(
6666
past_key_value: Optional[Cache] = None,
6767
cache_position: Optional[torch.LongTensor] = None,
6868
**kwargs: Unpack[FlashAttentionKwargs],
69-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
69+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
7070
input_shape = hidden_states.shape[:-1]
7171
hidden_shape = (*input_shape, -1, self.head_dim)
7272

src/transformers/models/chameleon/modeling_chameleon.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def extra_repr(self):
6767
# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
6868
# TODO(joao): add me back asap :)
6969
class ChameleonRotaryEmbedding(nn.Module):
70+
inv_freq: torch.Tensor # fix linting for `register_buffer`
71+
7072
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
7173
super().__init__()
7274
self.scaling_factor = scaling_factor

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def forward(self, hidden_states):
6565

6666

6767
class CohereRotaryEmbedding(nn.Module):
68+
inv_freq: torch.Tensor # fix linting for `register_buffer`
69+
6870
def __init__(self, config: CohereConfig, device=None):
6971
super().__init__()
7072
# BC: "rope_type" was originally "type"
@@ -233,7 +235,7 @@ def forward(
233235
past_key_value: Optional[Cache] = None,
234236
cache_position: Optional[torch.LongTensor] = None,
235237
**kwargs: Unpack[FlashAttentionKwargs],
236-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
238+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
237239
input_shape = hidden_states.shape[:-1]
238240
hidden_shape = (*input_shape, -1, self.head_dim)
239241

src/transformers/models/cohere/modular_cohere.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def forward(
153153
past_key_value: Optional[Cache] = None,
154154
cache_position: Optional[torch.LongTensor] = None,
155155
**kwargs: Unpack[FlashAttentionKwargs],
156-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
156+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
157157
input_shape = hidden_states.shape[:-1]
158158
hidden_shape = (*input_shape, -1, self.head_dim)
159159

src/transformers/models/cohere2/modeling_cohere2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141

4242
class Cohere2RotaryEmbedding(nn.Module):
43+
inv_freq: torch.Tensor # fix linting for `register_buffer`
44+
4345
def __init__(self, config: Cohere2Config, device=None):
4446
super().__init__()
4547
# BC: "rope_type" was originally "type"

src/transformers/models/csm/modeling_csm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def extra_repr(self):
118118

119119

120120
class CsmRotaryEmbedding(nn.Module):
121+
inv_freq: torch.Tensor # fix linting for `register_buffer`
122+
121123
def __init__(self, config: CsmConfig, device=None):
122124
super().__init__()
123125
# BC: "rope_type" was originally "type"
@@ -330,7 +332,7 @@ def forward(
330332
cache_position: Optional[torch.LongTensor] = None,
331333
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
332334
**kwargs: Unpack[TransformersKwargs],
333-
) -> tuple[torch.Tensor]:
335+
) -> torch.Tensor:
334336
residual = hidden_states
335337
hidden_states = self.input_layernorm(hidden_states)
336338
# Self Attention

0 commit comments

Comments
 (0)