22
22
multi_gpu_test )
23
23
from .backend import TestBackend
24
24
25
+ FP8_DTYPE = current_platform .fp8_dtype ()
26
+
25
27
prompts = [
26
28
"Hello, my name is" ,
27
29
"The president of the United States is" ,
32
34
33
35
class TestMMRSModel (torch .nn .Module ):
34
36
35
- def __init__ (self , hidden_size = 16 ):
37
+ def __init__ (self , hidden_size = 16 , dtype = torch . float16 ):
36
38
super ().__init__ ()
37
39
self .hidden_size = hidden_size
40
+ self .dtype = dtype
38
41
self .gate_proj = torch .nn .Parameter (torch .empty (
39
42
(self .hidden_size * 2 , hidden_size )),
40
43
requires_grad = False )
@@ -64,9 +67,10 @@ def ops_in_model_after(self):
64
67
65
68
class TestAGMMModel (torch .nn .Module ):
66
69
67
- def __init__ (self , hidden_size = 16 ):
70
+ def __init__ (self , hidden_size = 16 , dtype = torch . float16 ):
68
71
super ().__init__ ()
69
72
self .hidden_size = hidden_size
73
+ self .dtype = dtype
70
74
self .weight = torch .nn .Parameter (torch .empty (
71
75
(hidden_size , hidden_size )),
72
76
requires_grad = False )
@@ -91,8 +95,125 @@ def ops_in_model_after(self):
91
95
return [torch .ops .symm_mem .fused_all_gather_matmul .default ]
92
96
93
97
98
+ class _BaseScaledMMModel (torch .nn .Module ):
99
+
100
+ def __init__ (self , hidden_size = 16 , dtype = torch .float16 ):
101
+ super ().__init__ ()
102
+ self .hidden_size = hidden_size
103
+ self .dtype = dtype
104
+ self .weight = torch .empty ([hidden_size , hidden_size ], dtype = FP8_DTYPE )\
105
+ .contiguous ().transpose (0 , 1 )
106
+
107
+ # Initialize scale_b for _scaled_mm.
108
+ self .scale_b = torch .ones (1 , self .hidden_size , dtype = torch .float32 )
109
+
110
+
111
+ class TestScaledMMRSModel (_BaseScaledMMModel ):
112
+
113
+ def forward (self , input : torch .Tensor ):
114
+ """
115
+ Forward pass implementing the scaled_mm + reduce scatter in the FX graph
116
+
117
+ """
118
+ fp8_input = input .to (FP8_DTYPE )
119
+ scale_a = torch .ones (input .shape [0 ], 1 , dtype = torch .float32 )
120
+ scaled_mm = torch ._scaled_mm (fp8_input ,
121
+ self .weight ,
122
+ scale_a = scale_a ,
123
+ scale_b = self .scale_b ,
124
+ out_dtype = self .dtype )
125
+ reduce_scatter = tensor_model_parallel_reduce_scatter (scaled_mm , dim = 0 )
126
+ return reduce_scatter
127
+
128
+ def ops_in_model_before (self ):
129
+ return [torch .ops .vllm .reduce_scatter .default ]
130
+
131
+ def ops_in_model_after (self ):
132
+ return [torch .ops .symm_mem .fused_scaled_matmul_reduce_scatter .default ]
133
+
134
+
135
+ class TestAGScaledMMModel (_BaseScaledMMModel ):
136
+
137
+ def forward (self , input : torch .Tensor ):
138
+ """
139
+ Forward pass implementing the all gather + scaled_mm in the FX graph
140
+ """
141
+ # Reshape input
142
+ fp8_input = input .to (FP8_DTYPE )
143
+ all_gather = tensor_model_parallel_all_gather (fp8_input , dim = 0 )
144
+
145
+ scale_a = torch .ones (all_gather .shape [0 ], 1 , dtype = torch .float32 )
146
+ scaled_mm = torch ._scaled_mm (all_gather ,
147
+ self .weight ,
148
+ scale_a = scale_a ,
149
+ scale_b = self .scale_b ,
150
+ out_dtype = self .dtype )
151
+ return scaled_mm
152
+
153
+ def ops_in_model_before (self ):
154
+ return [torch .ops .vllm .all_gather .default ]
155
+
156
+ def ops_in_model_after (self ):
157
+ return [torch .ops .symm_mem .fused_all_gather_scaled_matmul .default ]
158
+
159
+
160
+ class TestCutlassScaledMMRSModel (_BaseScaledMMModel ):
161
+
162
+ def forward (self , input : torch .Tensor ):
163
+ """
164
+ Forward pass implementing the cutlass_scaled_mm + reduce scatter
165
+ in the FX graph
166
+
167
+ """
168
+ fp8_input = input .to (FP8_DTYPE )
169
+ scale_a = torch .ones (input .shape [0 ], 1 , dtype = torch .float32 )
170
+ mm_out = torch .empty ((fp8_input .shape [0 ], self .weight .shape [1 ]),
171
+ dtype = self .dtype ,
172
+ device = input .device )
173
+ torch .ops ._C .cutlass_scaled_mm (mm_out , fp8_input , self .weight , scale_a ,
174
+ self .scale_b , None )
175
+ reduce_scatter = tensor_model_parallel_reduce_scatter (mm_out , dim = 0 )
176
+ return reduce_scatter
177
+
178
+ def ops_in_model_before (self ):
179
+ return [torch .ops .vllm .reduce_scatter .default ]
180
+
181
+ def ops_in_model_after (self ):
182
+ return [torch .ops .symm_mem .fused_scaled_matmul_reduce_scatter .default ]
183
+
184
+
185
+ class TestAGCutlassScaledMMModel (_BaseScaledMMModel ):
186
+
187
+ def forward (self , input : torch .Tensor ):
188
+ """
189
+ Forward pass implementing the all gather + cutlass_scaled_mm
190
+ in the FX graph
191
+ """
192
+ # Reshape input
193
+ fp8_input = input .to (FP8_DTYPE )
194
+ all_gather = tensor_model_parallel_all_gather (fp8_input , dim = 0 )
195
+
196
+ scale_a = torch .ones (all_gather .shape [0 ], 1 , dtype = torch .float32 )
197
+
198
+ mm_out = torch .empty ((all_gather .shape [0 ], self .weight .shape [1 ]),
199
+ dtype = self .dtype ,
200
+ device = all_gather .device )
201
+ torch .ops ._C .cutlass_scaled_mm (mm_out , all_gather , self .weight ,
202
+ scale_a , self .scale_b , None )
203
+ return mm_out
204
+
205
+ def ops_in_model_before (self ):
206
+ return [torch .ops .vllm .all_gather .default ]
207
+
208
+ def ops_in_model_after (self ):
209
+ return [torch .ops .symm_mem .fused_all_gather_scaled_matmul .default ]
210
+
211
+
94
212
@multi_gpu_test (num_gpus = 2 )
95
- @pytest .mark .parametrize ("test_model" , [TestMMRSModel , TestAGMMModel ])
213
+ @pytest .mark .parametrize ("test_model" , [
214
+ TestMMRSModel , TestAGMMModel , TestScaledMMRSModel , TestAGScaledMMModel ,
215
+ TestCutlassScaledMMRSModel , TestAGCutlassScaledMMModel
216
+ ])
96
217
@pytest .mark .parametrize ("batch_size" , [8 ])
97
218
@pytest .mark .parametrize ("seq_len" , [16 ])
98
219
@pytest .mark .parametrize ("hidden_size" , [16 ])
@@ -101,6 +222,14 @@ def ops_in_model_after(self):
101
222
reason = "Only test on CUDA" )
102
223
def test_async_tp_pass_replace (test_model : str , batch_size : int , seq_len : int ,
103
224
hidden_size : int , dtype : torch .dtype ):
225
+ if test_model in (TestScaledMMRSModel , TestAGScaledMMModel ,
226
+ TestCutlassScaledMMRSModel ,
227
+ TestAGCutlassScaledMMModel ) and dtype == torch .float16 :
228
+ pytest .skip (
229
+ "Only bf16 high precision output types are supported for " \
230
+ "per-token (row-wise) scaling"
231
+ )
232
+
104
233
num_processes = 2
105
234
106
235
def run_torch_spawn (fn , nprocs ):
@@ -155,7 +284,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
155
284
async_tp_pass = AsyncTPPass (vllm_config )
156
285
backend = TestBackend (async_tp_pass )
157
286
158
- model = test_model_cls (hidden_size )
287
+ model = test_model_cls (hidden_size ,
288
+ dtype ) # Pass dtype to model constructor
159
289
160
290
hidden_states = torch .randn ((batch_size * seq_len , hidden_size ),
161
291
dtype = dtype ,
@@ -174,7 +304,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
174
304
175
305
176
306
@create_new_process_for_each_test ()
177
- @pytest .mark .parametrize ("model_id" , ["meta-llama/Llama-3.2-1B-Instruct" ])
307
+ @pytest .mark .parametrize ("model_id" , [
308
+ "meta-llama/Llama-3.2-1B-Instruct" ,
309
+ "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
310
+ ])
178
311
@pytest .mark .parametrize ("tp_size" , [2 ])
179
312
@pytest .mark .parametrize ("async_tp_enabled" , [True ])
180
313
@pytest .mark .parametrize ("distributed_backend" , ["mp" ])
0 commit comments