Skip to content

Commit 5ad2323

Browse files
authored
feat: add support of fp4_batched_quantize (#1633)
1 parent 17978a3 commit 5ad2323

File tree

5 files changed

+218
-1
lines changed

5 files changed

+218
-1
lines changed

csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,82 @@ std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self,
142142

143143
return {valueE2M1, scaleFP8SF};
144144
}
145+
146+
// self: [B, M, K], fp16/bf16/fp8_quantized
147+
// globalScale: [1] float, = (448 * 6) / self.abs().max()
148+
// nvfp4: sfVecSize = 16, sfUseUE8M0 = false
149+
// mxfp4: sfVecSize = 32 (not supported yet), sfUseUE8M0 = true
150+
// alignment: sfVecSize
151+
// returns self_fp4, self_block_scale_factors
152+
// self_fp4: [B, M, K / 2], FLOAT4_E2M1X2
153+
// self_block_scale_factors:
154+
// [B, ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4], SF_DTYPE (UE4M3 or UE8M0)
155+
std::tuple<at::Tensor, at::Tensor> fp4_batched_quantize(at::Tensor const& self,
156+
at::Tensor const& globalScale,
157+
int64_t sfVecSize, bool sfUseUE8M0) {
158+
CHECK_TH_CUDA(self);
159+
CHECK_CONTIGUOUS(self);
160+
CHECK_INPUT_TYPE(globalScale, c10::ScalarType::Float);
161+
TORCH_CHECK(sfVecSize == 16, "sfVecSize can only be 16");
162+
163+
auto const& inputShape = self.sizes();
164+
auto const& rank = inputShape.size();
165+
166+
TORCH_CHECK(rank == 3, "Input should be 3D tensor.");
167+
168+
int64_t b = inputShape[0];
169+
int64_t m = inputShape[1];
170+
int64_t k = inputShape[2];
171+
172+
TORCH_CHECK(k % sfVecSize == 0);
173+
174+
std::vector<int64_t> outputShape(inputShape.begin(), inputShape.end());
175+
outputShape[rank - 1] = k / 2;
176+
177+
at::Tensor valueE2M1 =
178+
at::detail::empty_cuda(outputShape, FLOAT4_E2M1X2, self.device(), /* stride */ std::nullopt);
179+
at::Tensor scaleFP8SF =
180+
at::detail::empty_cuda({b, tensorrt_llm::computeSwizzledLayoutSFSize(m, k / sfVecSize)},
181+
SF_DTYPE, self.device(), /* stride */ std::nullopt); // 2D tensor
182+
183+
const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
184+
auto layout = tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4;
185+
186+
#define LAUNCH_FP4_QUANTIZE_KERNEL(T, SF_VEC_SIZE) \
187+
tensorrt_llm::kernels::invokeFP4Quantization<T, SF_VEC_SIZE>( \
188+
b, m, k, reinterpret_cast<T*>(self.data_ptr()), globalScale.data_ptr<float>(), \
189+
reinterpret_cast<int64_t*>(valueE2M1.data_ptr()), \
190+
reinterpret_cast<int32_t*>(scaleFP8SF.data_ptr()), sfUseUE8M0, layout, mMultiProcessorCount, \
191+
at::cuda::getCurrentCUDAStream(self.get_device()));
192+
193+
if (self.scalar_type() == at::ScalarType::Half) {
194+
LAUNCH_FP4_QUANTIZE_KERNEL(half, 16)
195+
} else if (self.scalar_type() == at::ScalarType::BFloat16) {
196+
#ifdef ENABLE_BF16
197+
LAUNCH_FP4_QUANTIZE_KERNEL(__nv_bfloat16, 16)
198+
#else
199+
C10_THROW_ERROR(NotImplementedError,
200+
"BFloat16 must be enabled to quantize an bf16 tensor to fp4.");
201+
#endif
202+
} else if (self.scalar_type() == at::ScalarType::Float8_e4m3fn) {
203+
#ifdef ENABLE_FP8
204+
LAUNCH_FP4_QUANTIZE_KERNEL(__nv_fp8_e4m3, 16)
205+
#else
206+
C10_THROW_ERROR(NotImplementedError, "FP8 must be enabled to quantize an fp8 tensor to fp4.");
207+
#endif
208+
} else {
209+
C10_THROW_ERROR(NotImplementedError,
210+
"fp4_quantize only supports input tensor with dtypes fp16/bf16/e4m3.");
211+
}
212+
213+
#undef LAUNCH_FP4_QUANTIZE_KERNEL
214+
215+
return {valueE2M1, scaleFP8SF};
216+
}
217+
145218
} // namespace torch_ext
146219

147-
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { m.def("fp4_quantize", &torch_ext::fp4_quantize); }
220+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
221+
m.def("fp4_quantize", &torch_ext::fp4_quantize);
222+
m.def("fp4_batched_quantize", &torch_ext::fp4_batched_quantize);
223+
}

csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,8 @@ std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self,
2828
int64_t sfVecSize, bool sfUseUE8M0,
2929
bool isSfSwizzledLayout, bool isSf8x4Layout,
3030
bool enable_pdl);
31+
32+
std::tuple<at::Tensor, at::Tensor> fp4_batched_quantize(at::Tensor const& self,
33+
at::Tensor const& globalScale,
34+
int64_t sfVecSize, bool sfUseUE8M0);
3135
} // namespace torch_ext

flashinfer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
mxfp4_dequantize,
5959
mxfp4_quantize,
6060
nvfp4_quantize,
61+
nvfp4_batched_quantize,
6162
shuffle_matrix_a,
6263
shuffle_matrix_sf_a,
6364
)

flashinfer/fp4_quantization.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,71 @@ def _fake_block_scale_interleave_sm100(
251251
[unswizzled_sf.shape[0] * unswizzled_sf.shape[1] // 16], dtype=torch.uint8
252252
)
253253

254+
@register_custom_op(
255+
"flashinfer::fp4_batched_quantize_sm100",
256+
mutates_args=("",),
257+
)
258+
def fp4_batched_quantize_sm100(
259+
input: torch.Tensor,
260+
global_scale: Optional[torch.Tensor] = None,
261+
sf_vec_size: int = 16,
262+
sf_use_ue8m0: bool = False,
263+
) -> Tuple[torch.Tensor, torch.Tensor]:
264+
"""Quantize a batched tensor to FP4 (E2M1x2) with per-block scale factors.
265+
266+
This function converts a float/bfloat16 (or FP8-quantized) input tensor into a
267+
packed FP4 tensor using the E2M1 format (two 4-bit values per byte), along with
268+
per-block scale factors. Scale factors are encoded as UE4M3 by default, or UE8M0
269+
when requested, and an optional global scale can be applied.
270+
271+
Args:
272+
input (torch.Tensor): Input tensor of shape [B, M, K] with dtype torch.float16,
273+
torch.bfloat16, or an FP8-quantized dtype supported by the kernel.
274+
global_scale (torch.Tensor, optional): Global scale factor of shape [1] and
275+
dtype float32.
276+
sf_vec_size (int, optional): Scale-factor vector size and alignment unit along K.
277+
Supported/expected values:
278+
- 16 (NVFP4 path; supported)
279+
- 32 (MXFP4 path; not supported yet)
280+
Defaults to 16.
281+
sf_use_ue8m0 (bool, optional): Scale-factor encoding type.
282+
False → UE4M3 (default), True → UE8M0.
283+
284+
Returns:
285+
Tuple[torch.Tensor, torch.Tensor]:
286+
- self_fp4 (torch.Tensor): Packed FP4 tensor in E2M1x2 format of shape
287+
[B, M, K // 2] with dtype torch.uint8 (two FP4 lanes per byte).
288+
- self_block_scale_factors (torch.Tensor): Block scale factors with dtype
289+
uint8 (UE4M3 or UE8M0), laid out as a flat buffer of shape
290+
[B, ceil(M / 128) * 128 * ceil(K / sf_vec_size / 4) * 4].
291+
292+
Notes:
293+
- K must be even (because outputs pack two FP4 values per byte).
294+
- For best performance, K should be a multiple of sf_vec_size; the scale-factor
295+
buffer is aligned to sf_vec_size along K, pads M to multiples of 128, and
296+
rounds (K / sf_vec_size) up to a multiple of 4 for storage.
297+
- The batch dimension B is preserved for both outputs.
298+
"""
299+
return module.fp4_batched_quantize(
300+
input,
301+
global_scale,
302+
sf_vec_size,
303+
sf_use_ue8m0,
304+
)
305+
306+
@register_fake_op("flashinfer::fp4_batched_quantize_sm100")
307+
def _fp4_batched_quantize_sm100(
308+
input: torch.Tensor,
309+
global_scale: Optional[torch.Tensor] = None,
310+
sf_vec_size: int = 16,
311+
sf_use_ue8m0: bool = False,
312+
) -> Tuple[torch.Tensor, torch.Tensor]:
313+
m, k = input.shape
314+
return (
315+
input.new_empty([m, k // 2], dtype=torch.int64), # float4_e2m1_x2
316+
input.new_empty([m * k // sf_vec_size], dtype=torch.int32), # Scale factors
317+
)
318+
254319
@register_custom_op(
255320
"flashinfer::e2m1_and_ufp8sf_scale_to_float_sm100",
256321
mutates_args=(""),
@@ -307,6 +372,7 @@ def _fake_e2m1_and_ufp8sf_scale_to_float_sm100(
307372
block_scale_interleave_sm100=block_scale_interleave_sm100,
308373
e2m1_and_ufp8sf_scale_to_float_sm100=e2m1_and_ufp8sf_scale_to_float_sm100,
309374
mxfp4_dequantize_host=mxfp4_dequantize_host,
375+
fp4_batched_quantize_sm100=fp4_batched_quantize_sm100,
310376
)
311377

312378

@@ -610,3 +676,30 @@ def mxfp4_dequantize_host(
610676
scale,
611677
group_size,
612678
)
679+
680+
681+
def nvfp4_batched_quantize(
682+
a,
683+
a_global_sf,
684+
sf_vec_size=16,
685+
):
686+
"""
687+
Quantize batched input tensor to NVFP4 format.
688+
689+
Parameters:
690+
a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16.
691+
a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32.
692+
sf_vec_size (int, optional): Scale factor vector size. Defaults to 16.
693+
694+
Returns:
695+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
696+
- Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2
697+
- Scale factors tensor with shape determined by layout and sf_vec_size
698+
"""
699+
a_fp4, a_sf = get_fp4_quantization_module().fp4_batched_quantize_sm100(
700+
a,
701+
a_global_sf,
702+
sf_vec_size,
703+
False,
704+
)
705+
return a_fp4, a_sf

tests/test_fp4_quantize.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
fp4_quantize,
1111
mxfp4_quantize,
1212
mxfp4_dequantize,
13+
nvfp4_batched_quantize,
1314
)
1415
from flashinfer.utils import is_sm100a_supported
1516

1617
DTYPES = [torch.float16, torch.bfloat16]
1718
# The batch dimension doesn't need to be multiple of 128
1819
SHAPES = [(128, 64), (256, 128), (120, 64), (200, 256)]
20+
BATCH_SHAPES = [(2, 128, 64), (3, 256, 128), (1, 120, 64)]
1921
SEEDS = [42]
2022
CUDA_DEVICES = ["cuda:0"]
2123

@@ -310,5 +312,46 @@ def test_mxfp4_quantize_roundtrip(device: str):
310312
)
311313

312314

315+
@pytest.mark.parametrize("dtype", DTYPES)
316+
@pytest.mark.parametrize("batch_shape", BATCH_SHAPES)
317+
@pytest.mark.parametrize("seed", SEEDS)
318+
@pytest.mark.parametrize("device", CUDA_DEVICES)
319+
@torch.inference_mode()
320+
def test_nvfp4_batched_quantize(
321+
dtype: torch.dtype,
322+
batch_shape: tuple[int, int, int],
323+
seed: int,
324+
device: str,
325+
) -> None:
326+
"""Test nvfp4_batched_quantize function."""
327+
if not is_sm100a_supported(torch.device(device)):
328+
pytest.skip("Nvfp4 Requires compute capability of 10 or above")
329+
torch.set_default_device(device)
330+
torch.manual_seed(seed)
331+
332+
b, m, n = batch_shape
333+
x = torch.randn(batch_shape, dtype=dtype)
334+
tensor_amax = torch.abs(x).max().to(torch.float32)
335+
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
336+
337+
# Test the batched quantization
338+
out, out_scale = nvfp4_batched_quantize(x, global_scale)
339+
340+
# Basic shape checks
341+
assert out.shape == (b, m, n // 2), (
342+
f"Expected shape {(b, m, n // 2)}, got {out.shape}"
343+
)
344+
assert out.dtype == torch.uint8, f"Expected uint8, got {out.dtype}"
345+
assert out_scale.dtype == torch.uint8, f"Expected uint8, got {out_scale.dtype}"
346+
347+
# Compare with single tensor quantization for each batch
348+
for i in range(b):
349+
single_out, single_scale = fp4_quantize(x[i], global_scale, 16, False, True)
350+
torch.testing.assert_close(out[i], single_out, rtol=1e-5, atol=1e-5)
351+
torch.testing.assert_close(
352+
out_scale[i], single_scale.flatten(), rtol=1e-5, atol=1e-5
353+
)
354+
355+
313356
if __name__ == "__main__":
314357
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)