@@ -142,6 +142,82 @@ std::tuple<at::Tensor, at::Tensor> fp4_quantize(at::Tensor const& self,
142
142
143
143
return {valueE2M1, scaleFP8SF};
144
144
}
145
+
146
+ // self: [B, M, K], fp16/bf16/fp8_quantized
147
+ // globalScale: [1] float, = (448 * 6) / self.abs().max()
148
+ // nvfp4: sfVecSize = 16, sfUseUE8M0 = false
149
+ // mxfp4: sfVecSize = 32 (not supported yet), sfUseUE8M0 = true
150
+ // alignment: sfVecSize
151
+ // returns self_fp4, self_block_scale_factors
152
+ // self_fp4: [B, M, K / 2], FLOAT4_E2M1X2
153
+ // self_block_scale_factors:
154
+ // [B, ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4], SF_DTYPE (UE4M3 or UE8M0)
155
+ std::tuple<at::Tensor, at::Tensor> fp4_batched_quantize (at::Tensor const & self,
156
+ at::Tensor const & globalScale,
157
+ int64_t sfVecSize, bool sfUseUE8M0) {
158
+ CHECK_TH_CUDA (self);
159
+ CHECK_CONTIGUOUS (self);
160
+ CHECK_INPUT_TYPE (globalScale, c10::ScalarType::Float);
161
+ TORCH_CHECK (sfVecSize == 16 , " sfVecSize can only be 16" );
162
+
163
+ auto const & inputShape = self.sizes ();
164
+ auto const & rank = inputShape.size ();
165
+
166
+ TORCH_CHECK (rank == 3 , " Input should be 3D tensor." );
167
+
168
+ int64_t b = inputShape[0 ];
169
+ int64_t m = inputShape[1 ];
170
+ int64_t k = inputShape[2 ];
171
+
172
+ TORCH_CHECK (k % sfVecSize == 0 );
173
+
174
+ std::vector<int64_t > outputShape (inputShape.begin (), inputShape.end ());
175
+ outputShape[rank - 1 ] = k / 2 ;
176
+
177
+ at::Tensor valueE2M1 =
178
+ at::detail::empty_cuda (outputShape, FLOAT4_E2M1X2, self.device (), /* stride */ std::nullopt );
179
+ at::Tensor scaleFP8SF =
180
+ at::detail::empty_cuda ({b, tensorrt_llm::computeSwizzledLayoutSFSize (m, k / sfVecSize)},
181
+ SF_DTYPE, self.device (), /* stride */ std::nullopt ); // 2D tensor
182
+
183
+ const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount ();
184
+ auto layout = tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4;
185
+
186
+ #define LAUNCH_FP4_QUANTIZE_KERNEL (T, SF_VEC_SIZE ) \
187
+ tensorrt_llm::kernels::invokeFP4Quantization<T, SF_VEC_SIZE>( \
188
+ b, m, k, reinterpret_cast <T*>(self.data_ptr ()), globalScale.data_ptr <float >(), \
189
+ reinterpret_cast <int64_t *>(valueE2M1.data_ptr ()), \
190
+ reinterpret_cast <int32_t *>(scaleFP8SF.data_ptr ()), sfUseUE8M0, layout, mMultiProcessorCount , \
191
+ at::cuda::getCurrentCUDAStream (self.get_device ()));
192
+
193
+ if (self.scalar_type () == at::ScalarType::Half) {
194
+ LAUNCH_FP4_QUANTIZE_KERNEL (half, 16 )
195
+ } else if (self.scalar_type () == at::ScalarType::BFloat16) {
196
+ #ifdef ENABLE_BF16
197
+ LAUNCH_FP4_QUANTIZE_KERNEL (__nv_bfloat16, 16 )
198
+ #else
199
+ C10_THROW_ERROR (NotImplementedError,
200
+ " BFloat16 must be enabled to quantize an bf16 tensor to fp4." );
201
+ #endif
202
+ } else if (self.scalar_type () == at::ScalarType::Float8_e4m3fn) {
203
+ #ifdef ENABLE_FP8
204
+ LAUNCH_FP4_QUANTIZE_KERNEL (__nv_fp8_e4m3, 16 )
205
+ #else
206
+ C10_THROW_ERROR (NotImplementedError, " FP8 must be enabled to quantize an fp8 tensor to fp4." );
207
+ #endif
208
+ } else {
209
+ C10_THROW_ERROR (NotImplementedError,
210
+ " fp4_quantize only supports input tensor with dtypes fp16/bf16/e4m3." );
211
+ }
212
+
213
+ #undef LAUNCH_FP4_QUANTIZE_KERNEL
214
+
215
+ return {valueE2M1, scaleFP8SF};
216
+ }
217
+
145
218
} // namespace torch_ext
146
219
147
- TORCH_LIBRARY_FRAGMENT (TORCH_EXTENSION_NAME, m) { m.def (" fp4_quantize" , &torch_ext::fp4_quantize); }
220
+ TORCH_LIBRARY_FRAGMENT (TORCH_EXTENSION_NAME, m) {
221
+ m.def (" fp4_quantize" , &torch_ext::fp4_quantize);
222
+ m.def (" fp4_batched_quantize" , &torch_ext::fp4_batched_quantize);
223
+ }
0 commit comments