Skip to content

Commit 8cec2da

Browse files
bkryunekorobov
andauthored
[None][feat] Port fp4 quantization kernel optimization from FlashInfer (#9854)
Signed-off-by: Brian Ryu <[email protected]> Co-authored-by: Nikita Korobov <[email protected]>
1 parent 8fefa2c commit 8cec2da

File tree

1 file changed

+77
-42
lines changed

1 file changed

+77
-42
lines changed

cpp/tensorrt_llm/kernels/quantization.cuh

Lines changed: 77 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -794,67 +794,102 @@ quantize_with_block_size(
794794

795795
asm volatile("griddepcontrol.wait;");
796796
// Input tensor batch/row/col loops.
797+
// Optimization: Iterate over actual rows first (hot path), then padding rows (cold path)
798+
// This improves performance for small batch sizes with swizzled layout
797799
for (int rowIdx = blockIdx.x; rowIdx < numPaddedRowsForSf; rowIdx += gridDim.x)
798800
{
799-
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++)
801+
// Early exit for padding-only blocks: if this block only processes padding rows,
802+
// we can skip the batch loop and just zero out the scale factors
803+
bool isRowPadding = (rowIdx >= numRows);
804+
805+
if (isRowPadding)
800806
{
801-
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x)
807+
// Fast path: This row is entirely padding, only zero out scale factors.
808+
// Note: Padding rows do NOT exist in the output tensor (which is sized [numRows, K]),
809+
// they only exist in the swizzled scale factor layout. Do NOT write to output buffer here.
810+
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++)
802811
{
803-
std::optional<int> optionalBatchIdx = batchIdx;
804-
std::optional<int> optionalNumRows = numRows;
805-
806-
// The SF output pointer.
807-
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
808-
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, layout);
809-
810-
// The input tensor offset.
811-
int64_t inOffset = static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
812-
int64_t outOffset = static_cast<int64_t>(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;
813-
814-
// Set the values to 0 of those are padded columns.
815-
if (rowIdx < numRows && colIdx >= numColThreads && colIdx < numPaddedColThreads)
812+
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x)
816813
{
817-
// Dispatch the quantization kernel.
818-
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4)
819-
{
820-
reinterpret_cast<uint32_t*>(out)[outOffset] = 0u;
821-
}
822-
else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4
823-
|| quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8)
824-
{
825-
reinterpret_cast<uint64_t*>(out)[outOffset] = 0ull;
826-
}
827-
}
814+
std::optional<int> optionalBatchIdx = batchIdx;
815+
std::optional<int> optionalNumRows = numRows;
816+
817+
// The SF output pointer.
818+
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
819+
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, layout);
828820

829-
// Set the SF padding to 0.
830-
if (rowIdx >= numRows || colIdx >= numColThreads)
831-
{
832821
// Set the SF padding to 0.
833822
if (sf_out != nullptr)
834823
{
835824
sf_out[0] = 0x00;
836825
}
837826
}
838-
else
827+
}
828+
}
829+
else
830+
{
831+
// Normal path: This row contains actual data
832+
for (int batchIdx = 0; batchIdx < numbatches; batchIdx++)
833+
{
834+
for (int colIdx = threadIdx.x; colIdx < numColThreadsForSf; colIdx += blockDim.x)
839835
{
840-
// Load the input vector.
841-
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
836+
std::optional<int> optionalBatchIdx = batchIdx;
837+
std::optional<int> optionalNumRows = numRows;
838+
839+
// The SF output pointer.
840+
auto sf_out = cvt_quant_get_sf_out_offset<uint32_t, CVT_NUM_THREADS_PER_SF>(
841+
optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, layout);
842+
843+
// The input tensor offset.
844+
int64_t inOffset = static_cast<int64_t>(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
845+
int64_t outOffset
846+
= static_cast<int64_t>(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;
842847

843-
// Dispatch the quantization kernel.
844-
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4)
848+
// Set the values to 0 of those are padded columns.
849+
if (colIdx >= numColThreads && colIdx < numPaddedColThreads)
845850
{
846-
reinterpret_cast<uint32_t*>(out)[outOffset]
847-
= cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
851+
// Dispatch the quantization kernel.
852+
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4)
853+
{
854+
reinterpret_cast<uint32_t*>(out)[outOffset] = 0u;
855+
}
856+
else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4
857+
|| quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8)
858+
{
859+
reinterpret_cast<uint64_t*>(out)[outOffset] = 0ull;
860+
}
848861
}
849-
else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4)
862+
863+
// Set the SF padding to 0.
864+
if (colIdx >= numColThreads)
850865
{
851-
reinterpret_cast<uint64_t*>(out)[outOffset]
852-
= cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
866+
// Set the SF padding to 0.
867+
if (sf_out != nullptr)
868+
{
869+
sf_out[0] = 0x00;
870+
}
853871
}
854-
else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8)
872+
else
855873
{
856-
reinterpret_cast<uint64_t*>(out)[outOffset]
857-
= cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
874+
// Load the input vector.
875+
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
876+
877+
// Dispatch the quantization kernel.
878+
if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4)
879+
{
880+
reinterpret_cast<uint32_t*>(out)[outOffset]
881+
= cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
882+
}
883+
else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4)
884+
{
885+
reinterpret_cast<uint64_t*>(out)[outOffset]
886+
= cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
887+
}
888+
else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8)
889+
{
890+
reinterpret_cast<uint64_t*>(out)[outOffset]
891+
= cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
892+
}
858893
}
859894
}
860895
}

0 commit comments

Comments
 (0)