From ac337b864e4f3d0c0aa64a210a9c37ff9133888f Mon Sep 17 00:00:00 2001 From: alextmagro Date: Thu, 30 Oct 2025 15:57:29 -0500 Subject: [PATCH 1/3] Initial UB code dump --- .../te_layer_with_overlap_profile.py | 499 ++++++++++++++++++ .../pytorch/comm_gemm_overlap/ub_config.json | 6 + hipify_custom_map.json | 4 +- .../distributed/run_layer_with_overlap.py | 5 + transformer_engine/common/CMakeLists.txt | 26 +- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 250 ++++++++- .../userbuffers/userbuffers-host.cpp | 4 + .../userbuffers/userbuffers.cu | 370 +++++++++++++ .../transformer_engine/comm_gemm_overlap.h | 34 +- .../common/util/cuda_runtime.cpp | 7 +- transformer_engine/common/util/cuda_runtime.h | 2 +- .../common/util/pybind_helper.h | 17 +- transformer_engine/pytorch/csrc/common.h | 2 - transformer_engine/pytorch/csrc/extensions.h | 2 - .../csrc/extensions/comm_gemm_overlap.cpp | 6 +- .../pytorch/csrc/extensions/pybind.cpp | 12 +- transformer_engine/pytorch/module/base.py | 4 +- 17 files changed, 1199 insertions(+), 51 deletions(-) create mode 100644 examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py create mode 100644 examples/pytorch/comm_gemm_overlap/ub_config.json diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py new file mode 100644 index 000000000..a4e5fd15d --- /dev/null +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py @@ -0,0 +1,499 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +import socket +import fcntl +import struct +import argparse +import warnings + +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel + +import torch.profiler + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpp_extensions as tex +from transformer_engine.common.recipe import Format, DelayedScaling + +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" + + +def _te_layer_argtype(name): + te_layers = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers)) + if name.lower() not in layer_map.keys(): + raise argparse.ArgumentTypeError( + f"Invalid TE layer name! Please choose from: {layer_map.keys()}" + ) + return layer_map[name.lower()] + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser( + description="Train a Transformer Engine module with GEMM+comm overlap via Userbuffers." + ) + parser.add_argument( + "-i", "--num-iters", type=int, default=10, help="Number of dummy 'training' iterations." + ) + parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") + parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.") + parser.add_argument( + "-n", "--num-heads", type=int, default=64, help="Number of attention heads." + ) + parser.add_argument( + "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." + ) + parser.add_argument( + "--layer-type", + type=_te_layer_argtype, + default=te.TransformerLayer, + help="Transformer Engine layer to train with comm+GEMM overlap.", + ) + parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") + parser.add_argument( + "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + ) + parser.add_argument( + "--no-comm-overlap", + action="store_true", + default=False, + help="Disable the comm+GEMM overlap.", + ) + parser.add_argument( + "--num-replicas", type=int, default=1, help="Number of data-parallel model replicas." + ) + parser.add_argument( + "--tcp-init", + action="store_true", + default=False, + help="Initialize torch.distributed with TcpStore.", + ) + parser.add_argument( + "--bind-to-device", + action="store_true", + default=False, + help="Initialize torch.distributed with `device_id` to bind each rank to a single device.", + ) + parser.add_argument( + "--bootstrap-backend", + type=str.lower, + default="nccl", + choices=["gloo", "mpi", "nccl"], + help="Communications backend for host tensor collectives during Userbuffers bootstrapping.", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="Print out from every rank instead of just the root rank of relevant process groups.", + ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Print out additional debug information.", + ) + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Enable PyTorch profiler.", + ) + parser.add_argument( + "--profile-dir", + type=str, + default="./logs/profiler_traces", + help="Directory to save PyTorch profiler traces.", + ) + parser.add_argument( + "--ub_config", + type=str, + default="./ub_config.json", + help="Userbuffer configuration file.", + ) + + args = parser.parse_args(argv, namespace) + if args.bootstrap_backend == "nccl": + args.bind_to_device = True + return args + + +def _get_layer_args(config, tp_group, tp_size, reference=False): + hidden_size = config.num_heads * config.head_dim + input_shape = [config.seq_length, config.batch_size, hidden_size] + args = [hidden_size] + kwargs = { + "params_dtype": torch.float32, + "device": "cuda", + "tp_group": tp_group, + "tp_size": tp_size, + "sequence_parallel": True, + } + kwargs["ub_overlap_ag"] = not config.no_comm_overlap + + if config.layer_type is te.Linear: + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["parallel_mode"] = "row" + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + kwargs["ub_name"] = "proj" + else: + input_shape[0] = config.seq_length // tp_size + kwargs["ub_bulk_wgrad"] = not config.no_comm_overlap + kwargs["ub_bulk_dgrad"] = not config.no_comm_overlap + if config.layer_type is te.LayerNormLinear: + args.append(3 * hidden_size) + kwargs["parallel_mode"] = "column" + kwargs["ub_name"] = "qkv" + else: + kwargs["set_parallel_mode"] = True + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + # args.append(4 * hidden_size) + args.append(int(3.5 * hidden_size)) + + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not config.no_comm_overlap + kwargs["hidden_dropout"] = 0.0 + + return args, kwargs, input_shape + +def create_ub_cfgs(config_file:str, tp_size: int = 8): + import json + with open(config_file, 'r') as f: + data = json.load(f) + cfgs = {} + _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None + layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] + + for name, method in data.items(): + is_reduce_scatter = name in layers_reduce_scatter_overlap + + layers_all_gather_overlap = [ + "qkv_fprop", + "qkv_dgrad", + "proj_dgrad", + "fc1_fprop", + "fc1_dgrad", + "fc2_dgrad", + ] + if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None: + _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range() + + + cfg = { + "method": method, + "is_reduce_scatter": is_reduce_scatter, + "num_sm": 1 if method == "ring_exchange" else 16, + "cga_size": 1 if method == "ring_exchange" else 2, + "set_sm_margin": False, + "num_splits": 4 if method == "pipeline" else tp_size, + "aggregate": False, + "atomic_gemm": False, + "use_ce": True, + "fp8_buf": name in layers_all_gather_overlap, + "comm_priority": _MAX_STREAM_PRIORITY, + "gemm_priority": _MIN_STREAM_PRIORITY, + } + + cfgs[name] = cfg + + return cfgs + +def _train(opts): + if "OMPI_COMM_WORLD_SIZE" in os.environ: + # Execution with `mpirun -np N` + WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) + WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) + opts.tcp_init = True + opts.bind_to_device = True + opts.bootstrap_backend = "mpi" + elif "TORCHELASTIC_RUN_ID" in os.environ: + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + else: + raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") + NUM_NODES = WORLD_SIZE // LOCAL_SIZE + + # Initialize torch.distributed global process group and get DP/TP groups + torch.cuda.set_device(LOCAL_RANK) + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + if opts.tcp_init or NUM_NODES > 1: + if NUM_NODES > 1: + assert ( + "MASTER_ADDR" in os.environ + ), "Multi-node run requires MASTER_ADDR to be set in the environment." + MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname())) + MASTER_PORT = os.getenv("MASTER_PORT", "1234") + dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}" + if opts.bind_to_device or opts.bootstrap_backend == "nccl": + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + + def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False): + if debug and not opts.debug: + return + group_rank = dist.get_rank(group) + stream = sys.stderr if error else sys.stdout + if group_rank == src: + stream.write(f"[rank{WORLD_RANK}] {msg}{end}") + dist.barrier(group) + + dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") + + # Figure out process groups for tensor- and data-parallelism (if any) + if NUM_NODES > 1: + # Create a list of world ranks on this node + hostname = socket.gethostname() + ifname = os.getenv( + "NVTE_UB_SOCKET_IFNAME", + os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), + ) + + if ifname is not None: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + hostname = socket.inet_ntoa( + fcntl.ioctl( + s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) + )[20:24] + ) + except OSError as err: + raise OSError(f"Invalid network interface: {ifname}") from err + + hostnames = [None for _ in range(WORLD_SIZE)] + dist.all_gather_object(hostnames, hostname) + unique_hosts = [] + for host in hostnames: + if host not in unique_hosts: + unique_hosts.append(host) + assert len(unique_hosts) == NUM_NODES + + ranks_per_node_list = [[] for _ in range(NUM_NODES)] + self_node_idx = -1 + for i, host in enumerate(hostnames): + node_idx = unique_hosts.index(host) + ranks_per_node_list[node_idx].append(i) + if host == hostname: + self_node_idx = node_idx + assert self_node_idx >= 0 + self_node_ranks = ranks_per_node_list[self_node_idx] + + if opts.num_replicas > 1: + # Split node ranks into multiple replicas + assert len(self_node_ranks) % opts.num_replicas == 0 + tp_size = len(self_node_ranks) // opts.num_replicas + ranks_per_replica_list = [] + for node_ranks in ranks_per_node_list: + for i in range(opts.num_replicas): + start = i * tp_size + end = start + tp_size + ranks_per_replica_list.append(node_ranks[start:end]) + + self_replica_idx = -1 + for i, replica_ranks in enumerate(ranks_per_replica_list): + if WORLD_RANK in replica_ranks: + self_replica_idx = i + break + assert self_replica_idx >= 0 + + else: + # The entire node is the tensor-parallel group + ranks_per_replica_list = ranks_per_node_list + self_replica_idx = self_node_idx + + tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") + ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) + + else: + if opts.num_replicas > 1: + # Mixed data- and tensor-parallelism on a single node + # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions + all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") + ranks_per_replica_tensor = all_ranks.reshape( + (opts.num_replicas, LOCAL_SIZE // opts.num_replicas) + ) + tp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.tolist(), backend="nccl" + ) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) + else: + dp_group = None + tp_group = nccl_world + + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) + dist_print( + f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}", + group=tp_group, + ) + if dp_group is not None: + dp_rank = dist.get_rank(dp_group) + dist_print( + f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}", + group=dp_group, + ) + else: + dp_rank = 0 + + # Intialize userbuffers + hidden_size = opts.num_heads * opts.head_dim + batched_size = opts.seq_length * opts.batch_size + if not opts.no_comm_overlap: + te.module.base.initialize_ub( + [batched_size, hidden_size], + tp_size, + use_fp8=opts.fp8, + dtype=torch.bfloat16, + bootstrap_backend=opts.bootstrap_backend, + ub_cfgs=create_ub_cfgs(opts.ub_config, tp_size) + ) + # Initialize the fused LayerNorm + Multi-layer Perceptron module + torch.manual_seed(opts.seed + dp_rank) + torch.cuda.manual_seed(opts.seed + tp_rank) + layer_args, layer_kwargs, input_size = _get_layer_args(opts, tp_group, tp_size) + model = opts.layer_type(*layer_args, **layer_kwargs) + if dp_group is not None: + model = DistributedDataParallel(model, dim=1, process_group=dp_group) + + # Initialize optimizer with model parameters + optim = torch.optim.Adam(model.parameters(), lr=0.0001) + + # Fp8 recipe setup + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + + if opts.profile: + log_dir = os.path.join(opts.profile_dir, f"rank_{WORLD_RANK}") + os.makedirs(log_dir, exist_ok=True) + dist_print(f"Profiler traces will be saved to: {log_dir}", group=nccl_world) + + schedule = torch.profiler.schedule(wait=1, warmup=2, active=5, repeat=1) + + on_trace_ready = torch.profiler.tensorboard_trace_handler( + log_dir, worker_name=f"rank_{WORLD_RANK}" + ) + + profiler_activities = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + + with torch.profiler.profile( + schedule=schedule, + # record_shapes=True, + # with_stack=True, + # with_flops=True, + # with_modules=True, + on_trace_ready=on_trace_ready, + profile_memory=True, + activities=profiler_activities, + ) as prof: + dist_print("Starting training iterations...") + for i in range(opts.num_iters): + dist_print(f" Iter {i+1}", group=tp_group, debug=True) + + dist_print(" |-- Generate random input batch", group=tp_group, debug=True) + x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) + + dist_print(" |-- Forward pass", group=tp_group, debug=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + dist_print(" |-- Compute loss", group=tp_group, debug=True) + loss = out.sum() + + dist_print(" |-- Backward pass", group=tp_group, debug=True) + loss.backward() + + dist_print(" |-- Optimizer step", group=tp_group, debug=True) + optim.step() + + prof.step() + + else: + dist_print("Starting training iterations...") + for i in range(opts.num_iters): + dist_print(f" Iter {i+1}", group=tp_group, debug=True) + + dist_print(" |-- Generate random input batch", group=tp_group, debug=True) + x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) + + dist_print(" |-- Forward pass", group=tp_group, debug=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + dist_print(" |-- Compute loss", group=tp_group, debug=True) + loss = out.sum() + + dist_print(" |-- Backward pass", group=tp_group, debug=True) + loss.backward() + + dist_print(" |-- Optimizer step", group=tp_group, debug=True) + optim.step() + + + torch.cuda.synchronize() + dist_print("Finished training!") + te.module.base.destroy_ub() + + dist_print("Destroying all process groups...", debug=True) + dist.destroy_process_group() + if opts.debug and WORLD_RANK == 0: + print("Exiting...\n", end="", flush=True) + + return 0 + + +if __name__ == "__main__": + sys.exit(_train(_parse_args())) \ No newline at end of file diff --git a/examples/pytorch/comm_gemm_overlap/ub_config.json b/examples/pytorch/comm_gemm_overlap/ub_config.json new file mode 100644 index 000000000..3dec195d7 --- /dev/null +++ b/examples/pytorch/comm_gemm_overlap/ub_config.json @@ -0,0 +1,6 @@ +{ + "proj_fprop" : "pipeline", + "fc2_fprop" : "ring_exchange", + "qkv_fprop" : "ring_exchange", + "fc1_fprop" : "recursive_doubling" +} diff --git a/hipify_custom_map.json b/hipify_custom_map.json index 8773c233e..7c8bac22b 100644 --- a/hipify_custom_map.json +++ b/hipify_custom_map.json @@ -5,7 +5,9 @@ "util/cuda_runtime.h" : "util/hip_runtime.h", "ATen/cudnn/Handle.h" : "ATen/miopen/Handle.h", "CUfunc_cache" : "hipFuncCache_t", - "" : "" + "" : "", + "cudaLaunchKernel": "hipLaunchKernel", + "CUmemGenericAllocationHandle": "hipMemGenericAllocationHandle_t" } } diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 48ace31c3..4cf0d18f6 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -28,6 +28,11 @@ warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) +import transformer_engine.pytorch.cpp_extensions as tex +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" + class multi_module_model(torch.nn.Module): def __init__(self, module, num_layers, *args, **kwargs): diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index f70c9f8bb..7d978d796 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -217,7 +217,11 @@ else() fused_rope/fused_rope.cu recipe/current_scaling.cu recipe/delayed_scaling.cu - recipe/fp8_block_scaling.cu) + recipe/fp8_block_scaling.cu + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/userbuffers/userbuffers.cu + comm_gemm_overlap/comm_gemm_overlap.cpp) # process source code files set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../..) @@ -261,17 +265,19 @@ if (USE_CUDA) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") +endif() - # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI - # Changed - option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) - if (NVTE_UB_WITH_MPI) - find_package(MPI REQUIRED) - target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) - target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) - target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) - endif() +# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI +# Changed +option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) +if (NVTE_UB_WITH_MPI) + find_package(MPI REQUIRED) + target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) + target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) + target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) +endif() +if (USE_CUDA) # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) else() diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 98a970c98..819e115f7 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -448,6 +448,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + printf("split_overlap_rs_pipeline"); // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; @@ -595,7 +596,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, CommOverlapType comm_type, int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, - bool atomic_gemm, bool aggregate) + bool atomic_gemm, bool aggregate, bool use_rd = false) : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, @@ -603,6 +604,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, _is_p2p = true; _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; + _use_rd = use_rd; // Create workspace tensor with userbuffer NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); @@ -796,6 +798,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { + printf("split_overlap_ag"); int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -952,6 +955,250 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); } // CommOverlapP2PBase::split_overlap_ag +/* +** Split AllGather + GEMM using P2P communication using recursive doubling +** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG +** outputs in each rank to be in the contiguous memory space after all ring exchange phases. +*/ +void CommOverlapP2PBase::split_overlap_ag_rd(TensorWrapper &A, bool transa, TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { + printf("split_overlap_ag_rd"); + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const bool do_gelu = pre_gelu_out.numel() > 0; + const int output_chunk_bytes = (n_chunk * m) * D.element_size(); + const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; + + // Get output and workspace data pointers + char *output_ptr = reinterpret_cast(D.dptr()); + char *pre_gelu_out_ptr = reinterpret_cast(pre_gelu_out.dptr()); + char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + if (_aggregate) { + const int num_steps = _tp_size / 2; + char *input_b_ptr = reinterpret_cast(_ubuf.dptr()); + + // Initial 1X input chunk exchange between neighboring peers + int send_chunk_id = _tp_id; + int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, + _stream_send[0]); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, + _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0)); + + int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; + const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp; + const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp; + + // Ring exchange of 2X inputs chunks + for (int i = 0; i < num_steps; i++) { + send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size; + recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size; + send_offset = comm_bytes * send_chunk_id; + recv_offset = comm_bytes * recv_chunk_id; + + // GEMM + char *input_b_chunk_ptr = input_b_ptr + send_offset; + auto input_b_chunk = + TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(), + nullptr, nullptr, B.scale_inv()); + + char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); + auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), + {n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr); + + char *aux_chunk_ptr = + (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; + auto aux_chunk_shape = + (do_gelu) ? std::vector{n_chunk * 2, m} : std::vector{0}; + auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, + pre_gelu_out.dtype()); + + char *workspace_chunk_ptr = + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto workspace_chunk = + TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, workspace.dtype()); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[i % _stream_compute.size()]); + + if (i < num_steps - 1) { + // P2P communication + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, + next_rank, _stream_send[0]); + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, + prev_rank, _stream_recv); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); + } else if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + } + } else { + //recursive doubling ag + int steps = 0; + int tmp_size = _tp_size; + while (tmp_size > 1) { + steps++; + tmp_size >>= 1; + } + + + //compute the first gemm using own data + { + int chunk_id = _tp_id; + cudaStream_t compute_stream = _stream_compute[chunk_id % _stream_compute.size()]; + + auto input_b_chunk = TensorWrapper(_ubufs[chunk_id].dptr(), + {n_chunk, k}, B.dtype(), + nullptr, nullptr, B.scale_inv()); + + char* output_chunk_ptr = output_ptr + (chunk_id * output_chunk_bytes); + auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), + {n_chunk, m}, + D.dtype(), D.amax(), D.scale(), nullptr); + + char *aux_chunk_ptr = + (do_gelu) ? pre_gelu_out_ptr + (chunk_id * aux_chunk_bytes) : nullptr; + auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk, m} : std::vector{0}; + auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), + aux_chunk_shape, pre_gelu_out.dtype()); + + char *workspace_chunk_ptr = + workspace_ptr + (chunk_id % _stream_compute.size()) * workspace_size_chunk; + auto workspace_chunk = TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, + workspace.dtype()); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), + bias.data(), aux_chunk.data(), + transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, + _math_sms, compute_stream); + } + + std::vector owned_chunks = {_tp_id}; + int offset = 1; + + for (int step = 0; step < steps; step++) { + int send_rank = (_tp_id + offset) % _tp_size; + int recv_rank = (_tp_id - offset + _tp_size) % _tp_size; + + // send and recv + for (auto chunk_id : owned_chunks) { + size_t send_offset = chunk_id * comm_bytes; + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, + comm_bytes, _ub_comm, send_rank, _stream_send[0]); + } + for (int j = 0; j < offset; j++) { + int recv_chunk_id = (recv_rank + j) % _tp_size; + size_t recv_offset = recv_chunk_id * comm_bytes; + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, + comm_bytes, _ub_comm, recv_rank, _stream_recv); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + + // when previous recv finishes, proceed the GEMM + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0)); + + for (int j = 0; j < offset; j++) { + int new_chunk_id = (recv_rank + j) % _tp_size; + cudaStream_t compute_stream = _stream_compute[new_chunk_id % _stream_compute.size()]; + + auto input_b_chunk = TensorWrapper(_ubufs[new_chunk_id].dptr(), + {n_chunk, k}, B.dtype(), + nullptr, nullptr, B.scale_inv()); + + char* output_chunk_ptr = output_ptr + (new_chunk_id * output_chunk_bytes); + auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), + {n_chunk, m}, + D.dtype(), D.amax(), D.scale(), nullptr); + + char *aux_chunk_ptr = + (do_gelu) ? pre_gelu_out_ptr + (new_chunk_id * aux_chunk_bytes) : nullptr; + auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk, m} : std::vector{0}; + auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), + aux_chunk_shape, pre_gelu_out.dtype()); + + char *workspace_chunk_ptr = + workspace_ptr + (new_chunk_id % _stream_compute.size()) * workspace_size_chunk; + auto workspace_chunk = TensorWrapper(reinterpret_cast(workspace_chunk_ptr), + std::vector{workspace_size_chunk}, + workspace.dtype()); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), + bias.data(), aux_chunk.data(), + transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, + _math_sms, compute_stream); + } + + for (int j = 0; j < offset; j++) { + owned_chunks.push_back((recv_rank + j) % _tp_size); + } + offset *= 2; + } + + // synchronize compute streams + for (auto& s : _stream_compute) { + NVTE_CHECK_CUDA(cudaStreamSynchronize(s)); + } + + if (B_copy.numel() > 0) { + assert(B_copy.numel() == _ubufs[_tp_id].numel()); + assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), + _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + + } + + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); +} // CommOverlapP2PBase::split_overlap_ag_rd + /* ** Split ReduceScatter + GEMM using P2P communication */ @@ -1024,6 +1271,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { + printf("split_overlap_rs_p2p"); int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index e52cdd8a1..fa10e9329 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -360,8 +360,12 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE)); NVTE_CHECK_CUDA(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE)); +#ifdef __HIP_PLATFORM_AMD__ + reinterpret_cast((reinterpret_cast((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); +#else (*comm)->flags = reinterpret_cast(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); +#endif using namespace std; diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 1211392e4..e17cfc4f2 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -14,6 +14,12 @@ #define half_dtype half #endif +#ifdef __HIP_PLATFORM_AMD__ +#define half_dtype hip_bfloat16 +#define __nv_fp8_e5m2 te_hip_fp8_e5m2 +#define __nv_fp8_e4m3 te_hip_fp8_e4m3 +#endif + #include #include #include @@ -24,6 +30,7 @@ #define MAX_THREADS 1024 +#ifndef __HIP_PLATFORM_AMD__ #define ATOMIC_CONSUMER(chunk) \ if (counters) { \ if (threadIdx.x == 0 && blockIdx.x == 0) { \ @@ -34,6 +41,18 @@ } \ if (blockIdx.x == 0) __syncthreads(); \ } +#else +#define ATOMIC_CONSUMER(chunk) \ + if (counters) { \ + if (threadIdx.x == 0 && blockIdx.x == 0) { \ + while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \ + } \ + ((unsigned int *)counters)[chunk] = 1; \ + __threadfence(); \ + } \ + if (blockIdx.x == 0) __syncthreads(); \ + } +#endif #define ATOMIC_PRODUCER(chunk) \ if (counters) { \ @@ -1025,7 +1044,11 @@ __global__ void __launch_bounds__(MAX_THREADS) // reset counter for next producer. ((unsigned int *)counters)[0] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } __syncthreads(); @@ -1116,7 +1139,11 @@ __global__ void __launch_bounds__(MAX_THREADS) // reset counter for next producer. ((unsigned int *)counters)[chunk_i] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } __syncthreads(); @@ -1357,6 +1384,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } } // fp16 inplace allgather kernel (Volta,Hopper) +#ifndef __HIP_PLATFORM_AMD__ #define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ cudaLaunchAttribute attribute_ub[2]; \ @@ -1662,6 +1690,244 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic), \ kernelArgs)); \ } +#else +#define callranks_ag(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + (comm->use_rr_kernel ? 0 : arg4 * arg7); \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + uint64_t arg10 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ + NVTE_CHECK_CUDA(cudaLaunchKernel( \ + reinterpret_cast(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr_ag \ + : userbuffers_fp16_sum_inplace_gpu_rw_ag), \ + sms, threads, kernelArgs, 0, stream)); \ + } + +#define callranks_agMC(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS; \ + int arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS; \ + int arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x; \ + int arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + uint4 *arg10 = reinterpret_cast(comm->mc_ptr[handler]); \ + uint64_t arg11 = comm->ub_timeout; \ + \ + hipLaunchKernelGGL( \ + (userbuffers_fp16_sum_inplace_gpu_mc_ag), \ + sms, threads, 0, stream, \ + arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11); \ + } + +#define callranks_rs(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + uint64_t arg10 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10)}; \ + NVTE_CHECK_CUDA(cudaLaunchKernel( \ + reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs), \ + sms, threads, kernelArgs, 0, stream)); \ + } + +#define callranks_rsMC(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS; \ + int arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS; \ + int arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x; \ + int arg6 = offset / 8 + arg4 * arg7; \ + void **arg8 = reinterpret_cast(comm->gpu_ptrs); \ + int arg9 = handler * comm->nvsize; \ + float4 *arg10 = reinterpret_cast(comm->mc_ptr[handler]); \ + uint64_t arg11 = comm->ub_timeout; \ + \ + hipLaunchKernelGGL( \ + (userbuffers_fp16_sum_inplace_gpu_mc_rs), \ + sms, threads, 0, stream, \ + arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11); \ + } + +#define callranks_rs_oop(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + uint64_t arg13 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13)}; \ + NVTE_CHECK_CUDA(cudaLaunchKernel( \ + reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop), \ + sms, threads, kernelArgs, 0, stream)); \ + } + +#define callranks_rs_oop_fp8(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \ + arg6 = offset / 16 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + float *arg13 = scale; \ + uint64_t arg14 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14)}; \ + NVTE_CHECK_CUDA(cudaLaunchKernel( \ + reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_fp8), \ + sms, threads, kernelArgs, 0, stream)); \ + } + +#define callranks_rs_oop_atomic_fp8(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 16 / x, \ + arg6 = offset / 16, arg8 = rowelements / 8, arg9 = strideelements_out / 8, \ + arg10 = strideelements_in / 16; \ + void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ + int arg12 = handler * comm->nvsize; \ + void *arg13 = output; \ + float *arg14 = scale; \ + void *arg15 = counters; \ + int arg16 = numchunks, arg17 = atomicindex; \ + uint64_t arg18 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15), reinterpret_cast(&arg16), \ + reinterpret_cast(&arg17), reinterpret_cast(&arg18)}; \ + NVTE_CHECK_CUDA(cudaLaunchKernel( \ + reinterpret_cast( \ + userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_atomic_fp8), \ + sms, threads, kernelArgs, 0, stream)); \ + } + +#define callranks_rs_oop_stride(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8; \ + void **arg10 = reinterpret_cast(comm->gpu_ptrs); \ + int arg11 = handler * comm->nvsize; \ + void *arg12 = output; \ + uint64_t arg13 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13)}; \ + NVTE_CHECK_CUDA(cudaLaunchKernel( \ + reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride), \ + sms, threads, kernelArgs, 0, stream)); \ + } + +#define callranks_rs_oop_stride_atomic(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \ + void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ + int arg12 = handler * comm->nvsize; \ + void *arg13 = output; \ + void *arg14 = counters; \ + uint64_t arg15 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15)}; \ + NVTE_CHECK_CUDA(cudaLaunchKernel( \ + reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_atomic), \ + sms, threads, kernelArgs, 0, stream)); \ + } + +#define callranks_rs_oop_stride_multiatomic(x) \ + if (ar_nvsize == x) { \ + int arg1 = op - NVTE_MAX_OPS, \ + arg2 = NVTE_REG0_OFFSET(comm) - \ + (op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \ + NVTE_MAX_OPS, \ + arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \ + arg6 = offset / 8, arg8 = rowelements / 8, arg9 = strideelements / 8, arg10 = numchunks; \ + void **arg11 = reinterpret_cast(comm->gpu_ptrs); \ + int arg12 = handler * comm->nvsize; \ + void *arg13 = output; \ + void *arg14 = counters; \ + uint64_t arg15 = comm->ub_timeout; \ + void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), \ + reinterpret_cast(&arg3), reinterpret_cast(&arg4), \ + reinterpret_cast(&arg5), reinterpret_cast(&arg6), \ + reinterpret_cast(&arg7), reinterpret_cast(&arg8), \ + reinterpret_cast(&arg9), reinterpret_cast(&arg10), \ + reinterpret_cast(&arg11), reinterpret_cast(&arg12), \ + reinterpret_cast(&arg13), reinterpret_cast(&arg14), \ + reinterpret_cast(&arg15)}; \ + NVTE_CHECK_CUDA(cudaLaunchKernel( \ + reinterpret_cast(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop_stride_multiatomic), \ + sms, threads, kernelArgs, 0, stream)); \ + } +#endif void reducescatter2_userbuff_strided(void *output, const int handler, const int offset, const int rowelements, const int colelements, @@ -1680,7 +1946,11 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; +#ifndef __HIP_PLATFORM_AMD__ SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); +#else + int threads = comm->threads; +#endif callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8) callranks_rs_oop_stride(16) callranks_rs_oop_stride(32) } @@ -1702,7 +1972,11 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; +#ifndef __HIP_PLATFORM_AMD__ SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); +#else + int threads = comm->threads; +#endif callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4) callranks_rs_oop_stride_atomic(8) callranks_rs_oop_stride_atomic(16) callranks_rs_oop_stride_atomic(32) @@ -1729,7 +2003,11 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; +#ifndef __HIP_PLATFORM_AMD__ SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); +#else + int threads = comm->threads; +#endif callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8) callranks_rs_oop_atomic_fp8(16) callranks_rs_oop_atomic_fp8(32) } @@ -1774,7 +2052,11 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; +#ifndef __HIP_PLATFORM_AMD__ SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); +#else + int threads = comm->threads; +#endif callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4) callranks_rs_oop_stride_multiatomic(8) callranks_rs_oop_stride_multiatomic(16) callranks_rs_oop_stride_multiatomic(32) @@ -1795,6 +2077,7 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; +#ifndef __HIP_PLATFORM_AMD__ if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { @@ -1810,6 +2093,13 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32) } } +#else + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32) + } +#endif } void allgather2_userbuff_inplace_sliced(const int handler, const int offset, const int elements, @@ -1842,6 +2132,7 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; +#ifndef __HIP_PLATFORM_AMD__ if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { @@ -1857,7 +2148,16 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32) } } +#else + int threads = comm->threads; + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32) + } else { + callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32) + } +#endif } + void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, communicator *comm, @@ -1875,6 +2175,7 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; +#ifndef __HIP_PLATFORM_AMD__ if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { @@ -1894,7 +2195,13 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons callranks_rs_oop(32) } } +#else + int threads = comm->threads; + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(16) + callranks_rs_oop(32) +#endif } + void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements, communicator *comm, cudaStream_t stream, cudaEvent_t comm_launch_event) { @@ -1921,6 +2228,7 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; +#ifndef __HIP_PLATFORM_AMD__ if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16) @@ -1930,6 +2238,11 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16) callranks_rs_oop_fp8(32) } +#else + int threads = comm->threads; + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16) + callranks_rs_oop_fp8(32) +#endif } template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>( @@ -2196,7 +2509,11 @@ __global__ void __launch_bounds__(MAX_THREADS) // Decrement atomic val to signal current output tile finish if (counters) { ((unsigned int *)counters)[0] = 0; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } } @@ -2267,7 +2584,11 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat // Decrement atomic val to signal current output tile finish if (counters) { ((unsigned int *)counters)[recv_chunk_id /*chunk_i+1*/] = 0; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } @@ -2328,7 +2649,12 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } +#ifndef __HIP_PLATFORM_AMD__ SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); +#else + int sms = signalonly ? 1 : comm->sms; + int threads = signalonly ? 1 : 1024; +#endif int *arg1 = &comm->send_id[peer], *arg2 = reinterpret_cast(flagptr); int4 *arg3 = reinterpret_cast(srcptr), *arg4 = reinterpret_cast(dstptr); int arg5 = signalonly ? 0 : bytes / 16; @@ -2336,7 +2662,10 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds reinterpret_cast(&arg3), reinterpret_cast(&arg4), reinterpret_cast(&arg5)}; NVTE_CHECK_CUDA( +#ifndef __HIP_PLATFORM_AMD__ cudaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsend), kernelArgs)); +#else + cudaLaunchKernel(reinterpret_cast(kuserbuffers_pushsend), sms, threads, kernelArgs, 0, stream)); } } @@ -2361,7 +2690,12 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } +#ifndef __HIP_PLATFORM_AMD__ SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); +#else + int sms = signalonly ? 1 : comm->sms; + int threads = signalonly ? 1 : 1024; +#endif int *arg1 = &comm->send_id[send_peer]; int *arg2 = reinterpret_cast(flagptr_send); @@ -2391,7 +2725,11 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size reinterpret_cast(&arg13), reinterpret_cast(&arg14), reinterpret_cast(&arg15)}; NVTE_CHECK_CUDA( +#ifndef __HIP_PLATFORM_AMD__ cudaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsendrecv), kernelArgs)); +#else + cudaLaunchKernel(reinterpret_cast(kuserbuffers_pushsendrecv), sms, threads, kernelArgs, 0, stream)); +#endif } void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, @@ -2417,7 +2755,12 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, cudaMemcpyAsync(send_dstptr, send_srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } +#ifndef __HIP_PLATFORM_AMD__ SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); +#else + int sms = signalonly ? 1 : comm->sms; + int threads = signalonly ? 1 : 1024; +#endif int *arg1 = &comm->send_id[send_peer]; int *arg2 = reinterpret_cast(flagptr_send); @@ -2447,8 +2790,13 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, reinterpret_cast(&arg11), reinterpret_cast(&arg12), reinterpret_cast(&arg13), reinterpret_cast(&arg14), reinterpret_cast(&arg15), reinterpret_cast(&arg16)}; +#ifndef __HIP_PLATFORM_AMD__ NVTE_CHECK_CUDA(cudaLaunchKernelExC( &cfg, reinterpret_cast(kuserbuffers_pushsendrecv_atomic), kernelArgs)); +#else + NVTE_CHECK_CUDA( + cudaLaunchKernel(reinterpret_cast(kuserbuffers_pushsendrecv_atomic), sms, threads, kernelArgs, 0, stream)); +#endif } void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler, @@ -2464,7 +2812,12 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); +#ifndef SETUP_LAUNCH_CONFIG(comm->sms, 1024, stream); +#else + int sms = comm->sms; + int threads = 1024; +#endif int *arg1 = &comm->send_id[send_peer]; int *arg2 = reinterpret_cast(flagptr_send); @@ -2493,8 +2846,13 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler reinterpret_cast(&arg13), reinterpret_cast(&arg14), reinterpret_cast(&arg15), reinterpret_cast(&arg16), reinterpret_cast(&arg17), reinterpret_cast(&arg18)}; +#ifndef __HIP_PLATFORM_AMD__ NVTE_CHECK_CUDA(cudaLaunchKernelExC( &cfg, reinterpret_cast(kuserbuffers_pushsendrecv_multiatomic), kernelArgs)); +#else + NVTE_CHECK_CUDA( + cudaLaunchKernel(reinterpret_cast(kuserbuffers_pushsendrecv_multiatomic), sms, threads, kernelArgs, 0, stream)); +#endif } void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, @@ -2545,7 +2903,11 @@ static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) { // COMM kernel need to explicitely flash gmem. // GEMM kernel already executed, and can not see gmem // change without COMM kernel explicitely make change +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } // consumer @@ -2555,7 +2917,11 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) { while (0 != (atomicCAS((unsigned int *)atomic_ptr + chunk_i, 0, 0))) { } ((unsigned int *)atomic_ptr)[chunk_i] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } @@ -2567,7 +2933,11 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i while (0 != (atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) { } ((unsigned int *)atomic_ptr)[i] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } } diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 293c57526..5f810e4d8 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -36,7 +36,8 @@ enum class CommOverlapAlgo { SPLIT_PIPELINED_RS_P2P = 4, ATOMIC_GEMM_RS = 5, ATOMIC_GEMM_AG_P2P = 6, - ATOMIC_GEMM_RS_P2P = 7 + ATOMIC_GEMM_RS_P2P = 7, + SPLIT_PIPELINED_AG_RD_P2P = 8 }; class CommOverlapCore { @@ -133,6 +134,14 @@ class CommOverlapCore { cudaStream_t stream_main) { NVTE_ERROR("Operation is not implemented."); } + + virtual void split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { @@ -181,6 +190,17 @@ class CommOverlapBase : public CommOverlapCore { NVTE_ERROR("Operation not supported."); } + /* + ** Split AllGather + GEMM using P2P communication using recursive doubling + ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG + ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. + */ + void split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override; + /* ** Split FPROP GEMM + ReduceScatter */ @@ -205,6 +225,7 @@ class CommOverlapP2PBase : public CommOverlapCore { bool _is_reduce_scatter{false}; bool _use_multiatomic_ag{false}; bool _aggregate; + bool use_rd; int _next_rank; int _prev_rank; int _rank_round_tp; @@ -224,7 +245,7 @@ class CommOverlapP2PBase : public CommOverlapCore { CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, - bool atomic_gemm = false, bool aggregate = false); + bool atomic_gemm = false, bool aggregate = false, bool use_rd = false); virtual ~CommOverlapP2PBase(); @@ -260,6 +281,15 @@ class CommOverlapP2PBase : public CommOverlapCore { bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) override; + /* + ** Split AllGather + GEMM using P2P communication using recursive doubling + */ + void split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override; + /* ** Split ReduceScatter + GEMM using P2P communication */ diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 9453c2f86..4f47c3e00 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -27,7 +27,7 @@ namespace { #include "string_path_cuda_include.h" } // namespace -#endif // __HIP_PLATFORM_AMD__ +#endif // #ifndef __HIP_PLATFORM_AMD__ int num_devices() { auto query_num_devices = []() -> int { @@ -103,7 +103,6 @@ int sm_count(int device_id) { return cache[device_id]; } -#ifndef __HIP_PLATFORM_AMD__ void stream_priority_range(int *low_priority, int *high_priority, int device_id) { static std::vector> cache(num_devices()); static std::vector flags(num_devices()); @@ -125,7 +124,7 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id) } bool supports_multicast(int device_id) { -#if CUDART_VERSION >= 12010 +#if !defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= 12010 // NOTE: This needs to be guarded at compile time because the // CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions. static std::vector cache(num_devices(), false); @@ -155,7 +154,7 @@ bool supports_multicast(int device_id) { #endif } - +#ifndef __HIP_PLATFORM_AMD__ const std::string &include_directory(bool required) { static std::string path; diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index 0f0f730be..137be5498 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -50,7 +50,6 @@ const std::string &sm_arch_name(int device_id = -1); */ int sm_count(int device_id = -1); -#ifndef __HIP_PLATFORM_AMD__ /* \brief Minimum and maximum stream priorities supported on device * * \param[in] device_id CUDA device (default is current device) @@ -69,6 +68,7 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id */ bool supports_multicast(int device_id = -1); +#ifndef __HIP_PLATFORM_AMD__ /* \brief Path to CUDA Toolkit headers * * The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index fbc6dd1e1..d074f560b 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -10,10 +10,7 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ #include -//TODO: rocm does not support comm gemm overlap yet -#ifndef USE_ROCM #include -#endif #include #include @@ -35,8 +32,6 @@ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); #endif -// Define comm overlap handles if not using ROCm -#ifndef USE_ROCM #define NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) \ pybind11::enum_(m, "CommOverlapType", \ pybind11::module_local()) \ @@ -53,7 +48,9 @@ transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \ + .value("SPLIT_PIPELINED_AG_RD_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_RD_P2P); \ py::class_>(m, "CommOverlapCore", \ pybind11::module_local()) \ @@ -88,14 +85,6 @@ py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ py::call_guard()); -#else -#define NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) \ - pybind11::class_(m, "CommOverlapType", \ - pybind11::module_local()); \ - py::class_>(m, "CommOverlapCore", \ - pybind11::module_local()); -#endif #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ pybind11::enum_(m, "DType", pybind11::module_local()) \ diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index c15a1ae3c..edc18fa0e 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -29,9 +29,7 @@ #include #include #include -#ifndef USE_ROCM #include -#endif #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2ff64ae90..0a2ca3aa2 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -399,7 +399,6 @@ void nvshmem_finalize(); } // namespace transformer_engine::pytorch -#ifndef USE_ROCM /*************************************************************************************************** * Comm+GEMM Overlap Wrappers **************************************************************************************************/ @@ -467,6 +466,5 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm std::optional> shape = std::nullopt); }; // CommOverlapP2P -#endif // !USE_ROCM #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index af59d544e..aed677feb 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -5,7 +5,6 @@ * * See LICENSE for license information. ************************************************************************/ -#ifndef USE_ROCM #include "../extensions.h" #include "transformer_engine/transformer_engine.h" @@ -227,14 +226,14 @@ CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::Scal te::CommOverlapType comm_type, int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, - bool aggregate) + bool aggregate, bool use_rd) : te::CommOverlapP2PBase( buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, - atomic_gemm, aggregate) {} + atomic_gemm, aggregate, use_rd) {} /* ** Copy input to _ubufs[0] @@ -302,4 +301,3 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional(m, "CommOverlapHelper") .def(py::init<>(), py::call_guard()) .def(py::init>(), @@ -398,20 +397,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m, "CommOverlapP2P") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, transformer_engine::CommOverlapType, int, int, int, int, int, bool, bool, bool, - bool>(), + bool, bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, - py::arg("use_ce") = true, py::arg("aggregate") = false) + py::arg("use_ce") = true, py::arg("aggregate") = false, py::arg("use_rd" = false)) .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, - py::arg("shape") = std::nullopt); -#else - m.def("CommOverlapHelper", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); - m.def("CommOverlap", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); - m.def("CommOverlapP2P", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); -#endif //USE_ROCM + py::arg("shape") = std::nullopt,); } diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 815c808b2..d7a4257f9 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -362,7 +362,8 @@ def add_ub( assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype - if method == "ring_exchange": + use_rd = method == "recursive_doubling" + if method == "ring_exchange" or use_rd: ub_obj = tex.CommOverlapP2P( shape, # Communication buffer shape buffer_dtype, # Communication buffer data type @@ -378,6 +379,7 @@ def add_ub( aggregate=aggregate, gemm_priority=gemm_priority, comm_priority=comm_priority, + use_rd=use_rd, ) else: ub_obj = tex.CommOverlap( From c793b5e6c5a119e44a4104f3d010213ffb2411cd Mon Sep 17 00:00:00 2001 From: alextmagro Date: Tue, 11 Nov 2025 15:19:46 -0600 Subject: [PATCH 2/3] Fixes for rocm 7.0 and dev --- .../te_layer_with_overlap_profile.py | 13 +++---- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 38 +++++++++++-------- .../userbuffers/userbuffers.cu | 19 ++++++---- .../transformer_engine/comm_gemm_overlap.h | 7 +++- .../common/util/pybind_helper.h | 4 ++ transformer_engine/pytorch/csrc/extensions.h | 10 +---- .../pytorch/csrc/extensions/gemm.cpp | 11 ++++-- .../pytorch/csrc/extensions/pybind.cpp | 4 +- transformer_engine/pytorch/module/base.py | 2 + 9 files changed, 60 insertions(+), 48 deletions(-) diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py index a4e5fd15d..71c2aa6c4 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py @@ -185,18 +185,14 @@ def _get_layer_args(config, tp_group, tp_size, reference=False): return args, kwargs, input_shape -def create_ub_cfgs(config_file:str, tp_size: int = 8): +def create_ub_cfgs(config_file: str, tp_size: int = 8): import json with open(config_file, 'r') as f: data = json.load(f) cfgs = {} _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] - - for name, method in data.items(): - is_reduce_scatter = name in layers_reduce_scatter_overlap - - layers_all_gather_overlap = [ + layers_all_gather_overlap = [ "qkv_fprop", "qkv_dgrad", "proj_dgrad", @@ -204,13 +200,14 @@ def create_ub_cfgs(config_file:str, tp_size: int = 8): "fc1_dgrad", "fc2_dgrad", ] + + for name, method in data.items(): if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None: _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range() - cfg = { "method": method, - "is_reduce_scatter": is_reduce_scatter, + "is_reduce_scatter": name in layers_reduce_scatter_overlap, "num_sm": 1 if method == "ring_exchange" else 16, "cga_size": 1 if method == "ring_exchange" else 2, "set_sm_margin": False, diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 819e115f7..b0ae0d421 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -21,6 +21,12 @@ #define HALF_BYTES 2 #define UB_MAX_SM 32 +#ifdef __HIP_PLATFORM_AMD__ +#define half_dtype hip_bfloat16 +#define __nv_fp8_e5m2 te_hip_fp8_e5m2 +#define __nv_fp8_e4m3 te_hip_fp8_e4m3 +#endif + using namespace std::placeholders; namespace transformer_engine { @@ -448,7 +454,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { - printf("split_overlap_rs_pipeline"); + printf("split_overlap_rs_pipeline\n"); // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; @@ -596,7 +602,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, CommOverlapType comm_type, int num_max_streams, int comm_cga_size, int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, - bool atomic_gemm, bool aggregate, bool use_rd = false) + bool atomic_gemm, bool aggregate, bool use_rd) : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, @@ -798,7 +804,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { - printf("split_overlap_ag"); + printf("split_overlap_ag\n"); int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -960,12 +966,12 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ -void CommOverlapP2PBase::split_overlap_ag_rd(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main) { - printf("split_overlap_ag_rd"); +void CommOverlapP2PBase::split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + printf("split_overlap_ag_rd\n"); int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -1025,12 +1031,12 @@ void CommOverlapP2PBase::split_overlap_ag_rd(TensorWrapper &A, bool transa, Tens // GEMM char *input_b_chunk_ptr = input_b_ptr + send_offset; auto input_b_chunk = - TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(), + TensorWrapper(reinterpret_cast(input_b_chunk_ptr), std::vector{n_chunk * 2, k}, B.dtype(), nullptr, nullptr, B.scale_inv()); char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), - {n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr); + std::vector{n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr); char *aux_chunk_ptr = (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; @@ -1084,12 +1090,12 @@ void CommOverlapP2PBase::split_overlap_ag_rd(TensorWrapper &A, bool transa, Tens cudaStream_t compute_stream = _stream_compute[chunk_id % _stream_compute.size()]; auto input_b_chunk = TensorWrapper(_ubufs[chunk_id].dptr(), - {n_chunk, k}, B.dtype(), + std::vector{n_chunk, k}, B.dtype(), nullptr, nullptr, B.scale_inv()); char* output_chunk_ptr = output_ptr + (chunk_id * output_chunk_bytes); auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), - {n_chunk, m}, + std::vector{n_chunk, m}, D.dtype(), D.amax(), D.scale(), nullptr); char *aux_chunk_ptr = @@ -1140,12 +1146,12 @@ void CommOverlapP2PBase::split_overlap_ag_rd(TensorWrapper &A, bool transa, Tens cudaStream_t compute_stream = _stream_compute[new_chunk_id % _stream_compute.size()]; auto input_b_chunk = TensorWrapper(_ubufs[new_chunk_id].dptr(), - {n_chunk, k}, B.dtype(), + std::vector{n_chunk, k}, B.dtype(), nullptr, nullptr, B.scale_inv()); char* output_chunk_ptr = output_ptr + (new_chunk_id * output_chunk_bytes); auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), - {n_chunk, m}, + std::vector{n_chunk, m}, D.dtype(), D.amax(), D.scale(), nullptr); char *aux_chunk_ptr = @@ -1271,7 +1277,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { - printf("split_overlap_rs_p2p"); + printf("split_overlap_rs_p2p\n"); int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index e17cfc4f2..c4ec0a9d8 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -8,16 +8,17 @@ #include #include -#if __CUDA_ARCH__ >= 800 -#define half_dtype nv_bfloat16 -#else -#define half_dtype half -#endif #ifdef __HIP_PLATFORM_AMD__ #define half_dtype hip_bfloat16 #define __nv_fp8_e5m2 te_hip_fp8_e5m2 #define __nv_fp8_e4m3 te_hip_fp8_e4m3 +#else +#if __CUDA_ARCH__ >= 800 +#define half_dtype nv_bfloat16 +#else +#define half_dtype half +#endif #endif #include @@ -2094,7 +2095,8 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int } } #else - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + int threads = comm->threads; + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32) } else { callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32) @@ -2150,7 +2152,7 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const } #else int threads = comm->threads; - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32) } else { callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32) @@ -2666,6 +2668,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds cudaLaunchKernelExC(&cfg, reinterpret_cast(kuserbuffers_pushsend), kernelArgs)); #else cudaLaunchKernel(reinterpret_cast(kuserbuffers_pushsend), sms, threads, kernelArgs, 0, stream)); +#endif } } @@ -2812,7 +2815,7 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler void *flagptr_send = GET_SEND_PTR_BY_INDEX(send_peerlocal, comm, dsthandler, 0); void *flagptr_recv = GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsthandler, 0); -#ifndef +#ifndef __HIP_PLATFORM_AMD__ SETUP_LAUNCH_CONFIG(comm->sms, 1024, stream); #else int sms = comm->sms; diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 5f810e4d8..2ea3ed6cb 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -58,6 +58,7 @@ class CommOverlapCore { int _comm_priority; bool _atomic_gemm{false}; bool _is_p2p{false}; + bool _use_rd{false}; TensorWrapper _ubuf; TensorWrapper _counter; @@ -93,6 +94,8 @@ class CommOverlapCore { bool is_p2p_overlap() { return _is_p2p; } + bool is_use_rd() { return _use_rd; } + bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, @@ -199,7 +202,9 @@ class CommOverlapBase : public CommOverlapCore { TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main) override; + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + }; /* ** Split FPROP GEMM + ReduceScatter diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index d074f560b..c893c6916 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -14,7 +14,11 @@ #include #include +#ifdef __HIP_PLATFORM_AMD__ +#include "hip_runtime.h" +#else #include "cuda_runtime.h" +#endif // Define fused-attention handles separately for USE_ROCM #ifndef USE_ROCM diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 0a2ca3aa2..22c2691ab 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -13,14 +13,6 @@ #include "common.h" -#ifdef USE_ROCM -namespace transformer_engine { -//dummy CommOverlapCore, CommOverlapType in rocm -class CommOverlapCore{}; -class CommOverlapType{}; -} -#endif - namespace transformer_engine::pytorch { /*************************************************************************************************** @@ -456,7 +448,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, - bool aggregate = false); + bool aggregate = false, bool use_rd = false); ~CommOverlapP2P() {} diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 9b93d62b3..a5d8bc636 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -187,7 +187,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans std::move(swizzle_scaling_factors(B_tensor, !transb))); if (comm_overlap) { -#ifndef USE_ROCM // Prepare extra output tensor TensorWrapper extra_output_tensor; if (extra_output.has_value()) { @@ -213,6 +212,13 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); + } else if (comm_overlap->is_use_rd()) { + NVTE_SCOPED_GIL_RELEASE({ + comm_overlap->split_overlap_ag_rd(A_tensor, transa, B_tensor, transb, D_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, + extra_output_tensor, main_stream); + }); } else { NVTE_SCOPED_GIL_RELEASE({ comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, @@ -238,9 +244,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans }); } } -#else - NVTE_ERROR("ROCm TE does not support comm_overlap\n"); -#endif //!USE_ROCM } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 1d14b05df..bcce6d777 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -403,9 +403,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, - py::arg("use_ce") = true, py::arg("aggregate") = false, py::arg("use_rd" = false)) + py::arg("use_ce") = true, py::arg("aggregate") = false, py::arg("use_rd") = false) .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, - py::arg("shape") = std::nullopt,); + py::arg("shape") = std::nullopt); } diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d7a4257f9..61997f350 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -308,6 +308,7 @@ def get_default_config(name): "comm_priority": _MAX_STREAM_PRIORITY, "gemm_priority": _MIN_STREAM_PRIORITY, "pipeline_rs_overlap_first_gemm": False, + "use_rd": False, } return default_cfg @@ -326,6 +327,7 @@ def add_ub( comm_priority: int = 0, gemm_priority: int = 0, pipeline_rs_overlap_first_gemm: bool = False, + use_rd: bool = False, ) -> None: if atomic_gemm: warnings.warn( From 896c1915b52002da370f2d4a58f703e04d7bfd65 Mon Sep 17 00:00:00 2001 From: alextmagro Date: Tue, 11 Nov 2025 15:22:52 -0600 Subject: [PATCH 3/3] Copyrights --- .../pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py | 2 ++ tests/pytorch/distributed/run_layer_with_overlap.py | 2 ++ .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 2 ++ .../common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp | 2 ++ .../common/comm_gemm_overlap/userbuffers/userbuffers.cu | 2 ++ .../common/include/transformer_engine/comm_gemm_overlap.h | 2 ++ 6 files changed, 12 insertions(+) diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py index 71c2aa6c4..cf0fd360b 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py @@ -1,5 +1,7 @@ #!/usr/bin/python3 +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 4cf0d18f6..2df24ddf2 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -1,5 +1,7 @@ #!/usr/bin/python3 +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index b0ae0d421..2af8b494b 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index fa10e9329..57df32ad1 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index c4ec0a9d8..7e0c483c1 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 2ea3ed6cb..2deb00276 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information.