-
Notifications
You must be signed in to change notification settings - Fork 349
[mxfp8 moe training] add triton kernel for mxfp8 dequantization #3195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
danielvegamyhre
wants to merge
1
commit into
main
Choose a base branch
from
danielvegamyhre/stack/78
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
170 changes: 170 additions & 0 deletions
170
benchmarks/prototype/moe_training/mxfp8/bench_dequantize.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py | ||
|
||
from dataclasses import dataclass | ||
from typing import List | ||
|
||
import torch | ||
from tabulate import tabulate | ||
from tqdm import tqdm | ||
|
||
from benchmarks.utils import benchmark_cuda_function_in_microseconds | ||
from torchao.prototype.mx_formats.kernels import triton_mxfp8_dequant_dim0 | ||
from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx | ||
|
||
device = torch.device("cuda") | ||
|
||
# Needed since changing args to function causes recompiles | ||
torch._dynamo.config.cache_size_limit = 1000 | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ExperimentConfig: | ||
input_shape: tuple[int] | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ExperimentResult: | ||
# time | ||
torch_us: float | ||
triton_us: float | ||
torch_gbps: float | ||
triton_gbps: float | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Experiment: | ||
config: ExperimentConfig | ||
result: ExperimentResult | ||
|
||
|
||
def get_configs() -> List[ExperimentConfig]: | ||
input_shapes = [ | ||
# (local_batch_size, seq_len, dim) | ||
(1, 8192, 7168), | ||
(2, 8192, 7168), | ||
(4, 8192, 7168), | ||
(8, 8192, 7168), | ||
] | ||
configs = [] | ||
for shape in input_shapes: | ||
configs.append( | ||
ExperimentConfig( | ||
input_shape=shape, | ||
) | ||
) | ||
return configs | ||
|
||
|
||
def run_experiment(config: ExperimentConfig) -> ExperimentResult: | ||
block_size = 32 | ||
input_shape = config.input_shape | ||
input_tensor = torch.randn( | ||
*input_shape, | ||
dtype=torch.bfloat16, | ||
device=device, | ||
) | ||
|
||
e8m0_scales, e4m3_data = to_mx(input_tensor, torch.float8_e4m3fn, block_size) | ||
|
||
# Bench torch dequant | ||
to_dtype_c = torch.compile(to_dtype) | ||
elem_dtype, target_dtype = torch.float8_e4m3fn, torch.bfloat16 | ||
torch_output = to_dtype_c( | ||
e4m3_data, | ||
e8m0_scales, | ||
elem_dtype, | ||
block_size, | ||
target_dtype, | ||
) | ||
torch_us = benchmark_cuda_function_in_microseconds( | ||
to_dtype_c, | ||
e4m3_data, | ||
e8m0_scales, | ||
elem_dtype, | ||
block_size, | ||
target_dtype, | ||
) | ||
|
||
# Bench triton kernel | ||
_ = triton_mxfp8_dequant_dim0( | ||
e4m3_data, | ||
e8m0_scales, | ||
target_dtype, | ||
block_size, | ||
) | ||
triton_us = benchmark_cuda_function_in_microseconds( | ||
triton_mxfp8_dequant_dim0, | ||
e4m3_data, | ||
e8m0_scales, | ||
target_dtype, | ||
block_size, | ||
) | ||
|
||
# mem bw calculations | ||
bytes_per_input_el = torch.finfo(elem_dtype).bits / 8 | ||
bytes_per_output_el = torch.finfo(target_dtype).bits / 8 | ||
bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8 | ||
|
||
read_bytes = ( | ||
e4m3_data.numel() * bytes_per_input_el | ||
+ e8m0_scales.numel() * bytes_per_scale_el | ||
) | ||
write_bytes = torch_output.numel() * bytes_per_output_el | ||
|
||
torch_gbps = ((read_bytes + write_bytes) / 1e9) / (torch_us / 1e6) | ||
triton_gbps = ((read_bytes + write_bytes) / 1e9) / (triton_us / 1e6) | ||
|
||
return ExperimentResult( | ||
torch_us=torch_us, | ||
triton_us=triton_us, | ||
triton_gbps=triton_gbps, | ||
torch_gbps=torch_gbps, | ||
) | ||
|
||
|
||
def print_results(experiments: List[Experiment]): | ||
headers = [ | ||
"input_shape", | ||
"torch_us", | ||
"triton_us", | ||
"torch_gbps", | ||
"triton_gbps", | ||
"triton_speedup", | ||
] | ||
rows = [] | ||
for experiment in experiments: | ||
triton_speedup = round( | ||
experiment.result.torch_us / experiment.result.triton_us, 3 | ||
) | ||
rows.append( | ||
[ | ||
str(experiment.config.input_shape), | ||
experiment.result.torch_us, | ||
experiment.result.triton_us, | ||
round(experiment.result.torch_gbps, 3), | ||
round(experiment.result.triton_gbps, 3), | ||
f"{triton_speedup}x", | ||
] | ||
) | ||
print(tabulate(rows, headers=headers)) | ||
|
||
|
||
def main(): | ||
torch.random.manual_seed(123) | ||
configs = get_configs() | ||
results = [] | ||
for config in tqdm(configs): | ||
result = run_experiment(config) | ||
results.append(Experiment(config=config, result=result)) | ||
|
||
# Use Tabulate to print results | ||
print_results(results) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, didn't look at the rest too closely