@@ -114,10 +114,12 @@ def __init__(
114
114
self .weight = nn .Parameter (torch .zeros (hidden_size ))
115
115
self .variance_epsilon = eps
116
116
117
- def forward_native (
118
- self ,
117
+ @staticmethod
118
+ def forward_static (
119
+ weight : torch .Tensor ,
120
+ variance_epsilon : float ,
119
121
x : torch .Tensor ,
120
- residual : Optional [torch .Tensor ] = None ,
122
+ residual : Optional [torch .Tensor ],
121
123
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
122
124
"""PyTorch-native implementation equivalent to forward()."""
123
125
orig_dtype = x .dtype
@@ -127,17 +129,32 @@ def forward_native(
127
129
128
130
x = x .float ()
129
131
variance = x .pow (2 ).mean (dim = - 1 , keepdim = True )
130
- x = x * torch .rsqrt (variance + self . variance_epsilon )
132
+ x = x * torch .rsqrt (variance + variance_epsilon )
131
133
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
132
134
# See https://github.com/huggingface/transformers/pull/29402
133
- x = x * (1.0 + self . weight .float ())
135
+ x = x * (1.0 + weight .float ())
134
136
x = x .to (orig_dtype )
135
137
return x if residual is None else (x , residual )
136
138
139
+ def forward_native (
140
+ self ,
141
+ x : torch .Tensor ,
142
+ residual : Optional [torch .Tensor ] = None ,
143
+ ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
144
+ """PyTorch-native implementation equivalent to forward()."""
145
+ return self .forward_static (self .weight .data , self .variance_epsilon , x ,
146
+ residual )
147
+
137
148
def forward_cuda (
138
149
self ,
139
150
x : torch .Tensor ,
140
151
residual : Optional [torch .Tensor ] = None ,
141
152
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
142
- # TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
153
+ if torch .compiler .is_compiling ():
154
+ return self .forward_native (x , residual )
155
+
156
+ if not getattr (self , "_is_compiled" , False ):
157
+ self .forward_static = torch .compile ( # type: ignore
158
+ self .forward_static )
159
+ self ._is_compiled = True
143
160
return self .forward_native (x , residual )
0 commit comments