Skip to content

Commit 1e2e28f

Browse files
authored
Change Qwen2RMSNorm to RMSNorm from PyTorch (#40066)
* Unify Qwen2RMSNorm definitions and use RMSNorm from PyTorch Signed-off-by: cyy <[email protected]> * subclass RMSNorm Signed-off-by: cyy <[email protected]> --------- Signed-off-by: cyy <[email protected]>
1 parent 022af24 commit 1e2e28f

File tree

8 files changed

+50
-67
lines changed

8 files changed

+50
-67
lines changed

docs/source/en/model_doc/qwen2.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
160160

161161
[[autodoc]] Qwen2TokenizerFast
162162

163+
## Qwen2RMSNorm
164+
165+
[[autodoc]] Qwen2RMSNorm
166+
- forward
167+
163168
## Qwen2Model
164169

165170
[[autodoc]] Qwen2Model

src/transformers/models/dots1/modeling_dots1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@
4343

4444
@use_kernel_forward_from_hub("RMSNorm")
4545
class Dots1RMSNorm(nn.Module):
46-
def __init__(self, hidden_size, eps=1e-6):
46+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
4747
"""
4848
Dots1RMSNorm is equivalent to T5LayerNorm
4949
"""
5050
super().__init__()
5151
self.weight = nn.Parameter(torch.ones(hidden_size))
5252
self.variance_epsilon = eps
5353

54-
def forward(self, hidden_states):
54+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
5555
input_dtype = hidden_states.dtype
5656
hidden_states = hidden_states.to(torch.float32)
5757
variance = hidden_states.pow(2).mean(-1, keepdim=True)

src/transformers/models/qwen2/modeling_qwen2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,15 +185,15 @@ def forward(
185185

186186
@use_kernel_forward_from_hub("RMSNorm")
187187
class Qwen2RMSNorm(nn.Module):
188-
def __init__(self, hidden_size, eps=1e-6):
188+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
189189
"""
190190
Qwen2RMSNorm is equivalent to T5LayerNorm
191191
"""
192192
super().__init__()
193193
self.weight = nn.Parameter(torch.ones(hidden_size))
194194
self.variance_epsilon = eps
195195

196-
def forward(self, hidden_states):
196+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
197197
input_dtype = hidden_states.dtype
198198
hidden_states = hidden_states.to(torch.float32)
199199
variance = hidden_states.pow(2).mean(-1, keepdim=True)
@@ -497,6 +497,7 @@ class Qwen2ForQuestionAnswering(GenericForQuestionAnswering, Qwen2PreTrainedMode
497497
"Qwen2PreTrainedModel",
498498
"Qwen2Model",
499499
"Qwen2ForCausalLM",
500+
"Qwen2RMSNorm",
500501
"Qwen2ForSequenceClassification",
501502
"Qwen2ForTokenClassification",
502503
"Qwen2ForQuestionAnswering",

src/transformers/models/qwen2/modular_qwen2.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import torch
44
import torch.utils.checkpoint
5+
from packaging import version
56
from torch import nn
67

78
from ...cache_utils import Cache, DynamicCache
9+
from ...integrations import use_kernel_forward_from_hub
810
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
911
from ...modeling_flash_attention_utils import FlashAttentionKwargs
1012
from ...modeling_outputs import (
@@ -15,6 +17,7 @@
1517
from ...utils import TransformersKwargs, auto_docstring, logging
1618
from ...utils.deprecation import deprecate_kwarg
1719
from ...utils.generic import check_model_inputs
20+
from ...utils.import_utils import get_torch_version
1821
from ..llama.modeling_llama import (
1922
LlamaAttention,
2023
LlamaDecoderLayer,
@@ -97,6 +100,35 @@ def forward(
97100
return attn_output, attn_weights
98101

99102

103+
if version.parse(get_torch_version()) >= version.parse("2.3.0"):
104+
105+
class Qwen2RMSNorm(nn.RMSNorm):
106+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
107+
super().__init__(normalized_shape=hidden_size, eps=eps, elementwise_affine=True)
108+
109+
else:
110+
111+
@use_kernel_forward_from_hub("RMSNorm")
112+
class Qwen2RMSNorm(nn.Module):
113+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
114+
"""
115+
Qwen2RMSNorm is equivalent to T5LayerNorm
116+
"""
117+
super().__init__()
118+
self.weight = nn.Parameter(torch.ones(hidden_size))
119+
self.variance_epsilon = eps
120+
121+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
122+
input_dtype = hidden_states.dtype
123+
hidden_states = hidden_states.to(torch.float32)
124+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
125+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
126+
return self.weight * hidden_states.to(input_dtype)
127+
128+
def extra_repr(self):
129+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
130+
131+
100132
class Qwen2DecoderLayer(LlamaDecoderLayer):
101133
def __init__(self, config: Qwen2Config, layer_idx: int):
102134
super().__init__()
@@ -206,6 +238,7 @@ class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering):
206238
"Qwen2PreTrainedModel",
207239
"Qwen2Model",
208240
"Qwen2ForCausalLM",
241+
"Qwen2RMSNorm",
209242
"Qwen2ForSequenceClassification",
210243
"Qwen2ForTokenClassification",
211244
"Qwen2ForQuestionAnswering",

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ...utils import TransformersKwargs, auto_docstring, check_torch_load_is_safe, logging
4444
from ...utils.deprecation import deprecate_kwarg
4545
from ...utils.hub import cached_file
46+
from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
4647
from .configuration_qwen2_5_omni import (
4748
Qwen2_5OmniAudioEncoderConfig,
4849
Qwen2_5OmniBigVGANConfig,
@@ -986,26 +987,6 @@ def forward(self, hidden_state):
986987
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
987988

988989

989-
class Qwen2RMSNorm(nn.Module):
990-
def __init__(self, hidden_size, eps=1e-6):
991-
"""
992-
Qwen2RMSNorm is equivalent to T5LayerNorm
993-
"""
994-
super().__init__()
995-
self.weight = nn.Parameter(torch.ones(hidden_size))
996-
self.variance_epsilon = eps
997-
998-
def forward(self, hidden_states):
999-
input_dtype = hidden_states.dtype
1000-
hidden_states = hidden_states.to(torch.float32)
1001-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
1002-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
1003-
return self.weight * hidden_states.to(input_dtype)
1004-
1005-
def extra_repr(self):
1006-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
1007-
1008-
1009990
class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer):
1010991
def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None:
1011992
super().__init__()

src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ...processing_utils import Unpack
4444
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
4545
from ...utils.deprecation import deprecate_kwarg
46+
from ..qwen2.modeling_qwen2 import Qwen2RMSNorm
4647
from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig
4748

4849

@@ -103,26 +104,6 @@ def forward(self, seqlen: int) -> torch.Tensor:
103104
return freqs
104105

105106

106-
class Qwen2RMSNorm(nn.Module):
107-
def __init__(self, hidden_size, eps=1e-6):
108-
"""
109-
Qwen2RMSNorm is equivalent to T5LayerNorm
110-
"""
111-
super().__init__()
112-
self.weight = nn.Parameter(torch.ones(hidden_size))
113-
self.variance_epsilon = eps
114-
115-
def forward(self, hidden_states):
116-
input_dtype = hidden_states.dtype
117-
hidden_states = hidden_states.to(torch.float32)
118-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
119-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
120-
return self.weight * hidden_states.to(input_dtype)
121-
122-
def extra_repr(self):
123-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
124-
125-
126107
class Qwen2_5_VLPatchMerger(nn.Module):
127108
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
128109
super().__init__()

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
logging,
4747
)
4848
from ...utils.deprecation import deprecate_kwarg
49+
from ..qwen2.modeling_qwen2 import (
50+
Qwen2RMSNorm,
51+
)
4952
from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig
5053

5154

@@ -441,27 +444,6 @@ def forward(
441444
return hidden_states
442445

443446

444-
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RMSNorm
445-
class Qwen2RMSNorm(nn.Module):
446-
def __init__(self, hidden_size, eps=1e-6):
447-
"""
448-
Qwen2RMSNorm is equivalent to T5LayerNorm
449-
"""
450-
super().__init__()
451-
self.weight = nn.Parameter(torch.ones(hidden_size))
452-
self.variance_epsilon = eps
453-
454-
def forward(self, hidden_states):
455-
input_dtype = hidden_states.dtype
456-
hidden_states = hidden_states.to(torch.float32)
457-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
458-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
459-
return self.weight * hidden_states.to(input_dtype)
460-
461-
def extra_repr(self):
462-
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
463-
464-
465447
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
466448
class Qwen2MLP(nn.Module):
467449
def __init__(self, config):

src/transformers/models/qwen3/modeling_qwen3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@
4848

4949
@use_kernel_forward_from_hub("RMSNorm")
5050
class Qwen3RMSNorm(nn.Module):
51-
def __init__(self, hidden_size, eps=1e-6):
51+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
5252
"""
5353
Qwen3RMSNorm is equivalent to T5LayerNorm
5454
"""
5555
super().__init__()
5656
self.weight = nn.Parameter(torch.ones(hidden_size))
5757
self.variance_epsilon = eps
5858

59-
def forward(self, hidden_states):
59+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6060
input_dtype = hidden_states.dtype
6161
hidden_states = hidden_states.to(torch.float32)
6262
variance = hidden_states.pow(2).mean(-1, keepdim=True)

0 commit comments

Comments
 (0)