Skip to content

Commit c8da820

Browse files
cascade812diegocastanibm
authored andcommitted
[Feature] Add async tensor parallelism for scaled mm (vllm-project#20155)
Signed-off-by: cascade812 <[email protected]> Signed-off-by: Diego-Castan <[email protected]>
1 parent 3f61c50 commit c8da820

File tree

3 files changed

+381
-8
lines changed

3 files changed

+381
-8
lines changed

tests/compile/test_async_tp.py

Lines changed: 138 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
multi_gpu_test)
2323
from .backend import TestBackend
2424

25+
FP8_DTYPE = current_platform.fp8_dtype()
26+
2527
prompts = [
2628
"Hello, my name is",
2729
"The president of the United States is",
@@ -32,9 +34,10 @@
3234

3335
class TestMMRSModel(torch.nn.Module):
3436

35-
def __init__(self, hidden_size=16):
37+
def __init__(self, hidden_size=16, dtype=torch.float16):
3638
super().__init__()
3739
self.hidden_size = hidden_size
40+
self.dtype = dtype
3841
self.gate_proj = torch.nn.Parameter(torch.empty(
3942
(self.hidden_size * 2, hidden_size)),
4043
requires_grad=False)
@@ -64,9 +67,10 @@ def ops_in_model_after(self):
6467

6568
class TestAGMMModel(torch.nn.Module):
6669

67-
def __init__(self, hidden_size=16):
70+
def __init__(self, hidden_size=16, dtype=torch.float16):
6871
super().__init__()
6972
self.hidden_size = hidden_size
73+
self.dtype = dtype
7074
self.weight = torch.nn.Parameter(torch.empty(
7175
(hidden_size, hidden_size)),
7276
requires_grad=False)
@@ -91,8 +95,125 @@ def ops_in_model_after(self):
9195
return [torch.ops.symm_mem.fused_all_gather_matmul.default]
9296

9397

98+
class _BaseScaledMMModel(torch.nn.Module):
99+
100+
def __init__(self, hidden_size=16, dtype=torch.float16):
101+
super().__init__()
102+
self.hidden_size = hidden_size
103+
self.dtype = dtype
104+
self.weight = torch.empty([hidden_size, hidden_size], dtype=FP8_DTYPE)\
105+
.contiguous().transpose(0, 1)
106+
107+
# Initialize scale_b for _scaled_mm.
108+
self.scale_b = torch.ones(1, self.hidden_size, dtype=torch.float32)
109+
110+
111+
class TestScaledMMRSModel(_BaseScaledMMModel):
112+
113+
def forward(self, input: torch.Tensor):
114+
"""
115+
Forward pass implementing the scaled_mm + reduce scatter in the FX graph
116+
117+
"""
118+
fp8_input = input.to(FP8_DTYPE)
119+
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
120+
scaled_mm = torch._scaled_mm(fp8_input,
121+
self.weight,
122+
scale_a=scale_a,
123+
scale_b=self.scale_b,
124+
out_dtype=self.dtype)
125+
reduce_scatter = tensor_model_parallel_reduce_scatter(scaled_mm, dim=0)
126+
return reduce_scatter
127+
128+
def ops_in_model_before(self):
129+
return [torch.ops.vllm.reduce_scatter.default]
130+
131+
def ops_in_model_after(self):
132+
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
133+
134+
135+
class TestAGScaledMMModel(_BaseScaledMMModel):
136+
137+
def forward(self, input: torch.Tensor):
138+
"""
139+
Forward pass implementing the all gather + scaled_mm in the FX graph
140+
"""
141+
# Reshape input
142+
fp8_input = input.to(FP8_DTYPE)
143+
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
144+
145+
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
146+
scaled_mm = torch._scaled_mm(all_gather,
147+
self.weight,
148+
scale_a=scale_a,
149+
scale_b=self.scale_b,
150+
out_dtype=self.dtype)
151+
return scaled_mm
152+
153+
def ops_in_model_before(self):
154+
return [torch.ops.vllm.all_gather.default]
155+
156+
def ops_in_model_after(self):
157+
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
158+
159+
160+
class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
161+
162+
def forward(self, input: torch.Tensor):
163+
"""
164+
Forward pass implementing the cutlass_scaled_mm + reduce scatter
165+
in the FX graph
166+
167+
"""
168+
fp8_input = input.to(FP8_DTYPE)
169+
scale_a = torch.ones(input.shape[0], 1, dtype=torch.float32)
170+
mm_out = torch.empty((fp8_input.shape[0], self.weight.shape[1]),
171+
dtype=self.dtype,
172+
device=input.device)
173+
torch.ops._C.cutlass_scaled_mm(mm_out, fp8_input, self.weight, scale_a,
174+
self.scale_b, None)
175+
reduce_scatter = tensor_model_parallel_reduce_scatter(mm_out, dim=0)
176+
return reduce_scatter
177+
178+
def ops_in_model_before(self):
179+
return [torch.ops.vllm.reduce_scatter.default]
180+
181+
def ops_in_model_after(self):
182+
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
183+
184+
185+
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
186+
187+
def forward(self, input: torch.Tensor):
188+
"""
189+
Forward pass implementing the all gather + cutlass_scaled_mm
190+
in the FX graph
191+
"""
192+
# Reshape input
193+
fp8_input = input.to(FP8_DTYPE)
194+
all_gather = tensor_model_parallel_all_gather(fp8_input, dim=0)
195+
196+
scale_a = torch.ones(all_gather.shape[0], 1, dtype=torch.float32)
197+
198+
mm_out = torch.empty((all_gather.shape[0], self.weight.shape[1]),
199+
dtype=self.dtype,
200+
device=all_gather.device)
201+
torch.ops._C.cutlass_scaled_mm(mm_out, all_gather, self.weight,
202+
scale_a, self.scale_b, None)
203+
return mm_out
204+
205+
def ops_in_model_before(self):
206+
return [torch.ops.vllm.all_gather.default]
207+
208+
def ops_in_model_after(self):
209+
return [torch.ops.symm_mem.fused_all_gather_scaled_matmul.default]
210+
211+
94212
@multi_gpu_test(num_gpus=2)
95-
@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel])
213+
@pytest.mark.parametrize("test_model", [
214+
TestMMRSModel, TestAGMMModel, TestScaledMMRSModel, TestAGScaledMMModel,
215+
TestCutlassScaledMMRSModel, TestAGCutlassScaledMMModel
216+
])
96217
@pytest.mark.parametrize("batch_size", [8])
97218
@pytest.mark.parametrize("seq_len", [16])
98219
@pytest.mark.parametrize("hidden_size", [16])
@@ -101,6 +222,14 @@ def ops_in_model_after(self):
101222
reason="Only test on CUDA")
102223
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
103224
hidden_size: int, dtype: torch.dtype):
225+
if test_model in (TestScaledMMRSModel, TestAGScaledMMModel,
226+
TestCutlassScaledMMRSModel,
227+
TestAGCutlassScaledMMModel) and dtype == torch.float16:
228+
pytest.skip(
229+
"Only bf16 high precision output types are supported for " \
230+
"per-token (row-wise) scaling"
231+
)
232+
104233
num_processes = 2
105234

106235
def run_torch_spawn(fn, nprocs):
@@ -155,7 +284,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
155284
async_tp_pass = AsyncTPPass(vllm_config)
156285
backend = TestBackend(async_tp_pass)
157286

158-
model = test_model_cls(hidden_size)
287+
model = test_model_cls(hidden_size,
288+
dtype) # Pass dtype to model constructor
159289

160290
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
161291
dtype=dtype,
@@ -174,7 +304,10 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
174304

175305

176306
@create_new_process_for_each_test()
177-
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"])
307+
@pytest.mark.parametrize("model_id", [
308+
"meta-llama/Llama-3.2-1B-Instruct",
309+
"RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8"
310+
])
178311
@pytest.mark.parametrize("tp_size", [2])
179312
@pytest.mark.parametrize("async_tp_enabled", [True])
180313
@pytest.mark.parametrize("distributed_backend", ["mp"])

0 commit comments

Comments
 (0)