Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions csrc/nv_internal/cpp/kernels/quantization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, i
template <typename T, int SF_VEC_SIZE>
void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float const* SFScale,
int64_t* output, int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount, cudaStream_t stream) {
int multiProcessorCount, cudaStream_t stream,
FP4QuantizationSFLayout layout) {
#ifdef ENABLE_FP8
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
// Grid, Block size.
Expand All @@ -194,9 +195,9 @@ void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float con
// Launch the cvt kernel.
auto* kernel_instance =
useUE8M0 ? &cvt_fp8_to_fp4_3d<SF_VEC_SIZE, true> : &cvt_fp8_to_fp4_3d<SF_VEC_SIZE, false>;
kernel_instance<<<grid, block, 0, stream>>>(
b, m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput), FP4QuantizationSFLayout::SWIZZLED);
kernel_instance<<<grid, block, 0, stream>>>(b, m, n, input, SFScale,
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(SFOuput), layout);
} else
#endif
{
Expand All @@ -222,7 +223,7 @@ void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float con
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance, b, m, n, input, SFScale,
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
FP4QuantizationSFLayout::SWIZZLED);
layout);
}
}

Expand Down Expand Up @@ -316,11 +317,13 @@ template void invokeFP4Quantization<half, 32>(int m, int n, half const* input, f
template void invokeBatchedFP4Quantization<half, 16>(int b, int m, int n, half const* input,
float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount, cudaStream_t stream);
int multiProcessorCount, cudaStream_t stream,
FP4QuantizationSFLayout layout);
template void invokeBatchedFP4Quantization<half, 32>(int b, int m, int n, half const* input,
float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount, cudaStream_t stream);
int multiProcessorCount, cudaStream_t stream,
FP4QuantizationSFLayout layout);
#ifdef ENABLE_BF16
template void invokeFP4Quantization<__nv_bfloat16, 16>(int m, int n, __nv_bfloat16 const* input,
float const* SFScale, int64_t* output,
Expand All @@ -336,10 +339,12 @@ template void invokeFP4Quantization<__nv_bfloat16, 32>(int m, int n, __nv_bfloat
cudaStream_t stream);
template void invokeBatchedFP4Quantization<__nv_bfloat16, 16>(
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream);
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream,
FP4QuantizationSFLayout layout);
template void invokeBatchedFP4Quantization<__nv_bfloat16, 32>(
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream);
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream,
FP4QuantizationSFLayout layout);
#endif

#ifdef ENABLE_FP8
Expand All @@ -357,10 +362,12 @@ template void invokeFP4Quantization<__nv_fp8_e4m3, 32>(int m, int n, __nv_fp8_e4
cudaStream_t stream);
template void invokeBatchedFP4Quantization<__nv_fp8_e4m3, 16>(
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream);
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream,
FP4QuantizationSFLayout layout);
template void invokeBatchedFP4Quantization<__nv_fp8_e4m3, 32>(
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream);
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream,
FP4QuantizationSFLayout layout);
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
42 changes: 40 additions & 2 deletions csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,40 @@ inline __device__ __host__ int64_t get_sf_out_offset_128x4(std::optional<int> ba
return SFOffset;
}

template <int SF_VEC_SIZE>
inline __device__ __host__ int64_t get_sf_out_offset_8x4(std::optional<int> batchIdx, int mIdx,
int kIdx, std::optional<int> numRows,
int numCols) {
// SF layout [numMTiles, numKTiles, 8 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, innerMIdx, innerKIdx]

// batched tensor
// SF layout [numBTiles, numMTiles, numKTiles, 8 (mTile), 4(kTile)]
// --> index [bTileIdx, mTileIdx, kTileIdx, innerMIdx, innerKIdx]
const int32_t mTile = 8;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;

int32_t innerMIdx = (mIdx % mTile);
int64_t mStride = 4 * innerKStride;

int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = mTile * mStride;

int factor = SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int32_t mTileIdx = mIdx / mTile;
int64_t mTileStride = numKTiles * kTileStride;

int32_t numMTiles = (numRows.value_or(0) + 8 - 1) / 8;
int64_t bTileStride = numMTiles * mTileStride;

int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride +
kTileIdx * kTileStride + innerMIdx * mStride + innerKIdx * innerKStride;

return SFOffset;
}

template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF, int SF_VEC_SIZE>
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchIdx, int rowIdx,
int colIdx, std::optional<int> numRows,
Expand All @@ -666,13 +700,17 @@ __device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchI
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
if (layout == FP4QuantizationSFLayout::SWIZZLED) {
if (layout == FP4QuantizationSFLayout::SWIZZLED_128x4 ||
layout == FP4QuantizationSFLayout::SWIZZLED_8x4) {
// SF vector index (16 elements share one SF in the K dimension).
// numRows and numCols are unpadded.
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;

auto SFOffset = get_sf_out_offset_128x4<SF_VEC_SIZE>(batchIdx, mIdx, kIdx, numRows, numCols);
auto SFOffset =
layout == FP4QuantizationSFLayout::SWIZZLED_128x4
? get_sf_out_offset_128x4<SF_VEC_SIZE>(batchIdx, mIdx, kIdx, numRows, numCols)
: get_sf_out_offset_8x4<SF_VEC_SIZE>(batchIdx, mIdx, kIdx, numRows, numCols);
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
} else if (layout == FP4QuantizationSFLayout::LINEAR) {
// Linear row-major layout, no padding required.
Expand Down
19 changes: 11 additions & 8 deletions csrc/nv_internal/tensorrt_llm/kernels/quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ enum class FP4QuantizationSFLayout {
// The scale factor block rows map to data block rows in an interleaved pattern:
// For a scale factor row 'i', it maps to data block row: (i % 4) * 32 + (i / 4)
// Column 'j' in the scale factor block corresponds to scaling the j-th block in the data tensor.
//
// Please refer to https://nvbugs/4165523 for more details about the swizzled layout.
SWIZZLED,
SWIZZLED_128x4,

// Similar to SWIZZLED_128x4, but with 8x4 scale factor blocks.
SWIZZLED_8x4,

// Block scale factors are stored in linear layout (row-major). This is used in some trtllm-gen
// kernels standard.
LINEAR
Expand All @@ -42,8 +44,8 @@ enum class FP4QuantizationSFLayout {
#define PadUpFn(X, Y) ((X + Y - 1) / (Y) * (Y))

// totalCloumn should be in SFMatrix, not activation Matrix, so no sfVecSize needed.
inline int computeFP4SwizzledLayoutSFSize(int totalRow, int totalColumn) {
int paddedRow = PadUpFn(totalRow, 128);
inline int computeFP4SwizzledLayoutSFSize(int totalRow, int totalColumn, int rowSize = 128) {
int paddedRow = PadUpFn(totalRow, rowSize);
int paddedColumn = PadUpFn(totalColumn, 4);
return paddedRow * paddedColumn;
}
Expand All @@ -70,9 +72,10 @@ void invokeFP4Quantization(int m, int n, T const* input, float const* globalScal
int multiProcessorCount, cudaStream_t stream = 0);

template <typename T, int SF_VEC_SIZE = 16>
void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float const* globalScale,
int64_t* output, int32_t* SFOuput, bool useUE8M0,
int multiProcessorCount, cudaStream_t stream = 0);
void invokeBatchedFP4Quantization(
int b, int m, int n, T const* input, float const* globalScale, int64_t* output,
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream = 0,
FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED_128x4);

void invokeNVFP4BlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded,
uint8_t const* SFIn, uint8_t* SFOutput,
Expand Down
10 changes: 5 additions & 5 deletions csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn,
constexpr int kRowGroup1Size = kRowGroup0Size * 4;

// Swizzled layout is used as default layout.
if (layout == tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED) {
if (layout == tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4) {
// int paddedRow = PadUpFn(totalRow, 128);
int paddedColumn = PadUpFn(totalColumn, 4);

Expand Down Expand Up @@ -179,7 +179,7 @@ at::Tensor NVFP4BlockScaleInterleave(at::Tensor const& blockScale) {
sf_ori = blockScalePtr[cIdx];
}
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED);
tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4);
interleavedBlockScalePtr[sf_index] = sf_ori;
}
}
Expand Down Expand Up @@ -225,7 +225,7 @@ at::Tensor NVFP4BlockScaleInterleaveReverse(at::Tensor const& blockScale) {
for (int rIdx = 0; rIdx < rows; ++rIdx) {
for (int cIdx = 0; cIdx < cols; ++cIdx) {
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED);
tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4);
identity[eIdx * expert_out_size + sf_index] = std::array<int, 3>{eIdx, rIdx, cIdx};
}
}
Expand Down Expand Up @@ -267,7 +267,7 @@ at::Tensor E2M1AndUFP8SFScaleToFloat(at::Tensor valueE2M1, at::Tensor scaleFP8SF
uint8_t* scaleFP8SFPtr = scaleFP8SF.data_ptr<uint8_t>();
uint8_t fp8Scale =
scaleFP8SFPtr[computeSFIndex(vIdx, group, packedShape[0], groupsPerHiddenDim,
tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED)];
tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4)];
int scale = fp8Scale;
if (sfType == 0) {
scale -= 127;
Expand Down Expand Up @@ -311,7 +311,7 @@ at::Tensor E2M1AndUFP8SFScaleToFloatV2(at::Tensor valueE2M1, at::Tensor scaleFP8
int groupsPerHiddenDim = hiddenDim / sfVecSize;

tensorrt_llm::FP4QuantizationSFLayout layout =
isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED
isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;

for (size_t vIdx = 0; vIdx < static_cast<size_t>(packedShape[0]); ++vIdx) {
Expand Down
20 changes: 14 additions & 6 deletions csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ namespace torch_ext {
// ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0)
std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self,
at::Tensor const& globalScale, int64_t sfVecSize,
bool sfUseUE8M0, bool isSfSwizzledLayout) {
bool sfUseUE8M0, bool isSfSwizzledLayout,
bool isSf8x4Layout) {
CHECK_TH_CUDA(self);
CHECK_CONTIGUOUS(self);
CHECK_INPUT_TYPE(globalScale, c10::ScalarType::Float);
Expand All @@ -63,17 +64,24 @@ std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self,
at::Tensor valueE2M1 =
at::detail::empty_cuda(outputShape, FLOAT4_E2M1X2, self.device(), /* stride */ std::nullopt);

int64_t SFSize = isSfSwizzledLayout
? tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize)
: tensorrt_llm::computeFP4LinearLayoutSFSize(m, k / sfVecSize);
int64_t SFSize =
isSfSwizzledLayout
? tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize, isSf8x4Layout ? 8 : 128)
: tensorrt_llm::computeFP4LinearLayoutSFSize(m, k / sfVecSize);

at::Tensor scaleFP8SF = at::detail::empty_cuda({SFSize}, SF_DTYPE, self.device(),
/* stride */ std::nullopt); // 1D tensor

const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();

auto const layout = isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
auto layout = tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
if (isSf8x4Layout) {
TORCH_CHECK(isSfSwizzledLayout, "8x4layout must be swizzled layout");
layout = tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_8x4;
} else {
layout = isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
}

#define LAUNCH_FP4_QUANTIZE_KERNEL(T, SF_VEC_SIZE) \
tensorrt_llm::kernels::invokeFP4Quantization<T, SF_VEC_SIZE>( \
Expand Down
3 changes: 2 additions & 1 deletion csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@
namespace torch_ext {
std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self,
at::Tensor const& globalScale, int64_t sfVecSize,
bool sfUseUE8M0, bool isSfSwizzledLayout);
bool sfUseUE8M0, bool isSfSwizzledLayout,
bool isSf8x4Layout);
} // namespace torch_ext
Loading