Skip to content

Commit 5800181

Browse files
committed
Rebase the ck_tile_gemm branch to rocm/355_wip
1 parent f4a4bdb commit 5800181

File tree

14 files changed

+146
-105
lines changed

14 files changed

+146
-105
lines changed

csrc/layernorm_kernels.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ __global__ void rms_norm_kernel(
5151
template <typename scalar_t, int width>
5252
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
5353
fused_add_rms_norm_kernel(
54-
scalar_t* __restrict__ output, // [..., hidden_size]
55-
const scalar_t* __restrict__ input, // [..., hidden_size]
54+
scalar_t* __restrict__ output, // [..., hidden_size]
55+
const scalar_t* __restrict__ input, // [..., hidden_size]
5656
const int64_t input_stride,
5757
scalar_t* __restrict__ residual_out, // [..., hidden_size]
5858
const scalar_t* __restrict__ residual, // [..., hidden_size]
@@ -114,8 +114,8 @@ fused_add_rms_norm_kernel(
114114
template <typename scalar_t, int width>
115115
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
116116
fused_add_rms_norm_kernel(
117-
scalar_t* __restrict__ output, // [..., hidden_size]
118-
const scalar_t* __restrict__ input, // [..., hidden_size]
117+
scalar_t* __restrict__ output, // [..., hidden_size]
118+
const scalar_t* __restrict__ input, // [..., hidden_size]
119119
const int64_t input_stride,
120120
scalar_t* __restrict__ residual_out, // [..., hidden_size]
121121
const scalar_t* __restrict__ residual, // [..., hidden_size]
@@ -221,9 +221,10 @@ void fused_add_rms_norm(torch::Tensor& out, // [..., hidden_size]
221221
constexpr int req_alignment_bytes =
222222
vector_width * 2; // vector_width * sizeof(bfloat16 or float16) (float32
223223
// falls back to non-vectorized version anyway)
224-
bool ptrs_are_aligned = out_ptr % 16 == 0 && inp_ptr % req_alignment_bytes == 0 &&
225-
res_out_ptr % 16 == 0 && res_ptr % req_alignment_bytes == 0 &&
226-
wt_ptr % req_alignment_bytes == 0;
224+
bool ptrs_are_aligned =
225+
out_ptr % 16 == 0 && inp_ptr % req_alignment_bytes == 0 &&
226+
res_out_ptr % 16 == 0 && res_ptr % req_alignment_bytes == 0 &&
227+
wt_ptr % req_alignment_bytes == 0;
227228
bool offsets_are_multiple_of_vector_width =
228229
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
229230
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {

csrc/layernorm_quant_kernels.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ __global__ void rms_norm_static_fp8_quant_kernel(
6464
template <typename scalar_t, int width, typename fp8_type>
6565
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
6666
fused_add_rms_norm_static_fp8_quant_kernel(
67-
fp8_type* __restrict__ out, // [..., hidden_size]
68-
scalar_t* __restrict__ input, // [..., hidden_size]
67+
fp8_type* __restrict__ out, // [..., hidden_size]
68+
scalar_t* __restrict__ input, // [..., hidden_size]
6969
const int input_stride,
7070
scalar_t* __restrict__ residual_out, // [..., hidden_size]
7171
scalar_t* __restrict__ residual, // [..., hidden_size]
@@ -132,8 +132,8 @@ fused_add_rms_norm_static_fp8_quant_kernel(
132132
template <typename scalar_t, int width, typename fp8_type>
133133
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
134134
fused_add_rms_norm_static_fp8_quant_kernel(
135-
fp8_type* __restrict__ out, // [..., hidden_size]
136-
scalar_t* __restrict__ input, // [..., hidden_size]
135+
fp8_type* __restrict__ out, // [..., hidden_size]
136+
scalar_t* __restrict__ input, // [..., hidden_size]
137137
const int input_stride,
138138
scalar_t* __restrict__ residual_out, // [..., hidden_size]
139139
scalar_t* __restrict__ residual, // [..., hidden_size]
@@ -210,8 +210,8 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
210210
width, fp8_t> \
211211
<<<grid, block, 0, stream>>>( \
212212
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
213-
input_stride, residual_out.data_ptr<scalar_t>(), \
214-
residual.data_ptr<scalar_t>(), \
213+
input_stride, residual_out.data_ptr<scalar_t>(), \
214+
residual.data_ptr<scalar_t>(), \
215215
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
216216
epsilon, num_tokens, hidden_size); \
217217
}); \

tests/kernels/core/test_layernorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def test_fused_rms_norm_quant(
130130
out_unfused = torch.empty_like(x_unfused)
131131
torch.ops._C.fused_add_rms_norm(out_unfused, x_unfused, residual_out,
132132
residual, weight, 1e-6)
133-
torch.ops._C.static_scaled_fp8_quant(out_quant, out_unfused.contiguous(),
133+
torch.ops._C.static_scaled_fp8_quant(out_quant,
134+
out_unfused.contiguous(),
134135
quant_scale_t)
135136

136137
torch.cuda.synchronize()

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,47 +25,47 @@ def cdiv_fn(x, y):
2525

2626
@triton.jit
2727
def kernel_paged_attention_2d(
28-
output_ptr, # [num_tokens, num_query_heads, head_size]
29-
query_ptr, # [num_tokens, num_query_heads, head_size]
30-
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
31-
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
32-
sink_ptr, # [num_query_heads]
33-
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
34-
seq_lens_ptr, # [num_seqs]
35-
alibi_slopes_ptr, # [num_query_heads]
36-
scale, # float32
37-
k_scale, # float32
38-
v_scale, # float32
39-
out_scale,
40-
num_query_heads: tl.constexpr, # int
41-
num_queries_per_kv: tl.constexpr, # int
42-
num_queries_per_kv_padded: tl.constexpr, # int
43-
block_table_stride: tl.int64, # int
44-
query_stride_0: tl.int64, # int
45-
query_stride_1: tl.int64, # int, should be equal to head_size
46-
output_stride_0: tl.int64, # int
47-
output_stride_1: tl.int64, # int, should be equal to head_size
48-
BLOCK_SIZE: tl.constexpr, # int
49-
HEAD_SIZE: tl.constexpr, # int
50-
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
51-
USE_ALIBI_SLOPES: tl.constexpr, # bool
52-
SLIDING_WINDOW: tl.constexpr, # int
53-
x: tl.constexpr, # int
54-
stride_k_cache_0: tl.int64, # int
55-
stride_k_cache_1: tl.int64, # int
56-
stride_k_cache_2: tl.int64, # int
57-
stride_k_cache_3: tl.int64, # int
58-
stride_k_cache_4: tl.int64, # int
59-
stride_v_cache_0: tl.int64, # int
60-
stride_v_cache_1: tl.int64, # int
61-
stride_v_cache_2: tl.int64, # int
62-
stride_v_cache_3: tl.int64, # int
63-
filter_by_query_len: tl.constexpr, # bool
64-
query_start_len_ptr, # [num_seqs+1]
65-
USE_FP8: tl.constexpr,
66-
USE_SINKS: tl.constexpr, # bool
67-
FP8_MIN: tl.constexpr = float8_info.min,
68-
FP8_MAX: tl.constexpr = float8_info.max,
28+
output_ptr, # [num_tokens, num_query_heads, head_size]
29+
query_ptr, # [num_tokens, num_query_heads, head_size]
30+
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
31+
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
32+
sink_ptr, # [num_query_heads]
33+
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
34+
seq_lens_ptr, # [num_seqs]
35+
alibi_slopes_ptr, # [num_query_heads]
36+
scale, # float32
37+
k_scale, # float32
38+
v_scale, # float32
39+
out_scale,
40+
num_query_heads: tl.constexpr, # int
41+
num_queries_per_kv: tl.constexpr, # int
42+
num_queries_per_kv_padded: tl.constexpr, # int
43+
block_table_stride: tl.int64, # int
44+
query_stride_0: tl.int64, # int
45+
query_stride_1: tl.int64, # int, should be equal to head_size
46+
output_stride_0: tl.int64, # int
47+
output_stride_1: tl.int64, # int, should be equal to head_size
48+
BLOCK_SIZE: tl.constexpr, # int
49+
HEAD_SIZE: tl.constexpr, # int
50+
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
51+
USE_ALIBI_SLOPES: tl.constexpr, # bool
52+
SLIDING_WINDOW: tl.constexpr, # int
53+
x: tl.constexpr, # int
54+
stride_k_cache_0: tl.int64, # int
55+
stride_k_cache_1: tl.int64, # int
56+
stride_k_cache_2: tl.int64, # int
57+
stride_k_cache_3: tl.int64, # int
58+
stride_k_cache_4: tl.int64, # int
59+
stride_v_cache_0: tl.int64, # int
60+
stride_v_cache_1: tl.int64, # int
61+
stride_v_cache_2: tl.int64, # int
62+
stride_v_cache_3: tl.int64, # int
63+
filter_by_query_len: tl.constexpr, # bool
64+
query_start_len_ptr, # [num_seqs+1]
65+
USE_FP8: tl.constexpr,
66+
USE_SINKS: tl.constexpr, # bool
67+
FP8_MIN: tl.constexpr = float8_info.min,
68+
FP8_MAX: tl.constexpr = float8_info.max,
6969
):
7070
seq_idx = tl.program_id(0)
7171
kv_head_idx = tl.program_id(1)

vllm/attention/ops/triton_unified_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from vllm.logger import init_logger
1313
from vllm.platforms import current_platform
1414
from vllm.triton_utils import tl, triton
15-
from vllm.platforms import current_platform
1615

1716
logger = init_logger(__name__)
1817
float8_info = torch.finfo(current_platform.fp8_dtype())

vllm/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def get_vllm_port() -> Optional[int]:
754754
"VLLM_ROCM_USE_AITER_CK_TILE_LINEAR":
755755
lambda: (os.getenv("VLLM_ROCM_USE_AITER_CK_TILE_LINEAR", "True").lower() in
756756
("true", "1")),
757-
757+
758758
# Whether to use aiter moe ops.
759759
# By default is enabled.
760760
"VLLM_ROCM_USE_AITER_MOE":

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,11 @@ def __init__(self, quant_config: Fp8Config):
205205
and envs.VLLM_ROCM_USE_AITER
206206
and envs.VLLM_ROCM_USE_AITER_LINEAR
207207
and current_platform.is_fp8_fnuz())
208-
self.use_ck_tile_and_is_supported = (current_platform.is_rocm()
209-
and envs.VLLM_ROCM_USE_AITER
210-
and envs.VLLM_ROCM_USE_AITER_CK_TILE_LINEAR
211-
and current_platform.is_fp8_fnuz())
212-
208+
self.use_ck_tile_and_is_supported = (
209+
current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
210+
and envs.VLLM_ROCM_USE_AITER_CK_TILE_LINEAR
211+
and current_platform.is_fp8_fnuz())
212+
213213
self.block_quant = self.quant_config.weight_block_size is not None
214214
self.act_q_static = self.quant_config.activation_scheme == "static"
215215
# Use per-token quantization for better perf if dynamic and cutlass

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def rocm_aiter_gemm_w8a8_blockscale_fake(
9090

9191
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
9292

93+
9394
def rocm_aiter_ck_tile_gemm_w8a8_blockscale_impl(
9495
A: torch.Tensor,
9596
B: torch.Tensor,
@@ -100,7 +101,11 @@ def rocm_aiter_ck_tile_gemm_w8a8_blockscale_impl(
100101
) -> torch.Tensor:
101102
import aiter as rocm_aiter
102103

103-
return rocm_aiter.gemm_a8w8_blockscale_ck_tile(A, B, As, Bs, dtype=output_dtype)
104+
return rocm_aiter.gemm_a8w8_blockscale_ck_tile(A,
105+
B,
106+
As,
107+
Bs,
108+
dtype=output_dtype)
104109

105110

106111
def rocm_aiter_ck_tile_gemm_w8a8_blockscale_fake(
@@ -136,7 +141,8 @@ def rocm_aiter_ck_tile_gemm_w8a8_blockscale_fake(
136141

137142

138143
def dispatch_w8a8_blockscale_func(
139-
use_cutlass: bool, use_aiter_and_is_supported: bool, use_ck_tile_and_is_supported: bool
144+
use_cutlass: bool, use_aiter_and_is_supported: bool,
145+
use_ck_tile_and_is_supported: bool
140146
) -> Callable[[
141147
torch.Tensor,
142148
torch.Tensor,

vllm/model_executor/layers/quantization/utils/mxfp4_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
3434
elif current_platform.is_rocm():
3535
from triton_kernels.target_info import is_hip
3636
from triton_kernels.tensor_details.layout import (
37-
BlackwellMXScaleLayout, HopperMXScaleLayout, HopperMXValueLayout,
38-
GFX950MXScaleLayout)
37+
BlackwellMXScaleLayout, GFX950MXScaleLayout, HopperMXScaleLayout,
38+
HopperMXValueLayout)
3939
value_layout = StridedLayout
4040
scale_layout = StridedLayout
4141
if not is_hip():
@@ -53,7 +53,8 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps):
5353
else:
5454
""" weight swizzle for mxfp4 moe, used for OAI mxfp4 kernel
5555
"""
56-
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
56+
value_layout, value_layout_opts = \
57+
layout.make_default_matmul_mxfp4_w_layout(
5758
mx_axis=1)
5859
scale_layout, scale_layout_opts = (
5960
layout.make_default_matmul_mxfp4_w_scale_layout(

vllm/model_executor/layers/rotary_embedding/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from vllm.model_executor.custom_op import CustomOp
1010

1111
from .common import apply_rotary_emb_dispatch, apply_rotary_emb_torch
12-
from .rocm_aiter_rope_ops import is_rocm_rotary_embedding_enabled, is_rocm_triton_rotary_embedding_enabled
12+
from .rocm_aiter_rope_ops import (is_rocm_rotary_embedding_enabled,
13+
is_rocm_triton_rotary_embedding_enabled)
14+
1315

1416
@CustomOp.register("rotary_embedding")
1517
class RotaryEmbedding(CustomOp):
@@ -36,8 +38,11 @@ def __init__(
3638
cache = cache.to(dtype)
3739
self.cos_sin_cache: torch.Tensor
3840
self.register_buffer("cos_sin_cache", cache, persistent=False)
39-
self.is_rocm_aiter_enabled = is_rocm_rotary_embedding_enabled()
40-
self.is_rocm_aiter_triton_enabled = is_rocm_triton_rotary_embedding_enabled()
41+
self.is_rocm_aiter_enabled = \
42+
is_rocm_rotary_embedding_enabled()
43+
self.is_rocm_aiter_triton_enabled = \
44+
is_rocm_triton_rotary_embedding_enabled(
45+
)
4146

4247
def _compute_inv_freq(self, base: float) -> torch.Tensor:
4348
"""Compute the inverse frequency."""

0 commit comments

Comments
 (0)