|
| 1 | +/************************************************************************* |
| 2 | + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. |
| 3 | + * |
| 4 | + * License for AMD contributions = MIT. See LICENSE for more information |
| 5 | + ************************************************************************/ |
| 6 | + |
| 7 | +#include <transformer_engine/comm_gemm_overlap.h> |
| 8 | +#include <transformer_engine/gemm.h> |
| 9 | +#include <transformer_engine/transformer_engine.h> |
| 10 | + |
| 11 | +#include "common/common.h" |
| 12 | +#include "common/util/cuda_driver.h" |
| 13 | +#include "common/util/cuda_runtime.h" |
| 14 | +#include "common/util/logging.h" |
| 15 | +#include "common/util/system.h" |
| 16 | +#include "userbuffers/userbuffers.h" |
| 17 | + |
| 18 | +namespace transformer_engine { |
| 19 | + |
| 20 | +void CommOverlapP2PBase::rocm_split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, |
| 21 | + bool transb, TensorWrapper &D, TensorWrapper &bias, |
| 22 | + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, |
| 23 | + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, |
| 24 | + cudaStream_t stream_main) { |
| 25 | + printf("rocm_split_overlap_ag_rd\n"); |
| 26 | + int ori_sms = _ub_comm->sms; |
| 27 | + _ub_comm->use_ce = _use_ce; |
| 28 | + _ub_comm->sms = _num_comm_sm; |
| 29 | + _ub_comm->cga_size = _cga_size; |
| 30 | + // Get GEMM dimensions between TN and NN input layouts |
| 31 | + const size_t m = (transa) ? A.size(0) : A.size(1); |
| 32 | + const size_t k = (transa) ? A.size(1) : A.size(0); |
| 33 | + const size_t n_chunk = _ubufs[0].size(0); |
| 34 | + const int comm_bytes = _ubufs[0].bytes(); |
| 35 | + const bool do_gelu = pre_gelu_out.numel() > 0; |
| 36 | + const size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); |
| 37 | + |
| 38 | + // Check B copy sizing |
| 39 | + if (B_copy.numel() > 0) { |
| 40 | + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", |
| 41 | + _ubuf.numel(), " elements but got ", B_copy.numel()); |
| 42 | + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), |
| 43 | + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, |
| 44 | + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); |
| 45 | + } |
| 46 | + |
| 47 | + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); |
| 48 | + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); |
| 49 | + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); |
| 50 | + for (size_t i = 0; i < _stream_compute.size(); i++) { |
| 51 | + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); |
| 52 | + } |
| 53 | + |
| 54 | + int steps = 31 - __builtin_clz(_tp_size); |
| 55 | + |
| 56 | + // Chunk dims |
| 57 | + std::vector<size_t> input_b_chunk_shape = |
| 58 | + (transb ? std::vector<size_t>{k, n_chunk} : std::vector<size_t>{n_chunk, k}); |
| 59 | + std::vector<size_t> output_chunk_shape = {n_chunk, m}; |
| 60 | + size_t input_b_chunk_size = n_chunk * k; |
| 61 | + size_t output_chunk_size = n_chunk * m; |
| 62 | + |
| 63 | + // GEMM |
| 64 | + auto input_b_chunk = |
| 65 | + get_buffer_chunk_like(B, input_b_chunk_size * _tp_id, input_b_chunk_shape); |
| 66 | + auto output_chunk = |
| 67 | + get_tensor_chunk(D, output_chunk_size * _tp_id, output_chunk_shape); |
| 68 | + auto aux_chunk = |
| 69 | + (do_gelu) |
| 70 | + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * _tp_id, {n_chunk, k}) |
| 71 | + : TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype()); |
| 72 | + auto workspace_chunk = get_tensor_chunk( |
| 73 | + workspace, (_tp_id % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); |
| 74 | + |
| 75 | + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), |
| 76 | + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, |
| 77 | + use_split_accumulator, _math_sms, |
| 78 | + _stream_compute[_tp_id % _stream_compute.size()]); |
| 79 | + |
| 80 | + std::vector<size_t> owned_chunks; |
| 81 | + owned_chunks.reserve(_tp_size); |
| 82 | + owned_chunks.push_back(_tp_id); |
| 83 | + size_t offset = 1; |
| 84 | + |
| 85 | + for (int step = 0; step < steps; step++) { |
| 86 | + int send_rank = (_tp_id + offset) % _tp_size; |
| 87 | + int recv_rank = (_tp_id - offset + _tp_size) % _tp_size; |
| 88 | + |
| 89 | + for (int i = 0; i < owned_chunks.size(); i++) { |
| 90 | + size_t send_offset = owned_chunks[i] * comm_bytes; |
| 91 | + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, |
| 92 | + comm_bytes, _ub_comm, send_rank, _stream_send[i % _stream_send.size()]); |
| 93 | + } |
| 94 | + |
| 95 | + std::vector<size_t> new_chunks; |
| 96 | + for (size_t i = 0; i < owned_chunks.size(); i++) { |
| 97 | + size_t new_chunk_id = (recv_rank + i * offset) % _tp_size; |
| 98 | + if (new_chunk_id >= _tp_size || |
| 99 | + std::find(owned_chunks.begin(), owned_chunks.end(), new_chunk_id) != owned_chunks.end()) continue; |
| 100 | + size_t recv_offset = new_chunk_id * comm_bytes; |
| 101 | + size_t stream_id = new_chunks.size() % _stream_compute.size(); |
| 102 | + |
| 103 | + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, |
| 104 | + comm_bytes, _ub_comm, recv_rank, _stream_recv); |
| 105 | + |
| 106 | + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); |
| 107 | + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[stream_id], _stop_recv, 0)); |
| 108 | + |
| 109 | + auto input_b_chunk = get_buffer_chunk_like(B, input_b_chunk_size * new_chunk_id, input_b_chunk_shape); |
| 110 | + output_chunk = get_tensor_chunk(D, output_chunk_size * new_chunk_id, output_chunk_shape); |
| 111 | + aux_chunk = (do_gelu) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * new_chunk_id, {n_chunk, k}) |
| 112 | + : TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype()); |
| 113 | + workspace_chunk = get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); |
| 114 | + |
| 115 | + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), |
| 116 | + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, |
| 117 | + use_split_accumulator, _math_sms, |
| 118 | + _stream_compute[stream_id]); |
| 119 | + |
| 120 | + new_chunks.push_back(new_chunk_id); |
| 121 | + } |
| 122 | + owned_chunks.insert(owned_chunks.end(), new_chunks.begin(), new_chunks.end()); |
| 123 | + offset <<= 1; |
| 124 | + } |
| 125 | + |
| 126 | + if (B_copy.numel() > 0) { |
| 127 | + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), |
| 128 | + cudaMemcpyDeviceToDevice, _stream_send[0])); |
| 129 | + } |
| 130 | + |
| 131 | + _ub_comm->sms = ori_sms; |
| 132 | + for (size_t i = 0; i < _stream_compute.size(); i++) { |
| 133 | + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); |
| 134 | + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); |
| 135 | + } |
| 136 | + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); |
| 137 | + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); |
| 138 | + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); |
| 139 | + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); |
| 140 | +} // rocm_split_overlap_ag_rd |
| 141 | + |
| 142 | +} // namespace transformer_engine |
0 commit comments