Skip to content

Commit 903002a

Browse files
cthimeta-codesync[bot]
authored andcommitted
Improve API for f4f4bf16 (#5163)
Summary: Pull Request resolved: #5163 X-link: https://github.com/facebookresearch/FBGEMM/pull/2162 We add some improvements for FP4 gemm. - Remove the need to pass `use_mx`, we can infer this based on `global_scale` - As a follow up we should improve the assertions on the proper FP4 dtypes, similar to what we have with [FP4 group gemm](https://www.internalfb.com/code/fbsource/[addad803d330]/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16_grouped.cu?lines=388-409). - Add optional `output` to API, which is in-line with other torch APIs. - Move function declaration to `torch_ops.h` which will remove the need for the forward declaration in Blas.cpp - Small code cleans ups Misc - Later we should likely clean-up & re-evaluate the heuristic for the kernel, right now its almost identical (and duplicated) for MX and NV FP4, and we are likely instantiating more instances than needed. Reviewed By: slayton58 Differential Revision: D87655845 fbshipit-source-id: d3ddd1f7efe7683fb8615ab2f6febe438ce6b380
1 parent 1179289 commit 903002a

File tree

53 files changed

+313
-210
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+313
-210
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp4_quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def triton_quantize_mx4_unpack(
289289
stochastic_casting (bool): Whether to use stochastic casting.
290290
291291
Returns:
292-
torch.Tensor: [M / 2] mx4 scaled tensor packed into in8
292+
torch.Tensor: [M / 2] mx4 scaled tensor packed into uint8
293293
torch.Tensor: [M / group_size] mx4 shared exponents into int8
294294
295295
eg.

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,7 +2385,7 @@ def quantize(self, x, w):
23852385

23862386
def compute(self, xq, wq, x_scale, w_scale, global_scale):
23872387
return torch.ops.fbgemm.f4f4bf16(
2388-
xq, wq, x_scale, w_scale, global_scale=global_scale, use_mx=False
2388+
xq, wq, x_scale, w_scale, global_scale=global_scale
23892389
)
23902390

23912391
def quantize_and_compute(self, x, w):
@@ -2471,7 +2471,7 @@ def quantize(self, x, w):
24712471

24722472
def compute(self, xq, wq, x_scale, w_scale, global_scale):
24732473
return torch.ops.fbgemm.f4f4bf16(
2474-
xq, wq, x_scale, w_scale, global_scale=global_scale, use_mx=False
2474+
xq, wq, x_scale, w_scale, global_scale=global_scale
24752475
)
24762476

24772477
def quantize_and_compute(self, x, w):

0 commit comments

Comments
 (0)