Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,6 @@ def setup_pytorch_extension(
if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")

if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI")

library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
Expand All @@ -119,12 +111,22 @@ def setup_pytorch_extension(
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")

if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", 0))):
cxx_flags.append("-DNVTE_ENABLE_ROCSHMEM")
mpi_home = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi"))
include_dirs.append(mpi_home / "include")
library_dirs.append(mpi_home / "lib")
libraries.append("mpi_cxx")
libraries.append("mpi")
cxx_flags.extend(["-DNVTE_ENABLE_ROCSHMEM", "-DOMPI_SKIP_MPICXX"])

extra_link_args = []
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
mpi_path = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi"))
include_dirs.append(mpi_path / "include")
library_dirs.append(mpi_path / "lib")
libraries.append("mpi")
cxx_flags.extend(["-DNVTE_UB_WITH_MPI", "-DOMPI_SKIP_MPICXX"])

# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
Expand All @@ -138,4 +140,5 @@ def setup_pytorch_extension(
extra_compile_args={"cxx": cxx_flags},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
extra_link_args=[str(arg) for arg in extra_link_args],
)
12 changes: 9 additions & 3 deletions examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None):
)
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
"--fp8", action="store_true", default=False, help="Enables the te.autocast() context."
)
parser.add_argument(
"--no-comm-overlap",
Expand Down Expand Up @@ -263,7 +263,13 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False)
te.module.base.initialize_ub(
[batched_size, hidden_size],
tp_size,
use_fp8=opts.fp8,
quantization_modes=[
(
te.module.base.UserBufferQuantizationMode.FP8
if opts.fp8
else te.module.base.UserBufferQuantizationMode.NONE
)
],
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
)
Expand Down Expand Up @@ -293,7 +299,7 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False)

dist_print(" |-- Forward pass", group=tp_group, debug=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world):
y = model(x)
if isinstance(y, tuple):
out, *_ = y
Expand Down
Loading