Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ __global__ void fusedQKNormRopeKernel(
float factor, // factor in rope_scaling in config.json. When it is not 1.0, it means the model is using yarn.
float low, // threshold for high frequency
float high, // threshold for low frequency
float attention_factor // attention_factor applied on cos and sin
float attention_factor, // attention_factor applied on cos and sin
// stop of parameters for yarn
bool is_qk_norm // Whether to apply QK norm
)
{
int const warpsPerBlock = blockDim.x / 32;
Expand Down Expand Up @@ -136,20 +138,22 @@ __global__ void fusedQKNormRopeKernel(
}
}

// Reduce sum across warp using the utility function
sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares);
if (is_qk_norm)
{
// Reduce sum across warp using the utility function
sumOfSquares = tensorrt_llm::common::warpReduceSum(sumOfSquares);

// Compute RMS normalization factor
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);
// Compute RMS normalization factor
float rms_rcp = rsqrtf(sumOfSquares / static_cast<float>(head_dim) + eps);

// Normalize elements
for (int i = 0; i < numElemsPerThread; i++)
{
int dim = laneId * numElemsPerThread + i;
float weight = isQ ? __bfloat162float(q_weight[dim]) : __bfloat162float(k_weight[dim]);
elements[i] *= rms_rcp * weight;
// Normalize elements
for (int i = 0; i < numElemsPerThread; i++)
{
int dim = laneId * numElemsPerThread + i;
float weight = isQ ? __bfloat162float(q_weight[dim]) : __bfloat162float(k_weight[dim]);
elements[i] *= rms_rcp * weight;
}
}

// Apply RoPE to normalized elements
float elements2[numElemsPerThread]; // Additional buffer required for RoPE.
float cos_vals[numElemsPerThread];
Expand Down Expand Up @@ -276,7 +280,7 @@ __global__ void fusedQKNormRopeKernel(
void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_q, int const num_heads_k,
int const num_heads_v, int const head_dim, float const eps, void const* q_weight, void const* k_weight,
float const base, bool const interleave, int const* position_ids, float factor, float low, float high,
float attention_factor, cudaStream_t stream)
float attention_factor, cudaStream_t stream, bool is_qk_norm)
{
if (factor == 1.0f)
{
Expand All @@ -301,23 +305,23 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
fusedQKNormRopeKernel<64, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
base, position_ids, num_tokens, factor, low, high, attention_factor);
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
});
break;
case 128:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<128, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
base, position_ids, num_tokens, factor, low, high, attention_factor);
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
});
break;
case 256:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<256, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
base, position_ids, num_tokens, factor, low, high, attention_factor);
base, position_ids, num_tokens, factor, low, high, attention_factor, is_qk_norm);
});
break;
default: TLLM_THROW("Unsupported head dimension for fusedQKNormRope: %d", head_dim);
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ void launchFusedQKNormRope(
float low, // threshold for high frequency
float high, // threshold for low frequency
float attention_factor, // attention_factor applied on cos and sin
cudaStream_t stream); // CUDA stream
cudaStream_t stream, // CUDA stream
bool is_qk_norm);

} // namespace kernels
} // namespace tensorrt_llm
7 changes: 4 additions & 3 deletions cpp/tensorrt_llm/thop/fusedQKNormRopeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ void fused_qk_norm_rope(
double factor, // factor in rope_scaling in config.json. When it is not 1.0, it means the model is using yarn.
double low, // threshold for high frequency
double high, // threshold for low frequency
double attention_factor // attention_factor applied on cos and sin
double attention_factor, // attention_factor applied on cos and sin
bool is_qk_norm // Whether to apply QK norm
)
{
// Input validation
Expand Down Expand Up @@ -74,7 +75,7 @@ void fused_qk_norm_rope(
static_cast<float>(base),
!is_neox, // interleave
reinterpret_cast<int const*>(position_ids.data_ptr()), static_cast<float>(factor), static_cast<float>(low),
static_cast<float>(high), static_cast<float>(attention_factor), stream);
static_cast<float>(high), static_cast<float>(attention_factor), stream, is_qk_norm);
}

// Register the PyTorch operators
Expand All @@ -83,7 +84,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
m.def(
"fused_qk_norm_rope(Tensor(a!) qkv, int num_heads_q, int num_heads_k, int num_heads_v, int head_dim, float "
"eps, Tensor q_weight, Tensor k_weight, float base, bool is_neox, Tensor position_ids, float factor, float "
"low, float high, float attention_factor) -> ()");
"low, float high, float attention_factor, bool is_qk_norm) -> ()");
}

// Register the CUDA implementation
Expand Down
67 changes: 63 additions & 4 deletions tensorrt_llm/_torch/models/modeling_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode
from ..modules.qk_norm_attention import QKNormRoPEAttention
from ..modules.rms_norm import RMSNorm
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
register_auto_model)
Expand Down Expand Up @@ -53,6 +54,56 @@ def __init__(
)


# TODO this is a workaround to support yarn on Qwen2.
# We need refactor the codes to merge QwenYarnAttention and QwenAttention.
class QwenYarnAttention(QKNormRoPEAttention):

def __init__(
self,
model_config: ModelConfig[Qwen2Config],
layer_idx: Optional[int] = None,
fuse_qk_norm_rope: bool = True,
):
config = model_config.pretrained_config

if getattr(config, "rope_scaling", None) is not None:
if "type" in config.rope_scaling:
pos_type = config.rope_scaling["type"]
elif "rope_type" in config.rope_scaling:
pos_type = config.rope_scaling["rope_type"]
else:
raise ValueError(
"rope_scaling must have type or rope_type field")
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.from_string(pos_type),
rope=RopeParams.from_config(config),
)
else:
pos_embd_params = PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
)

# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
# and https://nvbugspro.nvidia.com/bug/5505402)
disable_deep_gemm = True
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=True,
pos_embd_params=pos_embd_params,
fuse_qk_norm_rope=fuse_qk_norm_rope,
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=False,
config=model_config,
disable_deep_gemm=disable_deep_gemm,
is_qk_norm=False,
)


class QwenDecoderLayer(DecoderLayer):

def __init__(
Expand All @@ -63,10 +114,18 @@ def __init__(
super().__init__()
self.layer_idx = layer_idx
config = model_config.pretrained_config
self.self_attn = QwenAttention(
model_config,
layer_idx=layer_idx,
)

if getattr(config, "rope_scaling", None) is not None and getattr(
config.rope_scaling, "rope_type", None) == "yarn":
self.self_attn = QwenYarnAttention(
model_config,
layer_idx=layer_idx,
)
else:
self.self_attn = QwenAttention(
model_config,
layer_idx=layer_idx,
)

self.mlp = GatedMLP(
hidden_size=config.hidden_size,
Expand Down
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/modules/qk_norm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def __init__(
disable_deep_gemm: bool = False,
use_gemma_rms_norm: bool = False,
attn_output_gate: Optional[bool] = None,
is_qk_norm: bool = True,
):
self.pretrained_config = config.pretrained_config

Expand All @@ -169,6 +170,7 @@ def __init__(
# If fuse_qk_norm_rope is true, do not apply fused RoPE in attention OP, and self.rotary_emb
# will be skipped in the overridden apply_rope.
rope_fusion = not self.fuse_qk_norm_rope and not skip_rope and not attn_output_gate and not use_gemma_rms_norm
self.is_qk_norm = is_qk_norm
assert not (fuse_qk_norm_rope and skip_rope
), "Fusing qk norm and skipping rope is not supported"

Expand All @@ -192,12 +194,12 @@ def __init__(
self.q_norm = RMSNorm(hidden_size=self.head_dim,
eps=self.pretrained_config.rms_norm_eps,
dtype=self.pretrained_config.torch_dtype,
has_weights=True,
has_weights=is_qk_norm,
use_gemma=use_gemma_rms_norm)
self.k_norm = RMSNorm(hidden_size=self.head_dim,
eps=self.pretrained_config.rms_norm_eps,
dtype=self.pretrained_config.torch_dtype,
has_weights=True,
has_weights=is_qk_norm,
use_gemma=use_gemma_rms_norm)
self.aux_stream = torch.cuda.Stream()
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
Expand Down Expand Up @@ -231,7 +233,8 @@ def apply_qk_norm_rope(self, qkv, position_ids):
self.q_norm.variance_epsilon, self.q_norm.weight,
self.k_norm.weight,
self.pos_embd_params.rope.theta, self.pos_embd_params.is_neox,
position_ids.view(-1), factor, low, high, attention_factor)
position_ids.view(-1), factor, low, high, attention_factor,
self.is_qk_norm)
return qkv, None, None

def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def test_fused_qk_norm_rope(head_dim, num_heads_group, num_tokens, is_neox,
torch.ops.trtllm.fused_qk_norm_rope(qkv, num_heads_q, num_heads_k,
num_heads_v, head_dim, eps, q_weight,
k_weight, base, is_neox, position_ids,
factor, low, high, attention_factor)
factor, low, high, attention_factor,
True)
output = qkv # This op is inplace

# Compute reference output using TensorRT LLM modules
Expand Down