Skip to content

Commit d8e7d6a

Browse files
authored
feature: add fp4 mm using trtllm backend (#1355)
feature: add fp4 mm using trtllm backend <!-- .github/pull_request_template.md --> ## πŸ“Œ Description 1. support both 128x4 and 8x4 block quant layout 2. support autotuning ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Signed-off-by: Vincent Huang <[email protected]>
1 parent ac7dc76 commit d8e7d6a

File tree

16 files changed

+2958
-279
lines changed

16 files changed

+2958
-279
lines changed

β€Žcsrc/nv_internal/cpp/kernels/quantization.cu

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, i
181181
template <typename T, int SF_VEC_SIZE>
182182
void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float const* SFScale,
183183
int64_t* output, int32_t* SFOuput, bool useUE8M0,
184-
int multiProcessorCount, cudaStream_t stream) {
184+
int multiProcessorCount, FP4QuantizationSFLayout layout,
185+
cudaStream_t stream) {
185186
#ifdef ENABLE_FP8
186187
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
187188
// Grid, Block size.
@@ -194,9 +195,9 @@ void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float con
194195
// Launch the cvt kernel.
195196
auto* kernel_instance =
196197
useUE8M0 ? &cvt_fp8_to_fp4_3d<SF_VEC_SIZE, true> : &cvt_fp8_to_fp4_3d<SF_VEC_SIZE, false>;
197-
kernel_instance<<<grid, block, 0, stream>>>(
198-
b, m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
199-
reinterpret_cast<uint32_t*>(SFOuput), FP4QuantizationSFLayout::SWIZZLED);
198+
kernel_instance<<<grid, block, 0, stream>>>(b, m, n, input, SFScale,
199+
reinterpret_cast<uint32_t*>(output),
200+
reinterpret_cast<uint32_t*>(SFOuput), layout);
200201
} else
201202
#endif
202203
{
@@ -222,7 +223,7 @@ void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float con
222223
config.attrs = attrs;
223224
cudaLaunchKernelEx(&config, kernel_instance, b, m, n, input, SFScale,
224225
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
225-
FP4QuantizationSFLayout::SWIZZLED);
226+
layout);
226227
}
227228
}
228229

@@ -313,14 +314,12 @@ template void invokeFP4Quantization<half, 32>(int m, int n, half const* input, f
313314
int64_t* output, int32_t* SFOuput, bool useUE8M0,
314315
FP4QuantizationSFLayout layout,
315316
int multiProcessorCount, cudaStream_t stream);
316-
template void invokeBatchedFP4Quantization<half, 16>(int b, int m, int n, half const* input,
317-
float const* SFScale, int64_t* output,
318-
int32_t* SFOuput, bool useUE8M0,
319-
int multiProcessorCount, cudaStream_t stream);
320-
template void invokeBatchedFP4Quantization<half, 32>(int b, int m, int n, half const* input,
321-
float const* SFScale, int64_t* output,
322-
int32_t* SFOuput, bool useUE8M0,
323-
int multiProcessorCount, cudaStream_t stream);
317+
template void invokeBatchedFP4Quantization<half, 16>(
318+
int b, int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput,
319+
bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, cudaStream_t stream);
320+
template void invokeBatchedFP4Quantization<half, 32>(
321+
int b, int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput,
322+
bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, cudaStream_t stream);
324323
#ifdef ENABLE_BF16
325324
template void invokeFP4Quantization<__nv_bfloat16, 16>(int m, int n, __nv_bfloat16 const* input,
326325
float const* SFScale, int64_t* output,
@@ -336,10 +335,12 @@ template void invokeFP4Quantization<__nv_bfloat16, 32>(int m, int n, __nv_bfloat
336335
cudaStream_t stream);
337336
template void invokeBatchedFP4Quantization<__nv_bfloat16, 16>(
338337
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
339-
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream);
338+
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout,
339+
cudaStream_t stream);
340340
template void invokeBatchedFP4Quantization<__nv_bfloat16, 32>(
341341
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
342-
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream);
342+
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout,
343+
cudaStream_t stream);
343344
#endif
344345

345346
#ifdef ENABLE_FP8
@@ -357,10 +358,12 @@ template void invokeFP4Quantization<__nv_fp8_e4m3, 32>(int m, int n, __nv_fp8_e4
357358
cudaStream_t stream);
358359
template void invokeBatchedFP4Quantization<__nv_fp8_e4m3, 16>(
359360
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
360-
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream);
361+
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout,
362+
cudaStream_t stream);
361363
template void invokeBatchedFP4Quantization<__nv_fp8_e4m3, 32>(
362364
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
363-
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, cudaStream_t stream);
365+
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout,
366+
cudaStream_t stream);
364367
#endif
365368

366369
////////////////////////////////////////////////////////////////////////////////////////////////////

β€Žcsrc/nv_internal/tensorrt_llm/kernels/quantization.cuh

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,40 @@ inline __device__ __host__ int64_t get_sf_out_offset_128x4(std::optional<int> ba
653653
return SFOffset;
654654
}
655655

656+
template <int SF_VEC_SIZE>
657+
inline __device__ __host__ int64_t get_sf_out_offset_8x4(std::optional<int> batchIdx, int mIdx,
658+
int kIdx, std::optional<int> numRows,
659+
int numCols) {
660+
// SF layout [numMTiles, numKTiles, 8 (mTile), 4(kTile)]
661+
// --> index [mTileIdx, kTileIdx, innerMIdx, innerKIdx]
662+
663+
// batched tensor
664+
// SF layout [numBTiles, numMTiles, numKTiles, 8 (mTile), 4(kTile)]
665+
// --> index [bTileIdx, mTileIdx, kTileIdx, innerMIdx, innerKIdx]
666+
const int32_t mTile = 8;
667+
int32_t innerKIdx = (kIdx % 4);
668+
int64_t innerKStride = 1;
669+
670+
int32_t innerMIdx = (mIdx % mTile);
671+
int64_t mStride = 4 * innerKStride;
672+
673+
int32_t kTileIdx = (kIdx / 4);
674+
int64_t kTileStride = mTile * mStride;
675+
676+
int factor = SF_VEC_SIZE * 4;
677+
int32_t numKTiles = (numCols + factor - 1) / factor;
678+
int32_t mTileIdx = mIdx / mTile;
679+
int64_t mTileStride = numKTiles * kTileStride;
680+
681+
int32_t numMTiles = (numRows.value_or(0) + 8 - 1) / 8;
682+
int64_t bTileStride = numMTiles * mTileStride;
683+
684+
int64_t SFOffset = batchIdx.value_or(0) * bTileStride + mTileIdx * mTileStride +
685+
kTileIdx * kTileStride + innerMIdx * mStride + innerKIdx * innerKStride;
686+
687+
return SFOffset;
688+
}
689+
656690
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF, int SF_VEC_SIZE>
657691
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchIdx, int rowIdx,
658692
int colIdx, std::optional<int> numRows,
@@ -666,13 +700,17 @@ __device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchI
666700
// TODO: stage through smem for packed STG.32
667701
// is it better than STG.8 from 4 threads ?
668702
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
669-
if (layout == FP4QuantizationSFLayout::SWIZZLED) {
703+
if (layout == FP4QuantizationSFLayout::SWIZZLED_128x4 ||
704+
layout == FP4QuantizationSFLayout::SWIZZLED_8x4) {
670705
// SF vector index (16 elements share one SF in the K dimension).
671706
// numRows and numCols are unpadded.
672707
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
673708
int32_t mIdx = rowIdx;
674709

675-
auto SFOffset = get_sf_out_offset_128x4<SF_VEC_SIZE>(batchIdx, mIdx, kIdx, numRows, numCols);
710+
auto SFOffset =
711+
layout == FP4QuantizationSFLayout::SWIZZLED_128x4
712+
? get_sf_out_offset_128x4<SF_VEC_SIZE>(batchIdx, mIdx, kIdx, numRows, numCols)
713+
: get_sf_out_offset_8x4<SF_VEC_SIZE>(batchIdx, mIdx, kIdx, numRows, numCols);
676714
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
677715
} else if (layout == FP4QuantizationSFLayout::LINEAR) {
678716
// Linear row-major layout, no padding required.

β€Žcsrc/nv_internal/tensorrt_llm/kernels/quantization.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ enum class FP4QuantizationSFLayout {
3131
// The scale factor block rows map to data block rows in an interleaved pattern:
3232
// For a scale factor row 'i', it maps to data block row: (i % 4) * 32 + (i / 4)
3333
// Column 'j' in the scale factor block corresponds to scaling the j-th block in the data tensor.
34-
//
35-
// Please refer to https://nvbugs/4165523 for more details about the swizzled layout.
36-
SWIZZLED,
34+
SWIZZLED_128x4,
35+
36+
// Similar to SWIZZLED_128x4, but with 8x4 scale factor blocks.
37+
SWIZZLED_8x4,
38+
3739
// Block scale factors are stored in linear layout (row-major). This is used in some trtllm-gen
3840
// kernels standard.
3941
LINEAR
@@ -42,8 +44,8 @@ enum class FP4QuantizationSFLayout {
4244
#define PadUpFn(X, Y) ((X + Y - 1) / (Y) * (Y))
4345

4446
// totalCloumn should be in SFMatrix, not activation Matrix, so no sfVecSize needed.
45-
inline int computeFP4SwizzledLayoutSFSize(int totalRow, int totalColumn) {
46-
int paddedRow = PadUpFn(totalRow, 128);
47+
inline int computeFP4SwizzledLayoutSFSize(int totalRow, int totalColumn, int rowSize = 128) {
48+
int paddedRow = PadUpFn(totalRow, rowSize);
4749
int paddedColumn = PadUpFn(totalColumn, 4);
4850
return paddedRow * paddedColumn;
4951
}
@@ -70,9 +72,11 @@ void invokeFP4Quantization(int m, int n, T const* input, float const* globalScal
7072
int multiProcessorCount, cudaStream_t stream = 0);
7173

7274
template <typename T, int SF_VEC_SIZE = 16>
73-
void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float const* globalScale,
74-
int64_t* output, int32_t* SFOuput, bool useUE8M0,
75-
int multiProcessorCount, cudaStream_t stream = 0);
75+
void invokeBatchedFP4Quantization(
76+
int b, int m, int n, T const* input, float const* globalScale, int64_t* output,
77+
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount,
78+
FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED_128x4,
79+
cudaStream_t stream = 0);
7680

7781
void invokeNVFP4BlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded,
7882
uint8_t const* SFIn, uint8_t* SFOutput,

β€Žcsrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn,
106106
constexpr int kRowGroup1Size = kRowGroup0Size * 4;
107107

108108
// Swizzled layout is used as default layout.
109-
if (layout == tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED) {
109+
if (layout == tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4) {
110110
// int paddedRow = PadUpFn(totalRow, 128);
111111
int paddedColumn = PadUpFn(totalColumn, 4);
112112

@@ -179,7 +179,7 @@ at::Tensor NVFP4BlockScaleInterleave(at::Tensor const& blockScale) {
179179
sf_ori = blockScalePtr[cIdx];
180180
}
181181
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
182-
tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED);
182+
tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4);
183183
interleavedBlockScalePtr[sf_index] = sf_ori;
184184
}
185185
}
@@ -225,7 +225,7 @@ at::Tensor NVFP4BlockScaleInterleaveReverse(at::Tensor const& blockScale) {
225225
for (int rIdx = 0; rIdx < rows; ++rIdx) {
226226
for (int cIdx = 0; cIdx < cols; ++cIdx) {
227227
int sf_index = computeSFIndex(rIdx, cIdx, rows, cols,
228-
tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED);
228+
tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4);
229229
identity[eIdx * expert_out_size + sf_index] = std::array<int, 3>{eIdx, rIdx, cIdx};
230230
}
231231
}
@@ -267,7 +267,7 @@ at::Tensor E2M1AndUFP8SFScaleToFloat(at::Tensor valueE2M1, at::Tensor scaleFP8SF
267267
uint8_t* scaleFP8SFPtr = scaleFP8SF.data_ptr<uint8_t>();
268268
uint8_t fp8Scale =
269269
scaleFP8SFPtr[computeSFIndex(vIdx, group, packedShape[0], groupsPerHiddenDim,
270-
tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED)];
270+
tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4)];
271271
int scale = fp8Scale;
272272
if (sfType == 0) {
273273
scale -= 127;
@@ -311,7 +311,7 @@ at::Tensor E2M1AndUFP8SFScaleToFloatV2(at::Tensor valueE2M1, at::Tensor scaleFP8
311311
int groupsPerHiddenDim = hiddenDim / sfVecSize;
312312

313313
tensorrt_llm::FP4QuantizationSFLayout layout =
314-
isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED
314+
isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
315315
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
316316

317317
for (size_t vIdx = 0; vIdx < static_cast<size_t>(packedShape[0]); ++vIdx) {

β€Žcsrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ namespace torch_ext {
4040
// ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0)
4141
std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self,
4242
at::Tensor const& globalScale, int64_t sfVecSize,
43-
bool sfUseUE8M0, bool isSfSwizzledLayout) {
43+
bool sfUseUE8M0, bool isSfSwizzledLayout,
44+
bool isSf8x4Layout) {
4445
CHECK_TH_CUDA(self);
4546
CHECK_CONTIGUOUS(self);
4647
CHECK_INPUT_TYPE(globalScale, c10::ScalarType::Float);
@@ -63,17 +64,24 @@ std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self,
6364
at::Tensor valueE2M1 =
6465
at::detail::empty_cuda(outputShape, FLOAT4_E2M1X2, self.device(), /* stride */ std::nullopt);
6566

66-
int64_t SFSize = isSfSwizzledLayout
67-
? tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize)
68-
: tensorrt_llm::computeFP4LinearLayoutSFSize(m, k / sfVecSize);
67+
int64_t SFSize =
68+
isSfSwizzledLayout
69+
? tensorrt_llm::computeFP4SwizzledLayoutSFSize(m, k / sfVecSize, isSf8x4Layout ? 8 : 128)
70+
: tensorrt_llm::computeFP4LinearLayoutSFSize(m, k / sfVecSize);
6971

7072
at::Tensor scaleFP8SF = at::detail::empty_cuda({SFSize}, SF_DTYPE, self.device(),
7173
/* stride */ std::nullopt); // 1D tensor
7274

7375
const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
7476

75-
auto const layout = isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED
76-
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
77+
auto layout = tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
78+
if (isSf8x4Layout) {
79+
TORCH_CHECK(isSfSwizzledLayout, "8x4layout must be swizzled layout");
80+
layout = tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_8x4;
81+
} else {
82+
layout = isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4
83+
: tensorrt_llm::FP4QuantizationSFLayout::LINEAR;
84+
}
7785

7886
#define LAUNCH_FP4_QUANTIZE_KERNEL(T, SF_VEC_SIZE) \
7987
tensorrt_llm::kernels::invokeFP4Quantization<T, SF_VEC_SIZE>( \

β€Žcsrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@
2424
namespace torch_ext {
2525
std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self,
2626
at::Tensor const& globalScale, int64_t sfVecSize,
27-
bool sfUseUE8M0, bool isSfSwizzledLayout);
27+
bool sfUseUE8M0, bool isSfSwizzledLayout,
28+
bool isSf8x4Layout);
2829
} // namespace torch_ext

0 commit comments

Comments
Β (0)