Skip to content

Commit f7645f3

Browse files
feat: Update Theoretical TFLOPS (#1236)
Signed-off-by: Youngeun Kwon <[email protected]>
1 parent cc8a93e commit f7645f3

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

nemo_rl/utils/flops_tracker.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Callable, Optional
1717

1818
import torch
19+
from packaging.version import Version as PkgVersion
1920
from transformers import AutoConfig
2021
from transformers.configuration_utils import PretrainedConfig
2122
from transformers.models.llama.configuration_llama import LlamaConfig
@@ -80,9 +81,27 @@ def convert_config_to_flops_config(
8081
raise ValueError(f"Unsupported config type: {type(config)}")
8182

8283

84+
def is_using_tf32() -> bool:
85+
"""Check if the current device is using TF32."""
86+
if PkgVersion(torch.__version__) < PkgVersion("2.9.0a0"):
87+
return torch.backends.cuda.matmul.allow_tf32
88+
else:
89+
return torch.backends.cuda.matmul.fp32_precision == "tf32"
90+
91+
8392
THEORETICAL_TFLOPS = {
93+
("NVIDIA A100 80GB PCIe", torch.bfloat16): 624 / 2,
94+
("NVIDIA A100 80GB PCIe", torch.float32): 312 / 2 if is_using_tf32() else 19.5,
8495
("NVIDIA H100 80GB HBM3", torch.bfloat16): 1979 / 2,
85-
("NVIDIA H100 80GB HBM3", torch.float32): 67.0,
96+
("NVIDIA H100 80GB HBM3", torch.float32): 989 / 2 if is_using_tf32() else 67.0,
97+
("NVIDIA B200", torch.bfloat16): 4500 / 2,
98+
("NVIDIA B200", torch.float32): 2200 / 2 if is_using_tf32() else 80.0,
99+
("NVIDIA B300", torch.bfloat16): 4500 / 2,
100+
("NVIDIA B300", torch.float32): 2200 / 2 if is_using_tf32() else 80.0,
101+
("NVIDIA GB200", torch.bfloat16): 4900 / 2,
102+
("NVIDIA GB200", torch.float32): 2500 / 2 if is_using_tf32() else 80.0,
103+
("NVIDIA GB300", torch.bfloat16): 4900 / 2,
104+
("NVIDIA GB300", torch.float32): 2500 / 2 if is_using_tf32() else 80.0,
86105
}
87106

88107

0 commit comments

Comments
 (0)