Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit b3856be

Browse files
authored
[Misc] Use torch.compile for GemmaRMSNorm (vllm-project#7642)
1 parent 8c6f694 commit b3856be

File tree

1 file changed

+23
-6
lines changed

1 file changed

+23
-6
lines changed

vllm/model_executor/layers/layernorm.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,12 @@ def __init__(
114114
self.weight = nn.Parameter(torch.zeros(hidden_size))
115115
self.variance_epsilon = eps
116116

117-
def forward_native(
118-
self,
117+
@staticmethod
118+
def forward_static(
119+
weight: torch.Tensor,
120+
variance_epsilon: float,
119121
x: torch.Tensor,
120-
residual: Optional[torch.Tensor] = None,
122+
residual: Optional[torch.Tensor],
121123
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
122124
"""PyTorch-native implementation equivalent to forward()."""
123125
orig_dtype = x.dtype
@@ -127,17 +129,32 @@ def forward_native(
127129

128130
x = x.float()
129131
variance = x.pow(2).mean(dim=-1, keepdim=True)
130-
x = x * torch.rsqrt(variance + self.variance_epsilon)
132+
x = x * torch.rsqrt(variance + variance_epsilon)
131133
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
132134
# See https://github.com/huggingface/transformers/pull/29402
133-
x = x * (1.0 + self.weight.float())
135+
x = x * (1.0 + weight.float())
134136
x = x.to(orig_dtype)
135137
return x if residual is None else (x, residual)
136138

139+
def forward_native(
140+
self,
141+
x: torch.Tensor,
142+
residual: Optional[torch.Tensor] = None,
143+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
144+
"""PyTorch-native implementation equivalent to forward()."""
145+
return self.forward_static(self.weight.data, self.variance_epsilon, x,
146+
residual)
147+
137148
def forward_cuda(
138149
self,
139150
x: torch.Tensor,
140151
residual: Optional[torch.Tensor] = None,
141152
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
142-
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
153+
if torch.compiler.is_compiling():
154+
return self.forward_native(x, residual)
155+
156+
if not getattr(self, "_is_compiled", False):
157+
self.forward_static = torch.compile( # type: ignore
158+
self.forward_static)
159+
self._is_compiled = True
143160
return self.forward_native(x, residual)

0 commit comments

Comments
 (0)