26
26
from vllm .model_executor .sampling_metadata import SamplingMetadata
27
27
from vllm .sequence import IntermediateTensors
28
28
from vllm .utils import cdiv
29
-
29
+ from vllm . platforms import current_platform
30
30
from .utils import extract_layer_index , maybe_prefix
31
+ import os
32
+
33
+ if current_platform .is_rocm ():
34
+ from aiter .ops .triton .gemm_a16w16 import gemm_a16w16
31
35
32
36
37
+ VLLM_USE_AITER_TRITON_FUSED_SPLIT_QKV_ROPE = (os .getenv ("VLLM_USE_AITER_TRITON_FUSED_SPLIT_QKV_ROPE" , "False" ).lower () in ("true" , "1" ))
38
+ VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD = (os .getenv ("VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD" , "False" ).lower () in ("true" , "1" ))
39
+ if VLLM_USE_AITER_TRITON_FUSED_SPLIT_QKV_ROPE :
40
+ from aiter .ops .triton .fused_qkv_split_qk_rope import fused_qkv_split_qk_rope
41
+ if VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD :
42
+ from aiter .ops .triton .fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
43
+
33
44
class OAIAttention (nn .Module ):
34
45
35
46
def __init__ (
@@ -118,15 +129,38 @@ def __init__(
118
129
119
130
def forward (self , hidden_states : torch .Tensor ,
120
131
positions : torch .Tensor ) -> torch .Tensor :
121
- t = self .norm (hidden_states )
122
-
132
+ # t = self.norm(hidden_states)
133
+ if isinstance (hidden_states , tuple ) and current_platform .is_rocm () and VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD :
134
+ hidden_states , res = hidden_states
135
+ t , hidden_states = fused_add_rmsnorm_pad (hidden_states , self .norm .weight , self .norm .variance_epsilon , res )
136
+ else :
137
+ t = self .norm (hidden_states )
123
138
qkv , _ = self .qkv (t )
124
- q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
125
- q , k = self .rotary_emb (positions , q , k )
139
+ # q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
140
+ # q, k = self.rotary_emb(positions, q, k)
141
+ if VLLM_USE_AITER_TRITON_FUSED_SPLIT_QKV_ROPE :
142
+ cos , sin = self .rotary_emb .cos_sin_cache .chunk (2 , dim = - 1 )
143
+ q , k , v = fused_qkv_split_qk_rope (
144
+ qkv ,
145
+ cos ,
146
+ sin ,
147
+ positions ,
148
+ self .num_local_attention_heads , self .num_local_key_value_heads , self .head_dim ,
149
+ is_neox = self .rotary_emb .is_neox_style ,
150
+ offsets = None ,
151
+ reuse_freqs_front_part = (self .head_dim // 2 == cos .shape [- 1 ]),
152
+ nope_first = False ,
153
+ )
154
+ q = q .view (- 1 , self .q_size )
155
+ k = k .view (- 1 , self .kv_size )
156
+ else :
157
+ q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
158
+ q , k = self .rotary_emb (positions , q , k )
126
159
v = v .contiguous ()
127
160
attn_output = self .attn (q , k , v )
128
161
output , _ = self .o_proj (attn_output )
129
-
162
+ if current_platform .is_rocm () and VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD :
163
+ return output , hidden_states
130
164
return output + hidden_states
131
165
132
166
@@ -144,6 +178,7 @@ def __init__(
144
178
self .num_experts = config .num_local_experts
145
179
self .experts_per_token = config .num_experts_per_tok
146
180
self .world_size = dist .get_world_size () if dist .is_initialized () else 1
181
+ self .hidden_size = config .hidden_size
147
182
self .norm = RMSNorm (config .hidden_size , eps = 1e-5 )
148
183
self .router = torch .nn .Linear (config .hidden_size ,
149
184
config .num_local_experts ,
@@ -161,10 +196,21 @@ def __init__(
161
196
has_bias = True ,
162
197
activation = "swiglu_oai" )
163
198
164
- def forward (self , x : torch .Tensor ) -> torch .Tensor :
165
- t = self .norm (x )
166
- g = self .router (t )
167
- t = self .experts (hidden_states = t , router_logits = g )
199
+ def forward (self , x : torch .Tensor | tuple ) -> torch .Tensor :
200
+ if isinstance (x , tuple ) and current_platform .is_rocm () and VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD :
201
+ x , res = x
202
+ t , x = fused_add_rmsnorm_pad (x , self .norm .weight , self .norm .variance_epsilon , res , x_pad_to_multiple = 256 )
203
+ else :
204
+ t = self .norm (x )
205
+
206
+ if current_platform .is_rocm ():
207
+ g = gemm_a16w16 (t [:, :self .hidden_size ], self .router .weight , self .router .bias )
208
+ else :
209
+ g = self .router (t [:, :self .hidden_size ])
210
+ t = self .experts (hidden_states = t , router_logits = g )[:, :self .hidden_size ]
211
+
212
+ if current_platform .is_rocm () and VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD :
213
+ return x , t
168
214
return x + t
169
215
170
216
@@ -222,7 +268,11 @@ def forward(self, input_ids: torch.Tensor,
222
268
x = self .embedding (input_ids )
223
269
for layer in self .layers :
224
270
x = layer (x , positions )
225
- x = self .norm (x )
271
+ if isinstance (x , tuple ) and current_platform .is_rocm () and VLLM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD :
272
+ x , res = x
273
+ x , _ = fused_add_rmsnorm_pad (x , self .norm .weight , self .norm .variance_epsilon , res )
274
+ else :
275
+ x = self .norm (x )
226
276
return x
227
277
228
278
0 commit comments