Skip to content

Commit d5a0809

Browse files
authored
Fix consistency (#39995)
* modular * fix
1 parent b347e93 commit d5a0809

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

src/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def forward(self, hidden_states):
160160

161161

162162
class GptOssRotaryEmbedding(nn.Module):
163+
inv_freq: torch.Tensor # fix linting for `register_buffer`
164+
163165
def __init__(self, config: GptOssConfig, device=None):
164166
super().__init__()
165167
# BC: "rope_type" was originally "type"
@@ -348,7 +350,7 @@ def forward(
348350
cache_position: Optional[torch.LongTensor] = None,
349351
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
350352
**kwargs: Unpack[TransformersKwargs],
351-
) -> tuple[torch.Tensor]:
353+
) -> torch.Tensor:
352354
residual = hidden_states
353355
hidden_states = self.input_layernorm(hidden_states)
354356
# Self Attention

src/transformers/models/gpt_oss/modular_gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def forward(
307307
cache_position: Optional[torch.LongTensor] = None,
308308
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
309309
**kwargs: Unpack[TransformersKwargs],
310-
) -> tuple[torch.Tensor]:
310+
) -> torch.Tensor:
311311
residual = hidden_states
312312
hidden_states = self.input_layernorm(hidden_states)
313313
# Self Attention

0 commit comments

Comments
 (0)