5
5
import torch
6
6
7
7
import vllm .envs as envs
8
- from vllm import LLM , SamplingParams
9
8
from vllm .compilation .activation_quant_fusion import ActivationQuantFusionPass
10
9
from vllm .compilation .fix_functionalization import FixFunctionalizationPass
11
- from vllm .compilation .fusion import FUSED_OPS , RMSNormQuantFusionPass
10
+ from vllm .compilation .fusion import RMSNormQuantFusionPass
12
11
from vllm .compilation .fx_utils import find_auto_fn , find_auto_fn_maybe , is_func
13
12
from vllm .compilation .noop_elimination import NoOpEliminationPass
14
13
from vllm .compilation .post_cleanup import PostCleanupPass
15
14
from vllm .config import CompilationConfig , PassConfig , VllmConfig
15
+ from vllm .model_executor .layers .activation import SiluAndMul
16
+ from vllm .model_executor .layers .layernorm import RMSNorm
16
17
from vllm .model_executor .layers .quantization .utils .quant_utils import (
17
- QuantKey , kFp8DynamicTokenSym , kFp8StaticTensorSym )
18
+ GroupShape )
19
+ from vllm .model_executor .layers .quantization .utils .w8a8_utils import (
20
+ Fp8LinearOp )
21
+ from vllm .model_executor .layers .rotary_embedding import get_rope
22
+ from vllm .platforms import current_platform
18
23
19
24
from .backend import TestBackend
20
25
21
- OPS_IN_MODEL = [
22
- torch .ops ._C .rotary_embedding .default ,
23
- torch .ops ._C .fused_add_rms_norm .default ,
24
- ]
26
+ TEST_FP8 = current_platform .supports_fp8 ()
27
+ FP8_DTYPE = current_platform .fp8_dtype ()
28
+
29
+
30
+ class TestSiluMul (torch .nn .Module ):
31
+
32
+ def __init__ (self , hidden_size : int = 128 ):
33
+ super ().__init__ ()
34
+ self .silu_and_mul = SiluAndMul ()
35
+ self .wscale = torch .rand (1 , dtype = torch .float32 )
36
+ self .scale = torch .rand (1 , dtype = torch .float32 )
37
+
38
+ if TEST_FP8 :
39
+ self .w = torch .rand (hidden_size ,
40
+ hidden_size ).to (dtype = FP8_DTYPE ).t ()
41
+ self .fp8_linear = Fp8LinearOp (
42
+ act_quant_static = True ,
43
+ act_quant_group_shape = GroupShape .PER_TENSOR ,
44
+ )
45
+
46
+ def forward (self , x ):
47
+ y = self .silu_and_mul (x )
48
+ if TEST_FP8 :
49
+ x2 = self .fp8_linear .apply (y ,
50
+ self .w ,
51
+ self .wscale ,
52
+ input_scale = self .wscale )
53
+ return x2
54
+ else :
55
+ return y
56
+
57
+ def example_inputs (self , num_tokens = 32 , hidden_size = 128 ):
58
+ dtype = torch .float16 if TEST_FP8 else torch .float32
59
+ return (torch .rand (num_tokens , hidden_size * 2 , dtype = dtype ), )
60
+
61
+ def ops_in_model (self , do_fusion ):
62
+ if TEST_FP8 and do_fusion :
63
+ return [torch .ops ._C .silu_and_mul_quant .default ]
64
+ else :
65
+ return [torch .ops ._C .silu_and_mul .default ]
66
+
67
+ def ops_not_in_model (self ):
68
+ return []
69
+
70
+
71
+ class TestFusedAddRMSNorm (torch .nn .Module ):
72
+
73
+ def __init__ (self , hidden_size = 16 , intermediate_size = 32 ):
74
+ super ().__init__ ()
75
+ self .hidden_size = hidden_size
76
+ self .intermediate_size = intermediate_size
77
+
78
+ dtype = torch .float16 if TEST_FP8 else torch .float32
79
+
80
+ self .gate_proj = torch .nn .Parameter (
81
+ torch .empty ((intermediate_size , hidden_size ), dtype = dtype ))
82
+ self .norm = RMSNorm (intermediate_size , 1e-05 )
83
+ self .norm .weight = torch .nn .Parameter (
84
+ torch .ones (intermediate_size , dtype = dtype ))
85
+
86
+ torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
87
+
88
+ if TEST_FP8 :
89
+ self .fp8_linear = Fp8LinearOp (act_quant_static = True )
90
+
91
+ self .scale = torch .rand (1 , dtype = torch .float32 )
92
+ self .w = torch .rand (hidden_size ,
93
+ intermediate_size ).to (dtype = FP8_DTYPE ).t ()
94
+ self .wscale = torch .rand (1 , dtype = torch .float32 )
95
+
96
+ def forward (self , hidden_states , residual ):
97
+ # Reshape input
98
+ view = hidden_states .reshape (- 1 , self .hidden_size )
99
+
100
+ # matrix multiplication
101
+ permute = self .gate_proj .permute (1 , 0 )
102
+ mm = torch .mm (view , permute )
103
+
104
+ # layer normalization
105
+ norm_output , residual_output = self .norm (mm , residual )
106
+
107
+ if TEST_FP8 :
108
+ # scaled_mm with static input quantization
109
+ fp8_linear_result = self .fp8_linear .apply (
110
+ norm_output ,
111
+ self .w ,
112
+ self .wscale ,
113
+ input_scale = self .scale .to (norm_output .device ),
114
+ )
115
+
116
+ return fp8_linear_result , residual_output
117
+
118
+ else :
119
+ return norm_output , residual_output
120
+
121
+ def example_inputs (self , batch_size = 8 , hidden_size = 16 , seq_len = 16 ):
122
+ dtype = torch .float16 if TEST_FP8 else torch .float32
123
+ hidden_states = torch .randn ((batch_size * seq_len , hidden_size ),
124
+ dtype = dtype )
125
+ residual = torch .randn ((batch_size * seq_len , hidden_size ),
126
+ dtype = dtype )
127
+ return (hidden_states , residual )
25
128
26
- RMS_OP = torch .ops ._C .rms_norm .default
129
+ def ops_in_model (self , do_fusion ):
130
+ if TEST_FP8 and do_fusion :
131
+ return [torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default ]
132
+ else :
133
+ return [torch .ops ._C .fused_add_rms_norm .default ]
27
134
28
- RMS_QUANT_OPS = {
29
- "static_fp8" : [
30
- torch .ops ._C .rms_norm_static_fp8_quant .default ,
31
- torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default
32
- ],
33
- }
135
+ def ops_not_in_model (self ):
136
+ return []
34
137
35
- SILU_MUL_OP = torch .ops ._C .silu_and_mul .default
36
138
37
- SILU_MUL_QUANT_OP = torch .ops ._C .silu_and_mul_quant .default
38
- prompts = [
39
- "Hello, my name is" ,
40
- "The president of the United States is" ,
41
- "The capital of France is" ,
42
- "The future of AI is" ,
139
+ class TestRotaryEmbedding (torch .nn .Module ):
140
+
141
+ def __init__ (self ,
142
+ head_dim = 64 ,
143
+ rotary_dim = None ,
144
+ max_position = 2048 ,
145
+ base = 10000 ):
146
+ super ().__init__ ()
147
+ self .head_dim = head_dim
148
+ self .rotary_dim = rotary_dim or head_dim
149
+
150
+ self .rotary_emb = get_rope (
151
+ self .head_dim ,
152
+ rotary_dim = self .rotary_dim ,
153
+ max_position = max_position ,
154
+ base = base ,
155
+ )
156
+
157
+ def forward (self , positions , q , k ):
158
+ q_rotated , k_rotated = self .rotary_emb (positions , q , k )
159
+ return q_rotated , k_rotated
160
+
161
+ def example_inputs (self , num_tokens = 32 , head_dim = 64 ):
162
+ dtype = torch .float16
163
+ positions = torch .arange (num_tokens , dtype = torch .long )
164
+ q = torch .randn (num_tokens , head_dim , dtype = dtype )
165
+ k = torch .randn (num_tokens , head_dim , dtype = dtype )
166
+ return (positions , q , k )
167
+
168
+ def ops_in_model (self , do_fusion ):
169
+ return [torch .ops ._C .rotary_embedding .default ]
170
+
171
+ def ops_not_in_model (self ):
172
+ return []
173
+
174
+
175
+ class TestRotaryEmbeddingSliceScatter (torch .nn .Module ):
176
+
177
+ def __init__ (self ,
178
+ head_dim = 64 ,
179
+ num_heads = 4 ,
180
+ max_position = 2048 ,
181
+ base = 10000 ):
182
+ super ().__init__ ()
183
+ self .head_dim = head_dim
184
+ self .num_heads = num_heads
185
+ self .hidden_size = head_dim * num_heads
186
+
187
+ self .qkv_proj = torch .nn .Linear (self .hidden_size ,
188
+ self .hidden_size * 3 ,
189
+ bias = False ,
190
+ dtype = torch .float16 )
191
+
192
+ self .rotary_emb = get_rope (
193
+ self .head_dim ,
194
+ rotary_dim = self .head_dim ,
195
+ max_position = max_position ,
196
+ base = base ,
197
+ )
198
+
199
+ def forward (self , positions , hidden_states ):
200
+ # Simulate the pattern: mm -> split_with_sizes -> rotary_embedding
201
+ # -> slice_scatter -> split_with_sizes
202
+
203
+ qkv = self .qkv_proj (hidden_states )
204
+ split_sizes = [self .hidden_size , self .hidden_size , self .hidden_size ]
205
+ q , k , v = torch .split (qkv , split_sizes , dim = - 1 )
206
+
207
+ q_rotated , k_rotated = self .rotary_emb (positions , q , k )
208
+
209
+ qkv_updated = torch .cat ([q_rotated , k_rotated , v ], dim = - 1 )
210
+ return qkv_updated
211
+
212
+ def example_inputs (self , num_tokens = 32 , head_dim = 64 , num_heads = 4 ):
213
+ dtype = torch .float16
214
+ hidden_size = head_dim * num_heads
215
+ positions = torch .arange (num_tokens , dtype = torch .long )
216
+ hidden_states = torch .randn (num_tokens , hidden_size , dtype = dtype )
217
+ return (positions , hidden_states )
218
+
219
+ def ops_in_model (self , do_fusion ):
220
+ return [torch .ops ._C .rotary_embedding .default ]
221
+
222
+ def ops_not_in_model (self ):
223
+ return [torch .ops .aten .slice_scatter .default ]
224
+
225
+
226
+ MODELS = [
227
+ TestSiluMul ,
228
+ TestFusedAddRMSNorm ,
229
+ TestRotaryEmbedding ,
230
+ TestRotaryEmbeddingSliceScatter ,
43
231
]
44
232
45
233
46
- @pytest .mark .parametrize (
47
- "model, quant_key" ,
48
- [("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" , kFp8StaticTensorSym ),
49
- ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e" ,
50
- kFp8DynamicTokenSym )])
234
+ @pytest .mark .parametrize ("model_class" , MODELS )
51
235
@pytest .mark .parametrize ("do_fusion" , [True , False ])
52
236
@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE != "cuda" ,
53
237
reason = "Only test on CUDA" )
54
- def test_fix_functionalization (model : str , quant_key : QuantKey ,
55
- do_fusion : bool ):
238
+ def test_fix_functionalization (model_class : torch .nn .Module , do_fusion : bool ):
56
239
torch .set_default_device ("cuda" )
57
240
58
241
vllm_config = VllmConfig ()
@@ -63,56 +246,31 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
63
246
cleanup_pass = PostCleanupPass (vllm_config )
64
247
act_quant_fusion_pass = ActivationQuantFusionPass (vllm_config )
65
248
66
- passes = [noop_pass , fusion_pass , act_quant_fusion_pass , cleanup_pass
67
- ] if do_fusion else [noop_pass , cleanup_pass ]
249
+ passes = ( [noop_pass , fusion_pass , act_quant_fusion_pass , cleanup_pass ]
250
+ if do_fusion else [noop_pass , cleanup_pass ])
68
251
func_pass = FixFunctionalizationPass (vllm_config )
252
+
69
253
backend_func = TestBackend (* passes , func_pass )
70
254
backend_no_func = TestBackend (* passes )
71
255
72
- # instantiate a full engine and manually compile the model 2x
73
- # (with and without FixFunctionalizationPass)
74
- llm = LLM (model = model , enforce_eager = True )
75
- model_runner = llm .llm_engine .model_executor .driver_worker .model_runner
76
- orig_model = model_runner .model
77
- # TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
78
- # Can only do that by using the decorator but then we'd have to instantiate
79
- # 2 LLM instances.
80
-
81
- sampling_params = SamplingParams (temperature = 0.0 , top_p = 1.0 )
82
- model_runner .model = torch .compile (orig_model ,
83
- fullgraph = True ,
84
- backend = backend_func )
85
- gen_func = llm .generate (prompts , sampling_params )
86
-
87
- model_runner .model = torch .compile (orig_model ,
88
- fullgraph = True ,
89
- backend = backend_no_func )
90
-
91
- gen_no_func = llm .generate (prompts , sampling_params )
92
-
93
- for output_func , output_no_func in zip (gen_func , gen_no_func ):
94
- assert output_func .outputs [0 ].text == output_no_func .outputs [0 ].text
95
-
96
- # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
97
- # and replaced by fused quantized ops in RMS_QUANT_OPS.
98
- rms_ops = [FUSED_OPS [(quant_key , True )], FUSED_OPS [(quant_key , False )]
99
- ] if do_fusion else [RMS_OP ]
100
- silu_mul_ops = [SILU_MUL_QUANT_OP ] if do_fusion and \
101
- quant_key == kFp8StaticTensorSym else [
102
- SILU_MUL_OP
103
- ]
104
-
105
- ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
106
-
107
- for op in ops :
256
+ model = model_class ()
257
+ torch .compile (model , backend = backend_func )(* model .example_inputs ())
258
+ torch .compile (model , backend = backend_no_func )(* model .example_inputs ())
259
+
260
+ # check if the functionalization pass is applied
261
+ for op in model .ops_in_model (do_fusion ):
108
262
find_auto_fn (backend_no_func .graph_post_pass .nodes , op )
109
- assert find_auto_fn_maybe (backend_func .graph_post_pass .nodes ,
110
- op ) is None # noqa: E501
263
+ assert ( find_auto_fn_maybe (backend_func .graph_post_pass .nodes , op )
264
+ is None ) # noqa: E501
111
265
112
266
# make sure the ops were all de-functionalized
113
267
found = dict ()
114
268
for node in backend_func .graph_post_pass .nodes :
115
- for op in ops :
269
+ for op in model .ops_in_model (do_fusion ):
270
+ if is_func (node , op ):
271
+ found [op ] = True
272
+ for op in model .ops_not_in_model ():
116
273
if is_func (node , op ):
117
274
found [op ] = True
118
- assert all (found [op ] for op in ops )
275
+ assert all (found [op ] for op in model .ops_in_model (do_fusion ))
276
+ assert all (not found .get (op ) for op in model .ops_not_in_model ())
0 commit comments