@@ -794,67 +794,102 @@ quantize_with_block_size(
794794
795795 asm volatile (" griddepcontrol.wait;" );
796796 // Input tensor batch/row/col loops.
797+ // Optimization: Iterate over actual rows first (hot path), then padding rows (cold path)
798+ // This improves performance for small batch sizes with swizzled layout
797799 for (int rowIdx = blockIdx .x ; rowIdx < numPaddedRowsForSf; rowIdx += gridDim .x )
798800 {
799- for (int batchIdx = 0 ; batchIdx < numbatches; batchIdx++)
801+ // Early exit for padding-only blocks: if this block only processes padding rows,
802+ // we can skip the batch loop and just zero out the scale factors
803+ bool isRowPadding = (rowIdx >= numRows);
804+
805+ if (isRowPadding)
800806 {
801- for (int colIdx = threadIdx .x ; colIdx < numColThreadsForSf; colIdx += blockDim .x )
807+ // Fast path: This row is entirely padding, only zero out scale factors.
808+ // Note: Padding rows do NOT exist in the output tensor (which is sized [numRows, K]),
809+ // they only exist in the swizzled scale factor layout. Do NOT write to output buffer here.
810+ for (int batchIdx = 0 ; batchIdx < numbatches; batchIdx++)
802811 {
803- std::optional<int > optionalBatchIdx = batchIdx;
804- std::optional<int > optionalNumRows = numRows;
805-
806- // The SF output pointer.
807- auto sf_out = cvt_quant_get_sf_out_offset<uint32_t , CVT_NUM_THREADS_PER_SF>(
808- optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, layout);
809-
810- // The input tensor offset.
811- int64_t inOffset = static_cast <int64_t >(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
812- int64_t outOffset = static_cast <int64_t >(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;
813-
814- // Set the values to 0 of those are padded columns.
815- if (rowIdx < numRows && colIdx >= numColThreads && colIdx < numPaddedColThreads)
812+ for (int colIdx = threadIdx .x ; colIdx < numColThreadsForSf; colIdx += blockDim .x )
816813 {
817- // Dispatch the quantization kernel.
818- if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4)
819- {
820- reinterpret_cast <uint32_t *>(out)[outOffset] = 0u ;
821- }
822- else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4
823- || quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8)
824- {
825- reinterpret_cast <uint64_t *>(out)[outOffset] = 0ull ;
826- }
827- }
814+ std::optional<int > optionalBatchIdx = batchIdx;
815+ std::optional<int > optionalNumRows = numRows;
816+
817+ // The SF output pointer.
818+ auto sf_out = cvt_quant_get_sf_out_offset<uint32_t , CVT_NUM_THREADS_PER_SF>(
819+ optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, layout);
828820
829- // Set the SF padding to 0.
830- if (rowIdx >= numRows || colIdx >= numColThreads)
831- {
832821 // Set the SF padding to 0.
833822 if (sf_out != nullptr )
834823 {
835824 sf_out[0 ] = 0x00 ;
836825 }
837826 }
838- else
827+ }
828+ }
829+ else
830+ {
831+ // Normal path: This row contains actual data
832+ for (int batchIdx = 0 ; batchIdx < numbatches; batchIdx++)
833+ {
834+ for (int colIdx = threadIdx .x ; colIdx < numColThreadsForSf; colIdx += blockDim .x )
839835 {
840- // Load the input vector.
841- PackedVec in_vec = reinterpret_cast <PackedVec const *>(in)[inOffset];
836+ std::optional<int > optionalBatchIdx = batchIdx;
837+ std::optional<int > optionalNumRows = numRows;
838+
839+ // The SF output pointer.
840+ auto sf_out = cvt_quant_get_sf_out_offset<uint32_t , CVT_NUM_THREADS_PER_SF>(
841+ optionalBatchIdx, rowIdx, colIdx, optionalNumRows, numPaddedCols / SF_VEC_SIZE, SFout, layout);
842+
843+ // The input tensor offset.
844+ int64_t inOffset = static_cast <int64_t >(batchIdx * numRows + rowIdx) * numColThreads + colIdx;
845+ int64_t outOffset
846+ = static_cast <int64_t >(batchIdx * numRows + rowIdx) * numPaddedColThreads + colIdx;
842847
843- // Dispatch the quantization kernel .
844- if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4 )
848+ // Set the values to 0 of those are padded columns .
849+ if (colIdx >= numColThreads && colIdx < numPaddedColThreads )
845850 {
846- reinterpret_cast <uint32_t *>(out)[outOffset]
847- = cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
851+ // Dispatch the quantization kernel.
852+ if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4)
853+ {
854+ reinterpret_cast <uint32_t *>(out)[outOffset] = 0u ;
855+ }
856+ else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4
857+ || quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8)
858+ {
859+ reinterpret_cast <uint64_t *>(out)[outOffset] = 0ull ;
860+ }
848861 }
849- else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4)
862+
863+ // Set the SF padding to 0.
864+ if (colIdx >= numColThreads)
850865 {
851- reinterpret_cast <uint64_t *>(out)[outOffset]
852- = cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
866+ // Set the SF padding to 0.
867+ if (sf_out != nullptr )
868+ {
869+ sf_out[0 ] = 0x00 ;
870+ }
853871 }
854- else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8)
872+ else
855873 {
856- reinterpret_cast <uint64_t *>(out)[outOffset]
857- = cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
874+ // Load the input vector.
875+ PackedVec in_vec = reinterpret_cast <PackedVec const *>(in)[inOffset];
876+
877+ // Dispatch the quantization kernel.
878+ if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_FP4)
879+ {
880+ reinterpret_cast <uint32_t *>(out)[outOffset]
881+ = cvt_warp_fp16_to_fp4<Type, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
882+ }
883+ else if constexpr (quantization_type == BlockScaleQuantizationType::FP8_TO_FP4)
884+ {
885+ reinterpret_cast <uint64_t *>(out)[outOffset]
886+ = cvt_warp_fp8_to_fp4<__nv_fp8_e4m3, SF_VEC_SIZE, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
887+ }
888+ else if constexpr (quantization_type == BlockScaleQuantizationType::FP16_TO_MXFP8)
889+ {
890+ reinterpret_cast <uint64_t *>(out)[outOffset]
891+ = cvt_warp_fp16_to_mxfp8<Type, SF_VEC_SIZE>(in_vec, sf_out);
892+ }
858893 }
859894 }
860895 }
0 commit comments