Skip to content

Speed up nvfp4 pack/unpack w/ torch.compile #400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def decompress_weight(
return decompressed_weight


@torch.compile(fullgraph=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to default to dynamic=True to avoid recompilation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
"""
Packs a tensor with values in the fp4 range into uint8.
Expand All @@ -127,12 +128,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:

# Find closest valid FP4 value index for each element
abs_x = torch.abs(x)
abs_indices = torch.zeros_like(abs_x, dtype=torch.long)
for i, val in enumerate(kE2M1):
abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices)
abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8]
abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n]

# Apply sign bit (bit 3) to get final 4-bit representation
indices = abs_indices + (torch.signbit(x) << 3).to(torch.long)
indices = abs_indices + (torch.signbit(x).to(torch.long) << 3)

# Reshape to prepare for packing pairs of values
indices = indices.reshape(-1)
Expand All @@ -155,6 +155,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
)

# reference: : https://github.com/vllm-project/vllm/pull/16362
@torch.compile(fullgraph=True)
def unpack_fp4_from_uint8(
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
) -> torch.Tensor:
Expand Down