Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 49 additions & 26 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,14 @@ class TllmGenFmhaRunnerCache {

void trtllm_paged_attention_launcher(
void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache,
void* workspace_buffer, int* block_tables, int* seq_lens, int* cum_seq_lens_q,
int* cum_seq_lens_kv, float* attention_sinks, Data_type q_data_type, Data_type kv_data_type,
Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len,
int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens,
int64_t q_stride_heads, int64_t kv_stride_keys_values, int64_t kv_stride_heads,
int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale,
void* workspace_buffer, int* block_tables, const void* k_block_scales_ptr,
const void* v_block_scales_ptr, int* seq_lens, int* cum_seq_lens_q, int* cum_seq_lens_kv,
float* attention_sinks, Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type,
TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len, int64_t max_kv_len,
int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk,
int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens, int64_t q_stride_heads,
int64_t kv_stride_keys_values, int64_t kv_stride_heads, int64_t kv_stride_batch,
int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale,
const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q,
int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl, int64_t workspace_size,
Expand All @@ -101,6 +102,8 @@ void trtllm_paged_attention_launcher(
runner_params.kPtr = key_cache;
runner_params.vPtr = value_cache;
runner_params.kvPageIdxPtr = block_tables;
runner_params.kSfBasePtr = k_block_scales_ptr;
runner_params.vSfBasePtr = v_block_scales_ptr;
runner_params.seqLensKvPtr = seq_lens;
runner_params.oPtr = out;
runner_params.mHeadDimQk = head_dim_qk;
Expand Down Expand Up @@ -128,6 +131,8 @@ void trtllm_paged_attention_launcher(
// outputScale. if they are not nullptr, then scaleSoftmaxLog2 and outputScale will be ignored
runner_params.outputScale = bmm2_scale;
runner_params.outputScalePtr = bmm2_scale_ptr;
runner_params.mScaleSfKv = 1.0f; // which should be fused into bmm1_scale(k)/bmm2_scale(v/o)
runner_params.kvSfScalePtr = nullptr;
runner_params.scaleSoftmaxLog2 = bmm1_scale * M_LOG2E;
runner_params.scaleSoftmaxLog2Ptr = bmm1_scale_log2_ptr;
runner_params.oSfPtr = out_scale_factor;
Expand Down Expand Up @@ -159,8 +164,11 @@ void trtllm_paged_attention_launcher(
runner_params.cumSeqLensQPtr = cum_seq_lens_q;
runner_params.cumSeqLensKvPtr = cum_seq_lens_kv;
} else {
// ForGen
runner_params.mMaskType = TrtllmGenAttentionMaskType::Dense;
// Generation.
// Note that kernel names are still labeled as using a dense mask even when maskType is
// specified as causal, this is expected for better performance as each CTA will only process
// one tokenQ in those cases, so dense mask works the same as causal mask.
runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal;
runner_params.mKernelType = FmhaKernelType::Generation;
bool use_multi_block = true;
runner_params.mTileScheduler =
Expand Down Expand Up @@ -214,17 +222,17 @@ inline Data_type dl_dtype_to_tllm_data_type(const DLDataType dtype) {

inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_TYPE_E2M1; }

void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scale_factor,
TensorView query, TensorView key_cache, TensorView value_cache,
TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
Variant<double, ffi::Tensor> bmm1_scale,
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t batch_size, int64_t window_left,
int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl,
int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> cum_seq_lens_q) {
void trtllm_paged_attention_decode(
TensorView out, Optional<TensorView> out_scale_factor, TensorView query, TensorView key_cache,
TensorView value_cache, TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
Variant<double, ffi::Tensor> bmm1_scale, Variant<double, ffi::Tensor> bmm2_scale,
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size,
int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl,
int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> cum_seq_lens_q, Optional<TensorView> key_block_scales,
Optional<TensorView> value_block_scales) {
fflush(stdout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This fflush(stdout) call appears to be a leftover debugging statement. It should be removed from production code to avoid potential performance impacts and unnecessary console output.

auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
Expand Down Expand Up @@ -252,18 +260,32 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
int max_num_blocks_per_seq = block_tables.size(-1);
bool is_shared_kv = key_cache.data_ptr() == value_cache.data_ptr();
int num_pages_in_mem_pool = is_shared_kv ? key_cache.size(0) : key_cache.size(0) * 2;
bool is_fp4_kv = is_4bit(kv_data_type);
int stride_idx_factor = is_fp4_kv ? 2 : 1;

// Assume NHD layout: [..., H, N, D]
// Assume HND layout: [..., H, N, D]
int page_size = key_cache.size(-2);
int num_kv_heads = key_cache.size(-3);
int kv_stride_keys_values = key_cache.stride(-2); // key/values
int kv_stride_heads = key_cache.stride(-3); // head
int kv_stride_batch = key_cache.stride(0); // batch
int kv_stride_keys_values = key_cache.stride(-2) * stride_idx_factor; // key/values
int kv_stride_heads = key_cache.stride(-3) * stride_idx_factor; // head
int kv_stride_batch = key_cache.stride(0) * stride_idx_factor; // batch

// Query stride: [num_tokens, num_heads, head_dim]
int q_stride_tokens = query.stride(0); // stride between tokens
int q_stride_heads = query.stride(1); // stride between heads

// kv block scales
if (is_fp4_kv) {
TVM_FFI_ICHECK(key_block_scales.has_value())
<< "key_block_scales must be provided for FP4 kv cache";
TVM_FFI_ICHECK(value_block_scales.has_value())
<< "value_block_scales must be provided for FP4 kv cache";
}
const void* k_block_scales_ptr =
key_block_scales.has_value() ? key_block_scales.value().data_ptr() : nullptr;
const void* v_block_scales_ptr =
value_block_scales.has_value() ? value_block_scales.value().data_ptr() : nullptr;

const auto stream = get_stream(query.device());
void* output_sf_ptr =
out_scale_factor.has_value() ? out_scale_factor.value().data_ptr() : nullptr;
Expand Down Expand Up @@ -295,8 +317,8 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
: nullptr;
trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
static_cast<int*>(seq_lens.data_ptr()), cum_seq_lens_q_ptr,
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), k_block_scales_ptr,
v_block_scales_ptr, static_cast<int*>(seq_lens.data_ptr()), cum_seq_lens_q_ptr,
/*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool,
num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, q_stride_tokens,
Expand Down Expand Up @@ -379,6 +401,7 @@ void trtllm_paged_attention_context(
trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()),
/*k_block_scales*/ nullptr, /*v_block_scales*/ nullptr,
static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/static_cast<int*>(cum_seq_lens_q.data_ptr()),
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr,
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class ArtifactPath:
When compiling new cubins for backend directories, update the corresponding path.
"""

TRTLLM_GEN_FMHA: str = "9f1b6ddaa1592a8339a82fcab7d27a57eff445fd/fmha/trtllm-gen/"
TRTLLM_GEN_FMHA: str = "81d3504ccf84d3ea0ff2ff4e2b15df2b63fb4160/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841"
)
Expand All @@ -107,7 +107,7 @@ class CheckSumHash:
"""

TRTLLM_GEN_FMHA: str = (
"a5a60600a80076317703695f56bbef2f0a44075ef4e24d7b06ba67ff68bc9da2"
"376d4de5a1bbb2a651bfd3c11d62cd55a0fe919c4669671675fc80c9934cd845"
)
TRTLLM_GEN_BMM: str = (
"b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc"
Expand Down
35 changes: 35 additions & 0 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2087,6 +2087,9 @@ def trtllm_batch_decode_with_kv_cache(
mask: Optional[torch.Tensor] = None,
max_q_len: Optional[int] = None,
cum_seq_lens_q: Optional[torch.Tensor] = None,
kv_block_scales: Optional[
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
] = None,
) -> Union[torch.Tensor, FP4Tensor]:
"""
Parameters
Expand Down Expand Up @@ -2190,6 +2193,33 @@ def trtllm_batch_decode_with_kv_cache(
# it doesn't change underlying storage
k_cache, v_cache = kv_cache.unbind(dim=1)

is_nvfp4_kvcache = (
k_cache.dtype == torch.uint8
and v_cache.dtype == torch.uint8
and kv_block_scales is not None
)

k_block_scales = None
v_block_scales = None
if is_nvfp4_kvcache:
if isinstance(kv_block_scales, tuple):
k_block_scales, v_block_scales = kv_block_scales
else:
if kv_block_scales.shape[1] == 1:
k_block_scales, v_block_scales = kv_block_scales, kv_block_scales
else:
assert kv_block_scales.shape[1] == 2, (
"When kv_block_scales is a single tensor, the second dimension must be 1 or 2"
)
# NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...])
# it doesn't change underlying storage
k_block_scales, v_block_scales = kv_block_scales.unbind(dim=1)

assert (
k_block_scales.dtype == torch.float8_e4m3fn
and v_block_scales.dtype == torch.float8_e4m3fn
), "k/v_block_scales should be float8 dtype."

if backend == "auto":
backend = (
"trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
Expand Down Expand Up @@ -2235,6 +2265,9 @@ def trtllm_batch_decode_with_kv_cache(
# For NHD: [..., N, H, D] -> HND: [..., H, N, D]
k_cache = k_cache.transpose(-3, -2)
v_cache = v_cache.transpose(-3, -2)
if is_nvfp4_kvcache:
k_block_scales = k_block_scales.transpose(-3, -2)
v_block_scales = v_block_scales.transpose(-3, -2)

run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode
sm_count = get_device_sm_count(query.device)
Expand Down Expand Up @@ -2350,6 +2383,8 @@ def trtllm_batch_decode_with_kv_cache(
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
cum_seq_lens_q,
k_block_scales,
v_block_scales,
)

return (
Expand Down
10 changes: 10 additions & 0 deletions flashinfer/jit/cubin_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def load_cubin(cubin_path: str, sha256: str) -> bytes:
try:
with open(cubin_path, mode="rb") as f:
cubin = f.read()
return cubin
if os.getenv("FLASHINFER_CUBIN_CHECKSUM_DISABLED"):
return cubin
m = hashlib.sha256()
Expand Down Expand Up @@ -205,6 +206,15 @@ def get_cubin(file_name: str, sha256: str, session=None) -> bytes:
return cubin
# either the file does not exist or it is corrupted, we'll download a new one.

# Check if cubin download is disabled
return b""
if os.getenv("FLASHINFER_DISABLE_CUBIN_DOWNLOAD"):
logger.error(
f"Cubin download is disabled (FLASHINFER_DISABLE_CUBIN_DOWNLOAD is set), "
f"but cubin file not found or corrupted: {cubin_path}"
)
return b""

uri = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, file_name)
logger.info(f"Fetching cubin {file_name} from {uri}")
download_file(uri, cubin_path, session=session)
Expand Down
Loading