|
16 | 16 | from typing import Callable, Optional |
17 | 17 |
|
18 | 18 | import torch |
| 19 | +from packaging.version import Version as PkgVersion |
19 | 20 | from transformers import AutoConfig |
20 | 21 | from transformers.configuration_utils import PretrainedConfig |
21 | 22 | from transformers.models.llama.configuration_llama import LlamaConfig |
@@ -80,9 +81,27 @@ def convert_config_to_flops_config( |
80 | 81 | raise ValueError(f"Unsupported config type: {type(config)}") |
81 | 82 |
|
82 | 83 |
|
| 84 | +def is_using_tf32() -> bool: |
| 85 | + """Check if the current device is using TF32.""" |
| 86 | + if PkgVersion(torch.__version__) < PkgVersion("2.9.0a0"): |
| 87 | + return torch.backends.cuda.matmul.allow_tf32 |
| 88 | + else: |
| 89 | + return torch.backends.cuda.matmul.fp32_precision == "tf32" |
| 90 | + |
| 91 | + |
83 | 92 | THEORETICAL_TFLOPS = { |
| 93 | + ("NVIDIA A100 80GB PCIe", torch.bfloat16): 624 / 2, |
| 94 | + ("NVIDIA A100 80GB PCIe", torch.float32): 312 / 2 if is_using_tf32() else 19.5, |
84 | 95 | ("NVIDIA H100 80GB HBM3", torch.bfloat16): 1979 / 2, |
85 | | - ("NVIDIA H100 80GB HBM3", torch.float32): 67.0, |
| 96 | + ("NVIDIA H100 80GB HBM3", torch.float32): 989 / 2 if is_using_tf32() else 67.0, |
| 97 | + ("NVIDIA B200", torch.bfloat16): 4500 / 2, |
| 98 | + ("NVIDIA B200", torch.float32): 2200 / 2 if is_using_tf32() else 80.0, |
| 99 | + ("NVIDIA B300", torch.bfloat16): 4500 / 2, |
| 100 | + ("NVIDIA B300", torch.float32): 2200 / 2 if is_using_tf32() else 80.0, |
| 101 | + ("NVIDIA GB200", torch.bfloat16): 4900 / 2, |
| 102 | + ("NVIDIA GB200", torch.float32): 2500 / 2 if is_using_tf32() else 80.0, |
| 103 | + ("NVIDIA GB300", torch.bfloat16): 4900 / 2, |
| 104 | + ("NVIDIA GB300", torch.float32): 2500 / 2 if is_using_tf32() else 80.0, |
86 | 105 | } |
87 | 106 |
|
88 | 107 |
|
|
0 commit comments