Skip to content

Commit 0f04865

Browse files
author
Doug Lehr
committed
Add Day 0 vllm changes to 355_wip branch
1 parent 8eb058e commit 0f04865

File tree

3 files changed

+82
-16
lines changed

3 files changed

+82
-16
lines changed

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
148148
hidden_size = round_up(hidden_size, 256)
149149
elif current_platform.is_rocm():
150150
intermediate_size_per_partition_after_pad = round_up(
151-
intermediate_size_per_partition, 128)
151+
intermediate_size_per_partition, 256)
152152
else:
153153
intermediate_size_per_partition_after_pad = round_up(
154154
intermediate_size_per_partition, 64)

vllm/model_executor/layers/utils.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from vllm import envs
1010
from vllm.platforms import current_platform
1111
from vllm.utils import direct_register_custom_op
12-
12+
import os
13+
if current_platform.is_rocm():
14+
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
15+
VLLM_USE_AITER_TRITON_GEMM = (os.getenv("VLLM_USE_AITER_TRITON_GEMM", "False").lower() in ("true", "1"))
1316

1417
def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
1518
# Shuffle weight along the last dimension so that
@@ -91,23 +94,36 @@ def default_unquantized_gemm(layer: torch.nn.Module,
9194
bias: Optional[torch.Tensor] = None):
9295
return torch.nn.functional.linear(x, weight, bias)
9396

97+
def aiter_GEMM_check(m, n, k):
98+
if ((n == 5120 and k == 2880)
99+
or (n == 2880 and k == 4096)
100+
or (n == 128 and k == 2880)
101+
or (n == 640 and k == 2880)
102+
or (n == 2880 and k == 512)):
103+
return True
104+
return False
105+
106+
94107

95108
def rocm_unquantized_gemm_impl(
96109
x: torch.Tensor,
97110
weight: torch.Tensor,
98111
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
99112
from vllm.platforms.rocm import on_gfx9
100113
k = weight.shape[1]
114+
m = weight.shape[0]
115+
x_view = x.view(-1, x.size(-1))
116+
n = x_view.shape[0]
101117
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
102118
x.dtype in [torch.float16, torch.bfloat16] \
103119
and k % 8 == 0 and bias is None)
104120

105121
if use_skinny is not True:
106122
return torch.nn.functional.linear(x, weight, bias)
107123

108-
x_view = x.view(-1, x.size(-1))
109-
n = x_view.shape[0]
110-
m = weight.shape[0]
124+
# x_view = x.view(-1, x.size(-1))
125+
# n = x_view.shape[0]
126+
# m = weight.shape[0]
111127
cu_count = current_platform.get_cu_count()
112128

113129
if m > 8 and 0 < n <= 4:

vllm/model_executor/models/gpt_oss.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,21 @@
2626
from vllm.model_executor.sampling_metadata import SamplingMetadata
2727
from vllm.sequence import IntermediateTensors
2828
from vllm.utils import cdiv
29-
29+
from vllm.platforms import current_platform
3030
from .utils import extract_layer_index, maybe_prefix
31+
import os
32+
33+
if current_platform.is_rocm():
34+
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
3135

3236

37+
VLLM_USE_AITER_TRITON_FUSED_SPLIT_QKV_ROPE = (os.getenv("VLLM_USE_AITER_TRITON_FUSED_SPLIT_QKV_ROPE", "False").lower() in ("true", "1"))
38+
VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD = (os.getenv("VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD", "False").lower() in ("true", "1"))
39+
if VLLM_USE_AITER_TRITON_FUSED_SPLIT_QKV_ROPE:
40+
from aiter.ops.triton.fused_qkv_split_qk_rope import fused_qkv_split_qk_rope
41+
if VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD:
42+
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
43+
3344
class OAIAttention(nn.Module):
3445

3546
def __init__(
@@ -118,15 +129,38 @@ def __init__(
118129

119130
def forward(self, hidden_states: torch.Tensor,
120131
positions: torch.Tensor) -> torch.Tensor:
121-
t = self.norm(hidden_states)
122-
132+
# t = self.norm(hidden_states)
133+
if isinstance(hidden_states, tuple) and current_platform.is_rocm() and VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD:
134+
hidden_states, res = hidden_states
135+
t, hidden_states = fused_add_rmsnorm_pad(hidden_states, self.norm.weight, self.norm.variance_epsilon, res)
136+
else:
137+
t = self.norm(hidden_states)
123138
qkv, _ = self.qkv(t)
124-
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
125-
q, k = self.rotary_emb(positions, q, k)
139+
# q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
140+
# q, k = self.rotary_emb(positions, q, k)
141+
if VLLM_USE_AITER_TRITON_FUSED_SPLIT_QKV_ROPE:
142+
cos, sin = self.rotary_emb.cos_sin_cache.chunk(2, dim = -1)
143+
q, k, v = fused_qkv_split_qk_rope(
144+
qkv,
145+
cos,
146+
sin,
147+
positions,
148+
self.num_local_attention_heads, self.num_local_key_value_heads, self.head_dim,
149+
is_neox=self.rotary_emb.is_neox_style,
150+
offsets = None,
151+
reuse_freqs_front_part = (self.head_dim // 2 == cos.shape[-1]),
152+
nope_first = False,
153+
)
154+
q = q.view(-1, self.q_size)
155+
k = k.view(-1, self.kv_size)
156+
else:
157+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
158+
q, k = self.rotary_emb(positions, q, k)
126159
v = v.contiguous()
127160
attn_output = self.attn(q, k, v)
128161
output, _ = self.o_proj(attn_output)
129-
162+
if current_platform.is_rocm() and VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD:
163+
return output, hidden_states
130164
return output + hidden_states
131165

132166

@@ -144,6 +178,7 @@ def __init__(
144178
self.num_experts = config.num_local_experts
145179
self.experts_per_token = config.num_experts_per_tok
146180
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
181+
self.hidden_size = config.hidden_size
147182
self.norm = RMSNorm(config.hidden_size, eps=1e-5)
148183
self.router = torch.nn.Linear(config.hidden_size,
149184
config.num_local_experts,
@@ -161,10 +196,21 @@ def __init__(
161196
has_bias=True,
162197
activation="swiglu_oai")
163198

164-
def forward(self, x: torch.Tensor) -> torch.Tensor:
165-
t = self.norm(x)
166-
g = self.router(t)
167-
t = self.experts(hidden_states=t, router_logits=g)
199+
def forward(self, x: torch.Tensor | tuple) -> torch.Tensor:
200+
if isinstance(x, tuple) and current_platform.is_rocm() and VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD:
201+
x, res = x
202+
t, x = fused_add_rmsnorm_pad(x, self.norm.weight, self.norm.variance_epsilon, res, x_pad_to_multiple=256)
203+
else:
204+
t = self.norm(x)
205+
206+
if current_platform.is_rocm():
207+
g = gemm_a16w16(t[:, :self.hidden_size], self.router.weight, self.router.bias)
208+
else:
209+
g = self.router(t[:, :self.hidden_size])
210+
t = self.experts(hidden_states=t, router_logits=g)[:, :self.hidden_size]
211+
212+
if current_platform.is_rocm() and VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD:
213+
return x, t
168214
return x + t
169215

170216

@@ -222,7 +268,11 @@ def forward(self, input_ids: torch.Tensor,
222268
x = self.embedding(input_ids)
223269
for layer in self.layers:
224270
x = layer(x, positions)
225-
x = self.norm(x)
271+
if isinstance(x, tuple) and current_platform.is_rocm() and VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD:
272+
x, res = x
273+
x, _ = fused_add_rmsnorm_pad(x, self.norm.weight, self.norm.variance_epsilon, res)
274+
else:
275+
x = self.norm(x)
226276
return x
227277

228278

0 commit comments

Comments
 (0)