Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,19 @@ def print_kernels(kernels: Optional[List[str]]) -> List[QuantizeOpBase]:
default=None,
help="If set with grouped mode, repeat input shapes this many times. Comma separated list of groups to benchmark",
)
@click.option(
"--total-K",
default=None,
help="If set, adjusts the K values to sum to this number. "
"This can help simulate real grouped workloads in backward wgrad. "
"Comma separated list of total-K values to benchmark.",
)
@click.option(
"--total-M",
default=None,
help="If set, Adjusts the M values to sum to this number. "
"This can help simulate real grouped workloads.",
help="If set, adjusts the M values to sum to this number. "
"This can help simulate real grouped workloads."
"Comma separated list of total-M values to benchmark.",
)
@click.option(
"--no-cuda-graph",
Expand Down Expand Up @@ -542,6 +550,7 @@ def invoke_main(
pair_nk: bool,
grouped: bool,
groups: Optional[str],
total_k: Optional[str],
total_m: Optional[str],
no_cuda_graph: bool,
use_rotating_buffer_bench: bool,
Expand All @@ -553,6 +562,14 @@ def invoke_main(
):
if enable_amd_env_vars:
set_amd_env_vars()

# Validate that total_m and total_k are mutually exclusive
if total_m is not None and total_k is not None:
raise ValueError(
"total_m and total_k cannot be specified at the same time. "
"Please provide only one of them."
)

# If kernel filter is provided, parse it. Else, benchmark all kernels.
all_kernels = kernels.strip().split(",") if kernels else None
quantize_ops = collect_kernels_to_profile(all_kernels)
Expand Down Expand Up @@ -619,16 +636,31 @@ def invoke_main(
if groups:
groups_list = [int(g) for g in groups.strip().split(",")]
if total_m:
total_m_list = [int(tm) for tm in total_m.strip().split(",")]
MNK = [
[
[b] * g,
generate_group_tensor(g, int(total_m)),
generate_group_tensor(g, tm),
[n] * g,
[k] * g,
]
for g in groups_list
for tm in total_m_list
for b, _, n, k in MNK
]
elif total_k:
total_k_list = [int(tk) for tk in total_k.strip().split(",")]
MNK = [
[
[b] * g,
[m] * g,
[n] * g,
generate_group_tensor(g, tk),
]
for g in groups_list
for tk in total_k_list
for b, m, n, _ in MNK
]
else:
MNK = [
[[b] * g, [m] * g, [n] * g, [k] * g]
Expand Down
48 changes: 47 additions & 1 deletion fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,7 +2084,7 @@ def cuda(self) -> bool:
@register_quantize_op
class BF16GroupedGrad(QuantizeOpBase):
"""
BF16 grouped matmul with grad inputs backed by cutlass
BF16 grouped matmul with dgrad inputs in pretraining backed by cutlass
"""

def preprocess(self, x, w):
Expand Down Expand Up @@ -2126,6 +2126,52 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class BF16GroupedWGrad(QuantizeOpBase):
"""
BF16 grouped matmul with wgrad inputs in pretraining backed by cutlass
"""

def preprocess(self, x, w):
# Get K values for each group
k_values = [xi.shape[1] for xi in x] # K dimension for each group

# Convert k_values into sizes tensor
k_sizes = torch.tensor(k_values).to(dtype=torch.int64, device=x[0].device)

x = torch.concat(x, dim=1).contiguous() # shape: (M, G*K)
w = torch.concat(w, dim=1).contiguous() # shape: (N, G*K)

# Transpose the follows to simulate wgrad shapes
x = x.t().contiguous() # shape: (G*K, M)
w = w.t().contiguous() # shape: (G*K, N)

# Return processed tensors
return x, w, k_sizes

def quantize(self, x, w, k_sizes):
return x, w, k_sizes

def compute(self, x, w, k_sizes):
return torch.ops.fbgemm.bf16bf16bf16_grouped_wgrad(x, w, k_sizes)

def quantize_and_compute(self, x, w, k_sizes):
x, w, k_sizes = self.quantize(x, w, k_sizes)
return self.compute(x, w, k_sizes)

@property
def name(self) -> str:
return "bf16_grouped_wgrad"

@property
def hip(self) -> bool:
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class BF16GroupedStacked(QuantizeOpBase):
"""
Expand Down
Loading
Loading