Skip to content

Commit 7cfa4b2

Browse files
authored
[BugFix] Fix de-functionalization pass for rotary_embedding (vllm-project#23953)
Signed-off-by: angelayi <[email protected]>
1 parent b71fcd4 commit 7cfa4b2

File tree

3 files changed

+266
-87
lines changed

3 files changed

+266
-87
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ steps:
397397
- pytest -v -s compile/test_pass_manager.py
398398
- pytest -v -s compile/test_fusion.py
399399
- pytest -v -s compile/test_fusion_attn.py
400+
- pytest -v -s compile/test_functionalization.py
400401
- pytest -v -s compile/test_silu_mul_quant_fusion.py
401402
- pytest -v -s compile/test_sequence_parallelism.py
402403
- pytest -v -s compile/test_async_tp.py

tests/compile/test_functionalization.py

Lines changed: 228 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -5,54 +5,237 @@
55
import torch
66

77
import vllm.envs as envs
8-
from vllm import LLM, SamplingParams
98
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
109
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
11-
from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass
10+
from vllm.compilation.fusion import RMSNormQuantFusionPass
1211
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
1312
from vllm.compilation.noop_elimination import NoOpEliminationPass
1413
from vllm.compilation.post_cleanup import PostCleanupPass
1514
from vllm.config import CompilationConfig, PassConfig, VllmConfig
15+
from vllm.model_executor.layers.activation import SiluAndMul
16+
from vllm.model_executor.layers.layernorm import RMSNorm
1617
from vllm.model_executor.layers.quantization.utils.quant_utils import (
17-
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
18+
GroupShape)
19+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
20+
Fp8LinearOp)
21+
from vllm.model_executor.layers.rotary_embedding import get_rope
22+
from vllm.platforms import current_platform
1823

1924
from .backend import TestBackend
2025

21-
OPS_IN_MODEL = [
22-
torch.ops._C.rotary_embedding.default,
23-
torch.ops._C.fused_add_rms_norm.default,
24-
]
26+
TEST_FP8 = current_platform.supports_fp8()
27+
FP8_DTYPE = current_platform.fp8_dtype()
28+
29+
30+
class TestSiluMul(torch.nn.Module):
31+
32+
def __init__(self, hidden_size: int = 128):
33+
super().__init__()
34+
self.silu_and_mul = SiluAndMul()
35+
self.wscale = torch.rand(1, dtype=torch.float32)
36+
self.scale = torch.rand(1, dtype=torch.float32)
37+
38+
if TEST_FP8:
39+
self.w = torch.rand(hidden_size,
40+
hidden_size).to(dtype=FP8_DTYPE).t()
41+
self.fp8_linear = Fp8LinearOp(
42+
act_quant_static=True,
43+
act_quant_group_shape=GroupShape.PER_TENSOR,
44+
)
45+
46+
def forward(self, x):
47+
y = self.silu_and_mul(x)
48+
if TEST_FP8:
49+
x2 = self.fp8_linear.apply(y,
50+
self.w,
51+
self.wscale,
52+
input_scale=self.wscale)
53+
return x2
54+
else:
55+
return y
56+
57+
def example_inputs(self, num_tokens=32, hidden_size=128):
58+
dtype = torch.float16 if TEST_FP8 else torch.float32
59+
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype), )
60+
61+
def ops_in_model(self, do_fusion):
62+
if TEST_FP8 and do_fusion:
63+
return [torch.ops._C.silu_and_mul_quant.default]
64+
else:
65+
return [torch.ops._C.silu_and_mul.default]
66+
67+
def ops_not_in_model(self):
68+
return []
69+
70+
71+
class TestFusedAddRMSNorm(torch.nn.Module):
72+
73+
def __init__(self, hidden_size=16, intermediate_size=32):
74+
super().__init__()
75+
self.hidden_size = hidden_size
76+
self.intermediate_size = intermediate_size
77+
78+
dtype = torch.float16 if TEST_FP8 else torch.float32
79+
80+
self.gate_proj = torch.nn.Parameter(
81+
torch.empty((intermediate_size, hidden_size), dtype=dtype))
82+
self.norm = RMSNorm(intermediate_size, 1e-05)
83+
self.norm.weight = torch.nn.Parameter(
84+
torch.ones(intermediate_size, dtype=dtype))
85+
86+
torch.nn.init.normal_(self.gate_proj, std=0.02)
87+
88+
if TEST_FP8:
89+
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
90+
91+
self.scale = torch.rand(1, dtype=torch.float32)
92+
self.w = torch.rand(hidden_size,
93+
intermediate_size).to(dtype=FP8_DTYPE).t()
94+
self.wscale = torch.rand(1, dtype=torch.float32)
95+
96+
def forward(self, hidden_states, residual):
97+
# Reshape input
98+
view = hidden_states.reshape(-1, self.hidden_size)
99+
100+
# matrix multiplication
101+
permute = self.gate_proj.permute(1, 0)
102+
mm = torch.mm(view, permute)
103+
104+
# layer normalization
105+
norm_output, residual_output = self.norm(mm, residual)
106+
107+
if TEST_FP8:
108+
# scaled_mm with static input quantization
109+
fp8_linear_result = self.fp8_linear.apply(
110+
norm_output,
111+
self.w,
112+
self.wscale,
113+
input_scale=self.scale.to(norm_output.device),
114+
)
115+
116+
return fp8_linear_result, residual_output
117+
118+
else:
119+
return norm_output, residual_output
120+
121+
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
122+
dtype = torch.float16 if TEST_FP8 else torch.float32
123+
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
124+
dtype=dtype)
125+
residual = torch.randn((batch_size * seq_len, hidden_size),
126+
dtype=dtype)
127+
return (hidden_states, residual)
25128

26-
RMS_OP = torch.ops._C.rms_norm.default
129+
def ops_in_model(self, do_fusion):
130+
if TEST_FP8 and do_fusion:
131+
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
132+
else:
133+
return [torch.ops._C.fused_add_rms_norm.default]
27134

28-
RMS_QUANT_OPS = {
29-
"static_fp8": [
30-
torch.ops._C.rms_norm_static_fp8_quant.default,
31-
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
32-
],
33-
}
135+
def ops_not_in_model(self):
136+
return []
34137

35-
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
36138

37-
SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default
38-
prompts = [
39-
"Hello, my name is",
40-
"The president of the United States is",
41-
"The capital of France is",
42-
"The future of AI is",
139+
class TestRotaryEmbedding(torch.nn.Module):
140+
141+
def __init__(self,
142+
head_dim=64,
143+
rotary_dim=None,
144+
max_position=2048,
145+
base=10000):
146+
super().__init__()
147+
self.head_dim = head_dim
148+
self.rotary_dim = rotary_dim or head_dim
149+
150+
self.rotary_emb = get_rope(
151+
self.head_dim,
152+
rotary_dim=self.rotary_dim,
153+
max_position=max_position,
154+
base=base,
155+
)
156+
157+
def forward(self, positions, q, k):
158+
q_rotated, k_rotated = self.rotary_emb(positions, q, k)
159+
return q_rotated, k_rotated
160+
161+
def example_inputs(self, num_tokens=32, head_dim=64):
162+
dtype = torch.float16
163+
positions = torch.arange(num_tokens, dtype=torch.long)
164+
q = torch.randn(num_tokens, head_dim, dtype=dtype)
165+
k = torch.randn(num_tokens, head_dim, dtype=dtype)
166+
return (positions, q, k)
167+
168+
def ops_in_model(self, do_fusion):
169+
return [torch.ops._C.rotary_embedding.default]
170+
171+
def ops_not_in_model(self):
172+
return []
173+
174+
175+
class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
176+
177+
def __init__(self,
178+
head_dim=64,
179+
num_heads=4,
180+
max_position=2048,
181+
base=10000):
182+
super().__init__()
183+
self.head_dim = head_dim
184+
self.num_heads = num_heads
185+
self.hidden_size = head_dim * num_heads
186+
187+
self.qkv_proj = torch.nn.Linear(self.hidden_size,
188+
self.hidden_size * 3,
189+
bias=False,
190+
dtype=torch.float16)
191+
192+
self.rotary_emb = get_rope(
193+
self.head_dim,
194+
rotary_dim=self.head_dim,
195+
max_position=max_position,
196+
base=base,
197+
)
198+
199+
def forward(self, positions, hidden_states):
200+
# Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
201+
# -> slice_scatter -> split_with_sizes
202+
203+
qkv = self.qkv_proj(hidden_states)
204+
split_sizes = [self.hidden_size, self.hidden_size, self.hidden_size]
205+
q, k, v = torch.split(qkv, split_sizes, dim=-1)
206+
207+
q_rotated, k_rotated = self.rotary_emb(positions, q, k)
208+
209+
qkv_updated = torch.cat([q_rotated, k_rotated, v], dim=-1)
210+
return qkv_updated
211+
212+
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
213+
dtype = torch.float16
214+
hidden_size = head_dim * num_heads
215+
positions = torch.arange(num_tokens, dtype=torch.long)
216+
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
217+
return (positions, hidden_states)
218+
219+
def ops_in_model(self, do_fusion):
220+
return [torch.ops._C.rotary_embedding.default]
221+
222+
def ops_not_in_model(self):
223+
return [torch.ops.aten.slice_scatter.default]
224+
225+
226+
MODELS = [
227+
TestSiluMul,
228+
TestFusedAddRMSNorm,
229+
TestRotaryEmbedding,
230+
TestRotaryEmbeddingSliceScatter,
43231
]
44232

45233

46-
@pytest.mark.parametrize(
47-
"model, quant_key",
48-
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym),
49-
("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e",
50-
kFp8DynamicTokenSym)])
234+
@pytest.mark.parametrize("model_class", MODELS)
51235
@pytest.mark.parametrize("do_fusion", [True, False])
52236
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
53237
reason="Only test on CUDA")
54-
def test_fix_functionalization(model: str, quant_key: QuantKey,
55-
do_fusion: bool):
238+
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
56239
torch.set_default_device("cuda")
57240

58241
vllm_config = VllmConfig()
@@ -63,56 +246,31 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
63246
cleanup_pass = PostCleanupPass(vllm_config)
64247
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
65248

66-
passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass
67-
] if do_fusion else [noop_pass, cleanup_pass]
249+
passes = ([noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
250+
if do_fusion else [noop_pass, cleanup_pass])
68251
func_pass = FixFunctionalizationPass(vllm_config)
252+
69253
backend_func = TestBackend(*passes, func_pass)
70254
backend_no_func = TestBackend(*passes)
71255

72-
# instantiate a full engine and manually compile the model 2x
73-
# (with and without FixFunctionalizationPass)
74-
llm = LLM(model=model, enforce_eager=True)
75-
model_runner = llm.llm_engine.model_executor.driver_worker.model_runner
76-
orig_model = model_runner.model
77-
# TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
78-
# Can only do that by using the decorator but then we'd have to instantiate
79-
# 2 LLM instances.
80-
81-
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
82-
model_runner.model = torch.compile(orig_model,
83-
fullgraph=True,
84-
backend=backend_func)
85-
gen_func = llm.generate(prompts, sampling_params)
86-
87-
model_runner.model = torch.compile(orig_model,
88-
fullgraph=True,
89-
backend=backend_no_func)
90-
91-
gen_no_func = llm.generate(prompts, sampling_params)
92-
93-
for output_func, output_no_func in zip(gen_func, gen_no_func):
94-
assert output_func.outputs[0].text == output_no_func.outputs[0].text
95-
96-
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
97-
# and replaced by fused quantized ops in RMS_QUANT_OPS.
98-
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
99-
] if do_fusion else [RMS_OP]
100-
silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \
101-
quant_key == kFp8StaticTensorSym else [
102-
SILU_MUL_OP
103-
]
104-
105-
ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
106-
107-
for op in ops:
256+
model = model_class()
257+
torch.compile(model, backend=backend_func)(*model.example_inputs())
258+
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
259+
260+
# check if the functionalization pass is applied
261+
for op in model.ops_in_model(do_fusion):
108262
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
109-
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
110-
op) is None # noqa: E501
263+
assert (find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op)
264+
is None) # noqa: E501
111265

112266
# make sure the ops were all de-functionalized
113267
found = dict()
114268
for node in backend_func.graph_post_pass.nodes:
115-
for op in ops:
269+
for op in model.ops_in_model(do_fusion):
270+
if is_func(node, op):
271+
found[op] = True
272+
for op in model.ops_not_in_model():
116273
if is_func(node, op):
117274
found[op] = True
118-
assert all(found[op] for op in ops)
275+
assert all(found[op] for op in model.ops_in_model(do_fusion))
276+
assert all(not found.get(op) for op in model.ops_not_in_model())

0 commit comments

Comments
 (0)