Skip to content

Commit 517b712

Browse files
cthifacebook-github-bot
authored andcommitted
Blackwell FP4 Grouped refactor (#4847)
Summary: Pull Request resolved: #4847 X-link: facebookresearch/FBGEMM#1874 We plan to integrate NVFP4 + MXFP4 into torch, which will require updating the kernel a bit. Before we do this, some quick cosmetic and code refactors. - Remove the TensorList API. Theres no plans for it's continued usage given the inefficiency and that we have token shuffling now. Removing it would simplify a lot of the existing code. - Coalesce the current `set_stacked_kernel_args_kernel` to a single path, as the only difference is whether the global scale is passed in or not, which we could determine at the call site. - Compute alignment explicitly over hard coded 32 for A and B. Reviewed By: jiawenliu64 Differential Revision: D82046862 fbshipit-source-id: 31fcbe10c8a3715b59d1b2da0f366af3bc3f134b
1 parent b8c0b33 commit 517b712

16 files changed

+157
-1019
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -2426,107 +2426,6 @@ def cuda(self) -> bool:
24262426
return True
24272427

24282428

2429-
@register_quantize_op
2430-
class MXFP4GroupedGemm(QuantizeOpBase):
2431-
"""
2432-
MXFP4 grouped matmul with blockwise scaling.
2433-
"""
2434-
2435-
def preprocess(self, x, w):
2436-
wq, w_scale = zip(*[triton_quantize_mx4_unpack(i) for i in w])
2437-
return x, wq, w_scale
2438-
2439-
def quantize(self, x, wq, w_scale):
2440-
xq, x_scale = zip(*[triton_quantize_mx4_unpack(i) for i in x])
2441-
return xq, wq, x_scale, w_scale
2442-
2443-
def compute(self, xq, wq, x_scale, w_scale):
2444-
return torch.ops.fbgemm.f4f4bf16_grouped(
2445-
xq,
2446-
wq,
2447-
x_scale,
2448-
w_scale,
2449-
)
2450-
2451-
def quantize_and_compute(self, x, wq, w_scale):
2452-
xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
2453-
return self.compute(xq, wq, x_scale, w_scale)
2454-
2455-
@property
2456-
def name(self) -> str:
2457-
return "cutlass_f4f4bf16_grouped"
2458-
2459-
@property
2460-
def hip(self) -> bool:
2461-
# F4F4BF16_grouped only supported for cuda.
2462-
return False
2463-
2464-
@property
2465-
def cuda(self) -> bool:
2466-
return True
2467-
2468-
2469-
@register_quantize_op
2470-
class NVFP4GroupedGemm(QuantizeOpBase):
2471-
"""
2472-
NVFP4 grouped matmul with blockwise scaling.
2473-
"""
2474-
2475-
def quantize(self, x, w):
2476-
def get_global_scale(x, w):
2477-
x_global_scale = (448.0 * 6.0) / torch.amax(
2478-
torch.abs(x.flatten()), dim=-1
2479-
).to(torch.float32)
2480-
w_global_scale = (448.0 * 6.0) / torch.amax(
2481-
torch.abs(w.flatten()), dim=-1
2482-
).to(torch.float32)
2483-
global_scale = 1 / (x_global_scale * w_global_scale)
2484-
return x_global_scale, w_global_scale, global_scale
2485-
2486-
# Compute global scale for each group
2487-
G = len(x)
2488-
x_global_scale = []
2489-
w_global_scale = []
2490-
global_scale = []
2491-
for i in range(G):
2492-
x_global_scale_, w_global_scale_, global_scale_ = get_global_scale(
2493-
x[i], w[i]
2494-
)
2495-
x_global_scale.append(x_global_scale_)
2496-
w_global_scale.append(w_global_scale_)
2497-
global_scale.append(global_scale_)
2498-
2499-
# Quantize weights and activations
2500-
wq, w_scale = zip(
2501-
*[triton_scale_nvfp4_quant(w[i], w_global_scale[i]) for i in range(G)]
2502-
)
2503-
xq, x_scale = zip(
2504-
*[triton_scale_nvfp4_quant(x[i], x_global_scale[i]) for i in range(G)]
2505-
)
2506-
return xq, wq, x_scale, w_scale, global_scale
2507-
2508-
def compute(self, xq, wq, x_scale, w_scale, global_scale):
2509-
return torch.ops.fbgemm.f4f4bf16_grouped(
2510-
xq, wq, x_scale, w_scale, global_scale, use_mx=False
2511-
)
2512-
2513-
def quantize_and_compute(self, x, w):
2514-
xq, wq, x_scale, w_scale, global_scale = self.quantize(x, w)
2515-
return self.compute(xq, wq, x_scale, w_scale, global_scale)
2516-
2517-
@property
2518-
def name(self) -> str:
2519-
return "cutlass_nv_f4f4bf16_grouped"
2520-
2521-
@property
2522-
def hip(self) -> bool:
2523-
return False
2524-
2525-
@property
2526-
def cuda(self) -> bool:
2527-
return True
2528-
2529-
25302429
@register_quantize_op
25312430
class MXFP4StackedGroupedGemm(QuantizeOpBase):
25322431
"""

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped.cu

Lines changed: 18 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ namespace fbgemm_gpu {
2727

2828
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
2929

30-
template <typename InputType>
31-
Kernel_f4f4bf16_grouped<InputType>
30+
Kernel_f4f4bf16_grouped
3231
get_kernel_via_heuristics(int total_M, int N, int K, int G, bool use_mx) {
3332
// MXFP4
3433
if (use_mx) {
@@ -151,40 +150,35 @@ get_kernel_via_heuristics(int total_M, int N, int K, int G, bool use_mx) {
151150
}
152151
}
153152

154-
template <typename InputType>
155153
at::Tensor dispatch_fp4_grouped_kernel(
156154
int total_M,
157155
int N,
158156
int K,
159157
int G,
160-
InputType XQ, // FP4
161-
InputType WQ, // FP4
162-
InputType x_scale,
163-
InputType w_scale,
158+
at::Tensor XQ, // FP4
159+
at::Tensor WQ, // FP4
160+
at::Tensor x_scale,
161+
at::Tensor w_scale,
164162
at::Tensor output,
165163
std::optional<at::Tensor> zero_start_index_M = std::nullopt,
166164
std::optional<at::Tensor> M_sizes = std::nullopt,
167-
std::optional<InputType> global_scale = std::nullopt,
165+
std::optional<at::Tensor> global_scale = std::nullopt,
168166
std::optional<at::Tensor> starting_row_after_padding = std::nullopt,
169167
bool use_mx = true) {
170-
if constexpr (std::is_same_v<InputType, at::TensorList>) {
171-
TORCH_CHECK(WQ.size() == G);
172-
} else {
173-
TORCH_CHECK(
174-
zero_start_index_M.has_value() != M_sizes.has_value(),
175-
"One of zero_start_index_M or M_sizes must be provided.");
176-
TORCH_CHECK(M_sizes.has_value(), "M_sizes is assumed to be provided.");
177-
TORCH_CHECK(
178-
starting_row_after_padding.has_value(),
179-
"starting_row_after_padding is assumed to be provided.");
180-
at::Tensor starting_row_after_padding_actual =
181-
starting_row_after_padding.value_or(at::zeros({0}));
182-
TORCH_CHECK(starting_row_after_padding_actual.size(0) % (G + 1) == 0);
183-
}
168+
TORCH_CHECK(
169+
zero_start_index_M.has_value() != M_sizes.has_value(),
170+
"One of zero_start_index_M or M_sizes must be provided.");
171+
TORCH_CHECK(M_sizes.has_value(), "M_sizes is assumed to be provided.");
172+
TORCH_CHECK(
173+
starting_row_after_padding.has_value(),
174+
"starting_row_after_padding is assumed to be provided.");
175+
at::Tensor starting_row_after_padding_actual =
176+
starting_row_after_padding.value_or(at::zeros({0}));
177+
TORCH_CHECK(starting_row_after_padding_actual.size(0) % (G + 1) == 0);
184178

185179
// Select kernel to run via heuristics.
186180
auto kernel = [&]() {
187-
return get_kernel_via_heuristics<InputType>(total_M, N, K, G, use_mx);
181+
return get_kernel_via_heuristics(total_M, N, K, G, use_mx);
188182
}();
189183
// Invoke kernel
190184
return kernel(
@@ -200,82 +194,6 @@ at::Tensor dispatch_fp4_grouped_kernel(
200194
starting_row_after_padding);
201195
}
202196

203-
template <typename OutputType>
204-
OutputType _f4f4bf16_grouped(
205-
at::TensorList XQ, // FP4
206-
at::TensorList WQ, // FP4
207-
at::TensorList x_scale,
208-
at::TensorList w_scale,
209-
std::optional<at::TensorList> global_scale,
210-
bool use_mx) {
211-
at::Tensor Y;
212-
int64_t total_M = 0;
213-
int64_t max_N = 0;
214-
int64_t max_K = 0;
215-
int64_t G = XQ.size();
216-
217-
// Allocate output tensor.
218-
std::vector<int64_t> output_sizes;
219-
int64_t total_output_size = 0;
220-
for (int i = 0; i < G; ++i) {
221-
int64_t M = XQ[i].size(0);
222-
int64_t N = WQ[i].size(0);
223-
int64_t K = WQ[i].size(1);
224-
total_M += M;
225-
if (N > max_N) {
226-
max_N = N;
227-
}
228-
if (K > max_K) {
229-
max_K = K;
230-
}
231-
const int64_t output_size = M * N;
232-
total_output_size += output_size;
233-
output_sizes.push_back(output_size);
234-
}
235-
Y = at::empty(total_output_size, XQ[0].options().dtype(at::kBFloat16));
236-
237-
// Run kernel.
238-
at::Tensor g_out = dispatch_fp4_grouped_kernel<at::TensorList>(
239-
total_M,
240-
max_N,
241-
max_K * 2, // Since K is packed
242-
G,
243-
XQ,
244-
WQ,
245-
x_scale,
246-
w_scale,
247-
Y,
248-
std::nullopt,
249-
std::nullopt,
250-
global_scale,
251-
std::nullopt,
252-
use_mx);
253-
254-
// Return appropriate output type.
255-
if constexpr (std::is_same_v<OutputType, at::Tensor>) {
256-
int64_t N = WQ[0].size(0);
257-
return g_out.view({total_M, N});
258-
} else {
259-
// Return grouped view of output.
260-
std::vector<at::Tensor> output_group = g_out.split(output_sizes);
261-
for (int i = 0; i < G; ++i) {
262-
output_group[i] = output_group[i].view({XQ[i].size(0), WQ[i].size(0)});
263-
}
264-
return output_group;
265-
}
266-
}
267-
268-
std::vector<at::Tensor> f4f4bf16_grouped(
269-
at::TensorList XQ, // FP4
270-
at::TensorList WQ, // FP4
271-
at::TensorList x_scale,
272-
at::TensorList w_scale,
273-
std::optional<at::TensorList> global_scale = std::nullopt,
274-
bool use_mx = true) {
275-
return _f4f4bf16_grouped<std::vector<at::Tensor>>(
276-
XQ, WQ, x_scale, w_scale, global_scale, use_mx);
277-
}
278-
279197
at::Tensor f4f4bf16_grouped_stacked(
280198
at::Tensor XQ, // FP4
281199
at::Tensor WQ, // FP4
@@ -300,7 +218,7 @@ at::Tensor f4f4bf16_grouped_stacked(
300218
return Y;
301219
}
302220
// Return continuous view of output.
303-
return dispatch_fp4_grouped_kernel<at::Tensor>(
221+
return dispatch_fp4_grouped_kernel(
304222
total_M,
305223
N,
306224
K * 2, // Since K is packed
@@ -319,17 +237,6 @@ at::Tensor f4f4bf16_grouped_stacked(
319237

320238
#else
321239

322-
std::vector<at::Tensor> f4f4bf16_grouped(
323-
at::TensorList XQ, // FP4
324-
at::TensorList WQ, // FP4
325-
at::TensorList x_scale,
326-
at::TensorList w_scale,
327-
std::optional<at::TensorList> global_scale = std::nullopt,
328-
bool use_mx = true) {
329-
throw std::runtime_error(
330-
"CUDA version is older than 12.8"); // requires CUDA>=12.8
331-
}
332-
333240
at::Tensor f4f4bf16_grouped_stacked(
334241
at::Tensor XQ, // FP4
335242
at::Tensor WQ, // FP4

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_128_128_256_1_1_1_f.cu

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,39 +24,6 @@ at::Tensor f4f4bf16_grouped_128_128_256_1_1_1_f(
2424
std::optional<at::Tensor> global_scale,
2525
std::optional<at::Tensor> starting_row_after_padding) {
2626
return f4f4bf16_grouped_impl<
27-
at::Tensor,
28-
cutlass::nv_float4_t<cutlass::float_e2m1_t>,
29-
128,
30-
128,
31-
256,
32-
1,
33-
1,
34-
1>(
35-
XQ,
36-
WQ,
37-
x_scale,
38-
w_scale,
39-
output,
40-
G,
41-
zero_start_index_M,
42-
M_sizes,
43-
global_scale,
44-
starting_row_after_padding);
45-
}
46-
47-
at::Tensor f4f4bf16_grouped_128_128_256_1_1_1_f(
48-
at::TensorList XQ, // FP4
49-
at::TensorList WQ, // FP4
50-
at::TensorList x_scale,
51-
at::TensorList w_scale,
52-
at::Tensor output,
53-
int64_t G,
54-
std::optional<at::Tensor> zero_start_index_M,
55-
std::optional<at::Tensor> M_sizes,
56-
std::optional<at::TensorList> global_scale,
57-
std::optional<at::Tensor> starting_row_after_padding) {
58-
return f4f4bf16_grouped_impl<
59-
at::TensorList,
6027
cutlass::nv_float4_t<cutlass::float_e2m1_t>,
6128
128,
6229
128,

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped/f4f4bf16_grouped_128_128_256_1_1_1_t.cu

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,39 +24,6 @@ at::Tensor f4f4bf16_grouped_128_128_256_1_1_1_t(
2424
std::optional<at::Tensor> global_scale,
2525
std::optional<at::Tensor> starting_row_after_padding) {
2626
return f4f4bf16_grouped_impl<
27-
at::Tensor,
28-
cutlass::mx_float4_t<cutlass::float_e2m1_t>,
29-
128,
30-
128,
31-
256,
32-
1,
33-
1,
34-
1>(
35-
XQ,
36-
WQ,
37-
x_scale,
38-
w_scale,
39-
output,
40-
G,
41-
zero_start_index_M,
42-
M_sizes,
43-
global_scale,
44-
starting_row_after_padding);
45-
}
46-
47-
at::Tensor f4f4bf16_grouped_128_128_256_1_1_1_t(
48-
at::TensorList XQ, // FP4
49-
at::TensorList WQ, // FP4
50-
at::TensorList x_scale,
51-
at::TensorList w_scale,
52-
at::Tensor output,
53-
int64_t G,
54-
std::optional<at::Tensor> zero_start_index_M,
55-
std::optional<at::Tensor> M_sizes,
56-
std::optional<at::TensorList> global_scale,
57-
std::optional<at::Tensor> starting_row_after_padding) {
58-
return f4f4bf16_grouped_impl<
59-
at::TensorList,
6027
cutlass::mx_float4_t<cutlass::float_e2m1_t>,
6128
128,
6229
128,

0 commit comments

Comments
 (0)