Skip to content

Commit 23640f9

Browse files
Revert Tensor Descriptor
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 2d1ba45 commit 23640f9

File tree

752 files changed

+138185
-58
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

752 files changed

+138185
-58
lines changed

benchmarks/third_party/sglang/scaled_mm_benchmark.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -100,37 +100,26 @@ def scaled_mm_kernel_td(
100100
offsets_scale_bn = tl.arange(0, BLOCK_SIZE_SCALE_B) + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N
101101
masks_scale_bn = offsets_scale_bn < N
102102

103-
a_desc = tl.make_tensor_descriptor(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
104-
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
105-
b_desc = tl.make_tensor_descriptor(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
106-
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
107-
108-
# a_ptrs = a_ptr + offsets_a
109-
# b_ptrs = b_ptr + offsets_b
103+
a_ptrs = a_ptr + offsets_a
104+
b_ptrs = b_ptr + offsets_b
110105

111106
scale_a_ptrs = scale_a_ptr + offsets_scale_am
112107
scale_b_ptrs = scale_b_ptr + offsets_scale_bn
113108

114-
off_k = 0
115109
for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
116-
# masks_k = offsets_k < K
117-
# masks_a = masks_am[:, None] & masks_k[None, :]
118-
# a = tl.load(a_ptrs, mask=masks_a)
119-
120-
# masks_b = masks_k[:, None] & masks_bn[None, :]
121-
# b = tl.load(b_ptrs, mask=masks_b)
110+
masks_k = offsets_k < K
111+
masks_a = masks_am[:, None] & masks_k[None, :]
112+
a = tl.load(a_ptrs, mask=masks_a)
122113

123-
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
124-
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
125-
# accumulator += tl.dot(a, b)
114+
masks_b = masks_k[:, None] & masks_bn[None, :]
115+
b = tl.load(b_ptrs, mask=masks_b)
126116

127117
# Accumulate results.
128118
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
129-
off_k += BLOCK_SIZE_K
130119

131-
# offsets_k += BLOCK_SIZE_K
132-
# a_ptrs += BLOCK_SIZE_K * stride_ak
133-
# b_ptrs += BLOCK_SIZE_K * stride_bk
120+
offsets_k += BLOCK_SIZE_K
121+
a_ptrs += BLOCK_SIZE_K * stride_ak
122+
b_ptrs += BLOCK_SIZE_K * stride_bk
134123

135124
# Apply scale at end.
136125
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
@@ -162,13 +151,10 @@ def scaled_mm_kernel_td(
162151
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
163152
offs_cm = offs_cm.to(tl.int64)
164153
offs_cn = offs_cn.to(tl.int64)
165-
# c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
166-
# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
154+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
155+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
167156

168-
# tl.store(c_ptrs, c, mask=c_mask)
169-
c_desc = tl.make_tensor_descriptor(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
170-
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
171-
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c)
157+
tl.store(c_ptrs, c, mask=c_mask)
172158

173159

174160
# input - [M, K]

benchmarks/third_party/vllm/batched_moe_benchmark.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131

3232
@triton.jit
3333
def moe_mmk(
34-
a_desc,
35-
b_desc,
34+
a_ptrs,
35+
b_ptrs,
3636
K,
3737
expert_id,
3838
a_scale_ptr,
@@ -41,6 +41,9 @@ def moe_mmk(
4141
# moving by 1 element in a particular dimension. E.g. `stride_am` is
4242
# how much to increase `a_ptr` by to get the element one row down
4343
# (A has M rows).
44+
stride_ak: tl.int64,
45+
stride_bk: tl.int64,
46+
stride_ase: tl.int64,
4447
stride_asm: tl.int64,
4548
stride_ask: tl.int64,
4649
stride_bse: tl.int64,
@@ -65,6 +68,7 @@ def moe_mmk(
6568
use_w8a16: tl.constexpr,
6669
per_act_token_quant: tl.constexpr,
6770
):
71+
offs_k = tl.arange(0, BLOCK_K)
6872

6973
if use_w8a16:
7074
b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[None, :] * stride_bsn
@@ -99,8 +103,12 @@ def moe_mmk(
99103
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
100104
for k in range(0, tl.cdiv(K, BLOCK_K)):
101105
# Load the next block of A and B using tensor descriptors
102-
a = a_desc.load([pid_m * BLOCK_M, k * BLOCK_K])
103-
b = b_desc.load([k * BLOCK_K, pid_n * BLOCK_N])
106+
a = tl.load(
107+
a_ptrs,
108+
mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K),
109+
other=0.0,
110+
)
111+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
104112

105113
# We accumulate along the K dimension.
106114
if use_w8a16:
@@ -119,6 +127,9 @@ def moe_mmk(
119127
else:
120128
accumulator += tl.dot(a, b)
121129

130+
a_ptrs += BLOCK_K * stride_ak
131+
b_ptrs += BLOCK_K * stride_bk
132+
122133
if use_w8a16:
123134
accumulator = (accumulator * b_scale).to(compute_type)
124135
elif use_w8a8:
@@ -134,9 +145,9 @@ def moe_mmk(
134145

135146
@triton.jit
136147
def expert_triton_kernel(
137-
a_desc, #[max_tokens, K]
138-
b_desc, #[K, N]
139-
c_desc, #[max_tokens, N]
148+
a_ptr,
149+
b_ptr,
150+
c_ptr,
140151
expert_id,
141152
compute_type: tl.constexpr,
142153
# Dimensions
@@ -147,8 +158,12 @@ def expert_triton_kernel(
147158
a_scale_ptr,
148159
b_scale_ptr,
149160
# strides
161+
stride_am: tl.int64,
150162
stride_ak: tl.int64,
151163
stride_bk: tl.int64,
164+
stride_bn: tl.int64,
165+
stride_cm: tl.int64,
166+
stride_cn: tl.int64,
152167
stride_ase: tl.int64,
153168
stride_asm: tl.int64,
154169
stride_ask: tl.int64,
@@ -174,15 +189,19 @@ def expert_triton_kernel(
174189

175190
offs_m = tl.arange(0, BLOCK_M)
176191
offs_n = tl.arange(0, BLOCK_N) % N
177-
# offs_k = tl.arange(0, BLOCK_K)
192+
offs_k = tl.arange(0, BLOCK_K)
178193
mask_m = offs_m < M
179194

195+
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
196+
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
197+
180198
accumulator = moe_mmk(
181-
a_desc, b_desc, K, expert_id, a_scale_ptr, b_scale_ptr,
199+
a_ptrs, b_ptrs, K, expert_id, a_scale_ptr, b_scale_ptr,
182200
# The stride variables represent how much to increase the ptr by when
183201
# moving by 1 element in a particular dimension. E.g. `stride_am` is
184202
# how much to increase `a_ptr` by to get the element one row down
185203
# (A has M rows).
204+
stride_ak, stride_bk, stride_ase,
186205
stride_asm, stride_ask, stride_bse, stride_bsk, stride_bsn,
187206
# Offsets and masks
188207
offs_m, offs_n, offs_bn, mask_m,
@@ -192,11 +211,10 @@ def expert_triton_kernel(
192211
BLOCK_M, BLOCK_N, BLOCK_K, compute_type, use_fp8_w8a8, use_int8_w8a16, per_act_token_quant)
193212

194213
# store in C
195-
# offs_cn = tl.arange(0, BLOCK_N)
196-
# c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
197-
# c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
198-
c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], accumulator)
199-
# tl.store(c_ptrs, accumulator, mask=c_mask)
214+
offs_cn = tl.arange(0, BLOCK_N)
215+
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
216+
c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
217+
tl.store(c_ptrs, accumulator, mask=c_mask)
200218

201219

202220
def get_matmul_batched_autotune_configs() -> List[triton.Config]:
@@ -292,17 +310,10 @@ def batched_triton_kernel(
292310
cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start)
293311
cta_n_size = min(BLOCK_N, N - cta_n_start)
294312

295-
a_desc = tl.make_tensor_descriptor(base=a_ptr + expert_id * stride_ae, shape=(e_num_tokens, K),
296-
strides=(stride_am, stride_ak), block_shape=(BLOCK_M, BLOCK_K))
297-
b_desc = tl.make_tensor_descriptor(base=b_ptr + expert_id * stride_be, shape=(K, N), strides=(stride_bk, stride_bn),
298-
block_shape=(BLOCK_K, BLOCK_N))
299-
c_desc = tl.make_tensor_descriptor(base=c_ptr + expert_id * stride_ce, shape=(e_num_tokens, N),
300-
strides=(stride_cm, stride_cn), block_shape=(BLOCK_M, BLOCK_N))
301-
302-
# a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
303-
# b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
304-
# c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
305-
# cta_n_start * stride_cn)
313+
a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
314+
b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
315+
c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
316+
cta_n_start * stride_cn)
306317

307318
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(tl.int64)) % N
308319

@@ -314,12 +325,12 @@ def batched_triton_kernel(
314325
if group_k > 0 and group_n > 0 or per_act_token_quant:
315326
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
316327

317-
expert_triton_kernel(a_desc, b_desc, c_desc, expert_id, compute_type, cta_m_size, # M
328+
expert_triton_kernel(a_ptr, b_ptr, c_ptr, expert_id, compute_type, cta_m_size, # M
318329
cta_n_size, # N
319330
K, # K
320331
a_scale_ptr, b_scale_ptr,
321332
# Strides
322-
stride_ak, stride_bk, stride_ase, stride_asm, stride_ask, stride_bse, stride_bsk, stride_bsn,
333+
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_ase, stride_asm, stride_ask, stride_bse, stride_bsk, stride_bsn,
323334
# offsets
324335
offs_bn,
325336
# Blockwise quantization data
@@ -502,9 +513,7 @@ def get_batched_mm_benchmark(
502513
Returns a Mark object containing a Benchmark object for batched matrix multiplication.
503514
"""
504515
supported_providers = {
505-
'triton': 'triton',
506516
'triton-td': 'triton-td',
507-
'pytorch': 'pytorch',
508517
}
509518
if fp8:
510519
# pytorch is very slow with fp8 case, for (8, 64, 1024, 2048) case it has ~0.15 TFlops vs 1.5 for triton

benchmarks/third_party/vllm/tests/__init__.py

Whitespace-only changes.

benchmarks/third_party/vllm/tests/async_engine/__init__.py

Whitespace-only changes.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""vllm.entrypoints.api_server with some extra logging for testing."""
4+
from collections.abc import Iterable
5+
from typing import Any
6+
7+
import uvicorn
8+
from fastapi.responses import JSONResponse, Response
9+
10+
import vllm.entrypoints.api_server
11+
import vllm.envs as envs
12+
from vllm.engine.arg_utils import AsyncEngineArgs
13+
from vllm.engine.async_llm_engine import AsyncLLMEngine
14+
from vllm.utils import FlexibleArgumentParser
15+
16+
app = vllm.entrypoints.api_server.app
17+
18+
19+
class AsyncLLMEngineWithStats(AsyncLLMEngine):
20+
21+
def __init__(self, *args, **kwargs):
22+
super().__init__(*args, **kwargs)
23+
self._num_aborts = 0
24+
25+
async def _engine_abort(self, request_ids: Iterable[str]):
26+
ids = list(request_ids)
27+
self._num_aborts += len(ids)
28+
await super()._engine_abort(ids)
29+
30+
def testing_stats(self) -> dict[str, Any]:
31+
return {"num_aborted_requests": self._num_aborts}
32+
33+
34+
@app.get("/stats")
35+
def stats() -> Response:
36+
"""Get the statistics of the engine."""
37+
return JSONResponse(engine.testing_stats())
38+
39+
40+
if __name__ == "__main__":
41+
parser = FlexibleArgumentParser()
42+
parser.add_argument("--host", type=str, default="localhost")
43+
parser.add_argument("--port", type=int, default=8000)
44+
parser = AsyncEngineArgs.add_cli_args(parser)
45+
args = parser.parse_args()
46+
47+
engine_args = AsyncEngineArgs.from_cli_args(args)
48+
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
49+
vllm.entrypoints.api_server.engine = engine
50+
uvicorn.run(app,
51+
host=args.host,
52+
port=args.port,
53+
log_level="debug",
54+
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
5+
6+
@pytest.fixture(scope="function", autouse=True)
7+
def use_v0_only(monkeypatch):
8+
"""
9+
Since this module is V0 only, set VLLM_USE_V1=0 for
10+
all tests in the module.
11+
"""
12+
monkeypatch.setenv('VLLM_USE_V1', '0')

0 commit comments

Comments
 (0)