Skip to content

Commit b392e26

Browse files
Revert Tensor Descriptor
1 parent 14c6acb commit b392e26

File tree

4 files changed

+58
-67
lines changed

4 files changed

+58
-67
lines changed

benchmarks/third_party/sglang/scaled_mm_benchmark.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ def scaled_mm_kernel_td(
8181
# eventually occur.
8282

8383
# Offsets and masks.
84-
# offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
85-
# masks_am = offsets_am < M
84+
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
85+
masks_am = offsets_am < M
8686

8787
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
88-
# masks_bn = offsets_bn < N
88+
masks_bn = offsets_bn < N
8989

90-
# offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
91-
# offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
92-
# offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
90+
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
91+
offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
92+
offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
9393

9494
# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
9595
# appropriate offsets and masks for each case. Same goes for
@@ -100,37 +100,26 @@ def scaled_mm_kernel_td(
100100
offsets_scale_bn = tl.arange(0, BLOCK_SIZE_SCALE_B) + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N
101101
masks_scale_bn = offsets_scale_bn < N
102102

103-
a_desc = tl.make_tensor_descriptor(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak),
104-
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K))
105-
b_desc = tl.make_tensor_descriptor(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn),
106-
block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N))
107-
108-
# a_ptrs = a_ptr + offsets_a
109-
# b_ptrs = b_ptr + offsets_b
103+
a_ptrs = a_ptr + offsets_a
104+
b_ptrs = b_ptr + offsets_b
110105

111106
scale_a_ptrs = scale_a_ptr + offsets_scale_am
112107
scale_b_ptrs = scale_b_ptr + offsets_scale_bn
113108

114-
off_k = 0
115109
for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
116-
# masks_k = offsets_k < K
117-
# masks_a = masks_am[:, None] & masks_k[None, :]
118-
# a = tl.load(a_ptrs, mask=masks_a)
119-
120-
# masks_b = masks_k[:, None] & masks_bn[None, :]
121-
# b = tl.load(b_ptrs, mask=masks_b)
110+
masks_k = offsets_k < K
111+
masks_a = masks_am[:, None] & masks_k[None, :]
112+
a = tl.load(a_ptrs, mask=masks_a)
122113

123-
a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k])
124-
b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N])
125-
# accumulator += tl.dot(a, b)
114+
masks_b = masks_k[:, None] & masks_bn[None, :]
115+
b = tl.load(b_ptrs, mask=masks_b)
126116

127117
# Accumulate results.
128118
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
129-
off_k += BLOCK_SIZE_K
130119

131-
# offsets_k += BLOCK_SIZE_K
132-
# a_ptrs += BLOCK_SIZE_K * stride_ak
133-
# b_ptrs += BLOCK_SIZE_K * stride_bk
120+
offsets_k += BLOCK_SIZE_K
121+
a_ptrs += BLOCK_SIZE_K * stride_ak
122+
b_ptrs += BLOCK_SIZE_K * stride_bk
134123

135124
# Apply scale at end.
136125
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
@@ -162,13 +151,10 @@ def scaled_mm_kernel_td(
162151
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
163152
offs_cm = offs_cm.to(tl.int64)
164153
offs_cn = offs_cn.to(tl.int64)
165-
# c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
166-
# c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
154+
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
155+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
167156

168-
# tl.store(c_ptrs, c, mask=c_mask)
169-
c_desc = tl.make_tensor_descriptor(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn),
170-
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N))
171-
c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c)
157+
tl.store(c_ptrs, c, mask=c_mask)
172158

173159

174160
# input - [M, K]

benchmarks/third_party/vllm/batched_moe_benchmark.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131

3232
@triton.jit
3333
def moe_mmk(
34-
a_desc,
35-
b_desc,
34+
a_ptrs,
35+
b_ptrs,
3636
K,
3737
expert_id,
3838
a_scale_ptr,
@@ -41,6 +41,9 @@ def moe_mmk(
4141
# moving by 1 element in a particular dimension. E.g. `stride_am` is
4242
# how much to increase `a_ptr` by to get the element one row down
4343
# (A has M rows).
44+
stride_ak: tl.int64,
45+
stride_bk: tl.int64,
46+
stride_ase: tl.int64,
4447
stride_asm: tl.int64,
4548
stride_ask: tl.int64,
4649
stride_bse: tl.int64,
@@ -65,6 +68,7 @@ def moe_mmk(
6568
use_w8a16: tl.constexpr,
6669
per_act_token_quant: tl.constexpr,
6770
):
71+
offs_k = tl.arange(0, BLOCK_K)
6872

6973
if use_w8a16:
7074
b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n[None, :] * stride_bsn
@@ -99,8 +103,12 @@ def moe_mmk(
99103
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
100104
for k in range(0, tl.cdiv(K, BLOCK_K)):
101105
# Load the next block of A and B using tensor descriptors
102-
a = a_desc.load([pid_m * BLOCK_M, k * BLOCK_K])
103-
b = b_desc.load([k * BLOCK_K, pid_n * BLOCK_N])
106+
a = tl.load(
107+
a_ptrs,
108+
mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K),
109+
other=0.0,
110+
)
111+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
104112

105113
# We accumulate along the K dimension.
106114
if use_w8a16:
@@ -119,6 +127,9 @@ def moe_mmk(
119127
else:
120128
accumulator += tl.dot(a, b)
121129

130+
a_ptrs += BLOCK_K * stride_ak
131+
b_ptrs += BLOCK_K * stride_bk
132+
122133
if use_w8a16:
123134
accumulator = (accumulator * b_scale).to(compute_type)
124135
elif use_w8a8:
@@ -134,9 +145,9 @@ def moe_mmk(
134145

135146
@triton.jit
136147
def expert_triton_kernel(
137-
a_desc, #[max_tokens, K]
138-
b_desc, #[K, N]
139-
c_desc, #[max_tokens, N]
148+
a_ptr,
149+
b_ptr,
150+
c_ptr,
140151
expert_id,
141152
compute_type: tl.constexpr,
142153
# Dimensions
@@ -147,8 +158,12 @@ def expert_triton_kernel(
147158
a_scale_ptr,
148159
b_scale_ptr,
149160
# strides
161+
stride_am: tl.int64,
150162
stride_ak: tl.int64,
151163
stride_bk: tl.int64,
164+
stride_bn: tl.int64,
165+
stride_cm: tl.int64,
166+
stride_cn: tl.int64,
152167
stride_ase: tl.int64,
153168
stride_asm: tl.int64,
154169
stride_ask: tl.int64,
@@ -174,15 +189,19 @@ def expert_triton_kernel(
174189

175190
offs_m = tl.arange(0, BLOCK_M)
176191
offs_n = tl.arange(0, BLOCK_N) % N
177-
# offs_k = tl.arange(0, BLOCK_K)
192+
offs_k = tl.arange(0, BLOCK_K)
178193
mask_m = offs_m < M
179194

195+
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
196+
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
197+
180198
accumulator = moe_mmk(
181-
a_desc, b_desc, K, expert_id, a_scale_ptr, b_scale_ptr,
199+
a_ptrs, b_ptrs, K, expert_id, a_scale_ptr, b_scale_ptr,
182200
# The stride variables represent how much to increase the ptr by when
183201
# moving by 1 element in a particular dimension. E.g. `stride_am` is
184202
# how much to increase `a_ptr` by to get the element one row down
185203
# (A has M rows).
204+
stride_ak, stride_bk, stride_ase,
186205
stride_asm, stride_ask, stride_bse, stride_bsk, stride_bsn,
187206
# Offsets and masks
188207
offs_m, offs_n, offs_bn, mask_m,
@@ -192,11 +211,10 @@ def expert_triton_kernel(
192211
BLOCK_M, BLOCK_N, BLOCK_K, compute_type, use_fp8_w8a8, use_int8_w8a16, per_act_token_quant)
193212

194213
# store in C
195-
# offs_cn = tl.arange(0, BLOCK_N)
196-
# c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
197-
# c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
198-
c_desc.store([pid_m * BLOCK_M, pid_n * BLOCK_N], accumulator)
199-
# tl.store(c_ptrs, accumulator, mask=c_mask)
214+
offs_cn = tl.arange(0, BLOCK_N)
215+
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
216+
c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
217+
tl.store(c_ptrs, accumulator, mask=c_mask)
200218

201219

202220
def get_matmul_batched_autotune_configs() -> List[triton.Config]:
@@ -292,17 +310,10 @@ def batched_triton_kernel(
292310
cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start)
293311
cta_n_size = min(BLOCK_N, N - cta_n_start)
294312

295-
a_desc = tl.make_tensor_descriptor(base=a_ptr + expert_id * stride_ae, shape=(e_num_tokens, K),
296-
strides=(stride_am, stride_ak), block_shape=(BLOCK_M, BLOCK_K))
297-
b_desc = tl.make_tensor_descriptor(base=b_ptr + expert_id * stride_be, shape=(K, N), strides=(stride_bk, stride_bn),
298-
block_shape=(BLOCK_K, BLOCK_N))
299-
c_desc = tl.make_tensor_descriptor(base=c_ptr + expert_id * stride_ce, shape=(e_num_tokens, N),
300-
strides=(stride_cm, stride_cn), block_shape=(BLOCK_M, BLOCK_N))
301-
302-
# a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
303-
# b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
304-
# c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
305-
# cta_n_start * stride_cn)
313+
a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
314+
b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
315+
c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
316+
cta_n_start * stride_cn)
306317

307318
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(tl.int64)) % N
308319

@@ -314,12 +325,12 @@ def batched_triton_kernel(
314325
if group_k > 0 and group_n > 0 or per_act_token_quant:
315326
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
316327

317-
expert_triton_kernel(a_desc, b_desc, c_desc, expert_id, compute_type, cta_m_size, # M
328+
expert_triton_kernel(a_ptr, b_ptr, c_ptr, expert_id, compute_type, cta_m_size, # M
318329
cta_n_size, # N
319330
K, # K
320331
a_scale_ptr, b_scale_ptr,
321332
# Strides
322-
stride_ak, stride_bk, stride_ase, stride_asm, stride_ask, stride_bse, stride_bsk, stride_bsn,
333+
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_ase, stride_asm, stride_ask, stride_bse, stride_bsk, stride_bsn,
323334
# offsets
324335
offs_bn,
325336
# Blockwise quantization data
@@ -502,13 +513,8 @@ def get_batched_mm_benchmark(
502513
Returns a Mark object containing a Benchmark object for batched matrix multiplication.
503514
"""
504515
supported_providers = {
505-
'triton': 'triton',
506516
'triton-td': 'triton-td',
507-
'pytorch': 'pytorch',
508517
}
509-
if fp8:
510-
# pytorch is very slow with fp8 case, for (8, 64, 1024, 2048) case it has ~0.15 TFlops vs 1.5 for triton
511-
del supported_providers['pytorch']
512518

513519
providers = benchmark_suite.filter_providers(supported_providers, providers_filter)
514520
configs = MM_CONFIGS_FP8 if fp8 else MM_CONFIGS_BF16

benchmarks/third_party/vllm/unified_attention_benchmark.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,6 @@ def get_unified_attention_benchmark(
10651065
supported_providers = {
10661066
'triton': 'triton',
10671067
'triton-td': 'triton-td',
1068-
'pytorch': 'pytorch',
10691068
}
10701069
if os.getenv("TRITON_INTERPRET", "0") == "1":
10711070
# Skip triton providers if interpreter is used

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
167167
_, min_ms, max_ms, mean, cv = do_bench(torch_fn, device=DEVICE)
168168

169169
elif provider == 'triton':
170-
kernel_options = {'BLOCKS_ARE_CONTIGUOUS': True, 'USE_TMA': True}
170+
kernel_options = {'BLOCKS_ARE_CONTIGUOUS': True}
171171
triton_fn = lambda: compiled_flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=(
172172
not H_q == H_kv), kernel_options=kernel_options)
173173
if MODE == 'bwd':

0 commit comments

Comments
 (0)