diff --git a/.github/scripts/utils_build.bash b/.github/scripts/utils_build.bash index 82fa3e26a2..709e7b62f4 100644 --- a/.github/scripts/utils_build.bash +++ b/.github/scripts/utils_build.bash @@ -370,6 +370,7 @@ install_build_tools () { patchelf \ rhash \ scikit-build \ + tbb-devel \ tbb \ wheel \ xz \ diff --git a/cmake/modules/CppLibrary.cmake b/cmake/modules/CppLibrary.cmake index 92a93a60b6..f9c2ac4109 100644 --- a/cmake/modules/CppLibrary.cmake +++ b/cmake/modules/CppLibrary.cmake @@ -168,6 +168,18 @@ function(cpp_library) target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + target_link_libraries(${lib_name} PRIVATE TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + target_link_libraries(${lib_name} PRIVATE ${TBB_LIB}) + endif() + endif() + # Add sanitizer options if needed if(args_SANITIZER_OPTIONS) target_link_options(${lib_name} PUBLIC diff --git a/cmake/modules/GpuCppLibrary.cmake b/cmake/modules/GpuCppLibrary.cmake index 51c30df750..e662848348 100644 --- a/cmake/modules/GpuCppLibrary.cmake +++ b/cmake/modules/GpuCppLibrary.cmake @@ -302,6 +302,18 @@ function(gpu_cpp_library) list(APPEND library_dependencies ${NVML_LIB_PATH}) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + list(APPEND library_dependencies TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + list(APPEND library_dependencies ${TBB_LIB}) + endif() + endif() + # Link against the external libraries as needed target_link_libraries(${lib_name} PRIVATE ${library_dependencies}) diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index 82092cc173..b38f862564 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -176,7 +176,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( @@ -495,7 +494,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index a5277a906a..50506decb1 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -52,7 +52,11 @@ def render_backward_templates( return weighted_options = [True, False] - nobag_options = [True, False] if (not is_gwd) else [False] + nobag_options = ( + [True, False] + if (not (is_gwd or kwargs.get("is_hip_optimized_backward"))) + else [False] + ) vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False] ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False] template = CodeTemplate.load(template_filepath) @@ -327,8 +331,7 @@ def generate_backward_indices() -> None: @staticmethod def generate_rocm_backward_split(**kwargs: Any) -> None: - # Generate backward device kernels based on weighted (True/False), VBE - # (True/False), no bag (True/False) + # Generate backward device kernels based on weighted (True/False) template_filepath = ( "training/backward/rocm/embedding_backward_split_device_kernel_template.hip" ) @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None: "has_ssd_support": False, "dense": False, "gen_once": False, + "is_hip_optimized_backward": True, }, ) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp index 626838e930..0bc3c5f254 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function( c10::SymInt /* max_B = -1 */, c10::SymInt /* max_B_feature_rank = -1 */, c10::SymInt /* vbe_output_size = -1 */, - bool /* mixed_D = true */) { + bool /* mixed_D = false */) { return SplitLookupFunction_Dense_Op::apply( host_weights, weights_offsets, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 134a03b983..3fe516891f 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -960,7 +960,7 @@ class {{ autograd_func }} : #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; @@ -1116,7 +1116,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function( {%- else %} const c10::SymInt vbe_output_size = -1, {%- endif %} - const bool mixed_D = true + const bool mixed_D = false ) { // TODO: refactor into macro {%- if has_gpu_support %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu old mode 100644 new mode 100755 index 6d38d1d99a..9ffaea3a67 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -23,6 +23,10 @@ #include "fbgemm_gpu/utils/assert_macros.h" #include "fbgemm_gpu/utils/kernel_launcher.cuh" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -209,8 +213,127 @@ __global__ __launch_bounds__(kForwardMaxThreads) void 2, offset_idx + D_emb <= weights_numel, offset_idx ) {%- endif %} + int32_t j = 0; + {%- if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe %} + // Currently for split_embedding_codegen_grad_indice_weights_kernel only + if (placement != PlacementType::MANAGED_CACHING) { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + weight0 = weight_row0.load(d); + weight1 = weight_row1.load(d); + weight2 = weight_row2.load(d); + weight3 = weight_row3.load(d); + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); + + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + } else { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + const auto cache_idx_j0 = shfl_sync(cache_idx, j); + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + weight0 = (cache_idx_j0 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : + weight_row0.load(d); + + weight1 = (cache_idx_j1 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j1][d]) : + weight_row1.load(d); + + weight2 = (cache_idx_j2 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j2][d]) : + weight_row2.load(d); + + weight3 = (cache_idx_j3 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : + weight_row3.load(d); + + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + } + {%- endif %}{#-/* if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe */#} + for (; j < kWarpSize && l_start + j < L; ++j) { const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); @@ -359,6 +482,15 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 5137b5766c..deffa8bfab 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,6 +32,14 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" @@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} -{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" @@ -612,12 +620,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ) { - {%- if not nobag %} int32_t T = D_offsets.size(0) - 1; - {%- else %} - int32_t T = weights_offsets.size(0); - {%- endif %} - auto p_output_grad = grad_output.data(); auto p_emb_table = dev_weights.data(); auto p_hash_size_cumsum = hash_size_cumsum.data(); @@ -632,8 +635,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd constexpr int32_t segment_prefetch = 2; constexpr int32_t segment_unroll = 8; constexpr int32_t segment_split = 0; - auto batch = grad_output.size(0); - auto num_rows = dev_weights.size(0) / T / max_D; {%- if weighted %} constexpr bool is_weighted = true; {%- else %} @@ -646,30 +647,15 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd // weight_decay(_mode) is supplied as args.split_function_args_no_defaults opt_karg.weight_decay_mode = weight_decay_mode_v; opt_karg.weight_decay = weight_decay; - auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { - assert(d >= 1 && d <= INT32_MAX); - uint8_t shift; - for(shift = 0; shift < 32; shift++) - if((1U << shift) >= d) - break; - - uint64_t one = 1; - uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; - assert(magic <= 0xffffffffUL); - - rocm::magic_div_u32_t result; - result.magic = magic; - result.shift = shift; - return result; - }(batch); + rocm::split_tbe_backward_hip_kernel_{{kdesc}}< - rocm::{{optimizer}}_optimizer_t, + rocm::{{optimizer}}_optimizer_t, rocm::{{optimizer}}_kernel_arg_t, emb_t, cache_t, grad_t, index_t, - BLOCK_SIZE, + BLOCK_SIZE_ROCM, embedding_dim, segment_prefetch, segment_unroll, @@ -680,16 +666,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd p_sorted_linear_indices_run, p_sorted_linear_indices_cumulative_run_lengths, p_sorted_linear_indices_num_runs, - {%- if not nobag %} info_B_num_bits, info_B_mask, - {%- endif %} p_sorted_infos, - batch_mdiv, max_segment_length_per_warp, emb_dim, - batch, - num_rows, T, opt_karg {%- if weighted %} @@ -784,7 +765,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu old mode 100644 new mode 100755 index 76eba64c99..f07ef5830e --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,6 +48,15 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + template < typename emb_t, typename grad_t, @@ -227,8 +236,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -652,6 +660,16 @@ Tensor {{ embedding_cuda_op }}( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} @@ -852,8 +870,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} + {%- if is_optimized_hip_kernel_supported_mode %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -970,8 +987,11 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; - + {% if is_rocm %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + {% else %} + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + {%- endif %} Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { long_run_id_to_really_long_run_ids = @@ -1042,7 +1062,22 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); + {% if is_rocm %} + int32_t total_L = indices.numel(); + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1) { + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else { + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } + {%- else %} int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + const int32_t work_group_size = kMaxThreads; + {%- endif %} const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1053,7 +1088,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, kMaxThreads), + div_round_up(total_unique_indices, work_group_size), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( @@ -1162,7 +1197,17 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; // Compute shared memory size for warp_per_row - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + {%- if is_rocm %} + int32_t num_warp_per_row_groups; + if (total_L/total_B > 1){ + num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; + } + else{ + num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + } + {%- else %} + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + {%- endif %} int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { @@ -1185,18 +1230,17 @@ Tensor {{ embedding_cuda_op }}( get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and - not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %} + {%- if is_optimized_hip_kernel_supported_mode %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); - const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half - || dev_weights.scalar_type() == at::ScalarType::Float; + constexpr bool supported_weights_type = std::is_same_v || std::is_same_v; + constexpr bool supported_grad_type = std::is_same_v || std::is_same_v; - if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) + if (use_hip_kernel && !mixed_D && supported_weights_type && supported_grad_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; - {%- for kDimSize in [64, 128, 160, 192, 256] %} + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} {%- for kWeightDecayMode in [0, 1, 2] %} if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) { diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2fcbba395e..cd3d645775 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -27,7 +27,7 @@ #include "fbgemm_gpu/rocm/split_embeddings_common.h" namespace fbgemm_gpu::rocm { -template +template struct rowwise_adagrad_optimizer_t { __device__ rowwise_adagrad_optimizer_t(const rowwise_adagrad_kernel_arg_t& karg_) @@ -36,7 +36,7 @@ struct rowwise_adagrad_optimizer_t } template - __device__ void update(cache_t* acc, emb_t* weight, uint32_t row_index) + __device__ void update(cache_t* acc, emb_t* weight, index_t row_index) { if constexpr(segment_split == 0) { @@ -122,20 +122,11 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const index_t* p_sorted_linear_indices_run, const int32_t* p_sorted_linear_indices_cumulative_run_lengths, const int32_t* p_sorted_linear_indices_num_runs, - {%- if not nobag %} const int32_t info_B_num_bits, const uint32_t info_B_mask, - {%- endif %} - {%- if not nobag %} const int32_t* p_sorted_infos, - {%- else %} - const int64_t* p_sorted_infos, - {%- endif %} - magic_div_u32_t batch_mdiv, uint32_t max_segment_length_per_warp, uint32_t emb_dim, - uint32_t batch, - uint32_t num_rows, uint32_t num_tables, optimizer_karg_t opt_karg, const float * p_sorted_indice_weights = nullptr) @@ -157,13 +148,9 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - {%- if nobag %} - const auto info_0 = p_sorted_infos[segment_start]; - int32_t t_0 = info_0 % num_tables; - {%- else %} const auto info_0 = reinterpret_cast(&p_sorted_infos[0])[segment_start]; const auto t_0 = info_0 >> info_B_num_bits; - {%- endif %} + int64_t hash_size = p_hash_size_cumsum[t_0]; const int64_t emb_idx = linear_index - hash_size; @@ -179,7 +166,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_length_mod = segment_length & length_mask; cache_t grad_acc[dword_per_row]; - int32_t infos[segment_unroll]; + uint32_t infos[segment_unroll]; grad_t grad_data[dword_per_row * segment_prefetch]; emb_t emb_data[dword_per_row]; float indice_weights[segment_unroll]; @@ -221,22 +208,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( // LOOP for(; itr < segment_length_mod; itr += segment_unroll) { - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ #pragma unroll @@ -244,24 +225,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -284,24 +261,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -322,22 +295,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( } // LAST - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted) { @@ -346,24 +313,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -377,24 +340,20 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -414,13 +373,10 @@ L_tail_grad_acc: infos[0] = p_sorted_infos[segment_start]; p_sorted_infos++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -435,13 +391,10 @@ L_tail_grad_acc: p_sorted_infos++; p_sorted_indice_weights++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} - load_row_per_warp::run( + + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); @@ -452,11 +405,11 @@ L_tail_grad_acc: } // load the old emb weight data - load_row_per_warp::run( + load_row_per_warp::run( &emb_data[0], emb_idx, p_emb_table, lane_id); optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); - store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); + store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); } } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu old mode 100644 new mode 100755 index aada1cdad5..526c146ad3 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -141,7 +141,7 @@ using namespace fbgemm_gpu; {%- endmacro %} {#-/* - Splitted version of load_and_accumulate macro. This code chunk describes + Split version of load_and_accumulate macro. This code chunk describes the weights load in forward kernel. Set up the WeightRow and load quantization parameters. Shortcut store for nobag mode. @@ -221,7 +221,7 @@ using namespace fbgemm_gpu; {%- endmacro %} {#-/* - Splitted version of load_and_accumulate macro. This code chunk + Split version of load_and_accumulate macro. This code chunk describes the weights accumulate step in the forward kernel. Accumulate the slices of values from the row. Does nothing for nobag mode assuming all the work is done in load() macro. diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu old mode 100644 new mode 100755 index 37e774bb49..a3edb6b965 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -31,6 +31,10 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" {%- endif %} +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations @@ -459,6 +463,16 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); + {% if is_rocm %} + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + {%- endif %} + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index ae501509c7..4660e6ad0f 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1059,7 +1059,25 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + int32_t max_segment_length_per_warp = 64; + {%- if (not nobag) and + (optimizer == "rowwise_adagrad") and + (not vbe) and + (not is_gwd) and + (not ssd) and + (not is_index_select) and + (not dense) %} + int32_t total_L = indices.numel(); + const auto T = weights_offsets.sym_numel(); + auto total_B = (offsets.size(0) - 1); + const auto B = total_B / T; + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} + if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) + { + max_segment_length_per_warp = 16384; + } + {%- endfor %} + {%- endif %} #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 597446a36b..5220e48d61 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -821,7 +821,7 @@ def __init__( # noqa C901 assert ( self.pooling_mode != PoolingMode.NONE ), "Mixed dimension tables only supported for pooling tables." - + self.mixed_D: bool = mixed_D assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" @@ -2519,6 +2519,7 @@ def forward( # noqa: C901 row_counter, iter_int, self.max_counter.item(), + mixed_D=self.mixed_D, ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: @@ -2537,6 +2538,7 @@ def forward( # noqa: C901 # `Optional[Tensor]` but got `Union[Module, Tensor]`. prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, + mixed_D=self.mixed_D, ), ) else: @@ -2546,6 +2548,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, + mixed_D=self.mixed_D, ), ) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h index b55fd72fce..447613c5fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -38,7 +38,7 @@ namespace fbgemm_gpu::rocm { [[nodiscard]] inline bool is_supported_cdna() { - const std::set supported_archs{"gfx942", "gfx90a"}; + const std::set supported_archs{"gfx942", "gfx90a", "gfx950"}; int device_id = 0; HIP_CHECK(hipGetDevice(&device_id)); hipDeviceProp_t dev_props; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b3a56c4b52..59f96a19b7 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -21,9 +21,14 @@ * ******************************************************************************/ #pragma once + +#include +#include #include + #include #include +#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -31,7 +36,7 @@ typedef float floatx2_t __attribute__((ext_vector_type(2))); #define AMDGCN_BUFFER_RES_3 0x00027000 #define AMDGCN_WAVE_SIZE 64 #define THREADS_PER_ROW 64 -#define BLOCK_SIZE 256 +#define BLOCK_SIZE_ROCM 256 namespace fbgemm_gpu::rocm { template @@ -46,10 +51,12 @@ union amdgcn_buffer_resource { }; template -__device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { +__device__ int32x4_t amdgcn_make_buffer_resource( + const T* addr, + const int32_t size_in_bytes = 0xFFFFFFFF) { amdgcn_buffer_resource buffer_resource; buffer_resource.address = const_cast(addr); - buffer_resource.range = 0xffffffff; + buffer_resource.range = size_in_bytes; buffer_resource.config = AMDGCN_BUFFER_RES_3; // for gfx9 return buffer_resource.content; @@ -59,34 +66,68 @@ __device__ int32x4_t amdgcn_make_buffer_resource(const T* addr) { __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + int32_t soffset = 0, + int32_t glc_slc = 0) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.load.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.load.f16"); +#endif __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.load.f32"); __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + int32_t soffset = 0, + int32_t glc_slc = 0) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.load.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.load.v2f16"); +#endif + +__device__ void llvm_amdgcn_raw_buffer_store_fp16( + const half vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.store.f16"); +#endif + +__device__ void llvm_amdgcn_raw_buffer_store_fp16x2( + const half2 vdata, + int32x4_t rsrc, + int32_t voffset, + int32_t soffset = 0, + int32_t glc_slc = 0) +#if ROCM_VERSION_MAJOR >= 7 + __asm("llvm.amdgcn.raw.buffer.store.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.store.v2f16"); +#endif __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.f32"); __device__ void llvm_amdgcn_raw_buffer_store_fp32x2( floatx2_t vdata, int32x4_t rsrc, int32_t voffset, - int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + int32_t soffset = 0, + int32_t glc_slc = 0) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); /******************************************************************************/ @@ -96,33 +137,17 @@ struct load_row_per_warp { emb_t* emb_data, index_t row_index, const emb_t* p_emb_table, - int lane_id) {} -}; - -template -struct load_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - if constexpr (embedding_dim == 160) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); - } else { - emb_data[i] = 0.f; - } - } else { - emb_data[i] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + i * THREADS_PER_ROW) * sizeof(float), 0, 0); - } + // Types are not supported, but we need an instance of run method to avoid + // run-time .so symbol failure. Currently, the kernel dispatch for + // unsupported type is guarded on host side + if constexpr ( + std::is_same_v || + std::is_same_v) { + __builtin_trap(); + } else { + static_assert( + false, "HIP: Optimized load operation is not supported yet"); } } }; @@ -134,7 +159,7 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 64); emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half), 0, 0); + llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half)); } }; @@ -144,8 +169,8 @@ struct load_row_per_warp { run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 128); - *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(emb_data) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); } }; @@ -153,16 +178,12 @@ template struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); - *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); - if ((lane_id + 128) % 192 < 160) { - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); - } else { - emb_data[2] = __float2half(0.0); - } + int32x4_t emb_res = amdgcn_make_buffer_resource( + p_emb_table + row_index * 160, sizeof(half) * 160); + *reinterpret_cast(emb_data) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -172,10 +193,10 @@ struct load_row_per_warp { run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 192); - *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(emb_data) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half), 0, 0); + emb_res, (lane_id + 128) * sizeof(half)); } }; @@ -186,32 +207,149 @@ struct load_row_per_warp { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 256); *reinterpret_cast(&emb_data[0]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); } }; template -struct load_row_per_warp { +struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 512); + int32x4_t emb_res = amdgcn_make_buffer_resource( + p_emb_table + row_index * 320, sizeof(half) * 320); *reinterpret_cast(&emb_data[0]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, lane_id * sizeof(half2), 0, 0); + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[4]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 2) * sizeof(half2), 0, 0); - *reinterpret_cast(&emb_data[6]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64 * 3) * sizeof(half2), 0, 0); + emb_res, (lane_id + 64) * sizeof(half2)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half)); + } +}; + +template +struct load_row_per_warp { + static __device__ void run( + c10::Half* emb_data, + index_t row_index, + const c10::Half* p_emb_table, + int lane_id) { + load_row_per_warp::run( + reinterpret_cast(emb_data), + row_index, + reinterpret_cast(p_emb_table), + lane_id); + } +}; + +template +struct load_row_per_warp { + static __device__ void run( + float* emb_data, + index_t row_index, + const float* p_emb_table, + int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 64); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void run( + float* emb_data, + index_t row_index, + const float* p_emb_table, + int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 128); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 64) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void run( + float* emb_data, + index_t row_index, + const float* p_emb_table, + int lane_id) { + int32x4_t emb_res = amdgcn_make_buffer_resource( + p_emb_table + row_index * 160, sizeof(float) * 160); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 128) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void run( + float* emb_data, + index_t row_index, + const float* p_emb_table, + int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 128) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void run( + float* emb_data, + index_t row_index, + const float* p_emb_table, + int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 256); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 192) * sizeof(float)); + } +}; + +template +struct load_row_per_warp { + static __device__ void run( + float* emb_data, + index_t row_index, + const float* p_emb_table, + int lane_id) { + int32x4_t emb_res = amdgcn_make_buffer_resource( + p_emb_table + row_index * 320, sizeof(float) * 320); + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 64) * sizeof(float)); + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 128) * sizeof(float)); + emb_data[3] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 192) * sizeof(float)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 256) * sizeof(float)); } }; @@ -233,93 +371,186 @@ struct accumulate_row_per_warp { } else { #pragma unroll for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast((float)emb_data[i] * row_weight); + if constexpr (std::is_same_v) { + acc[i] += + static_cast(__half2float(emb_data[i]) * row_weight); + } else { + acc[i] += static_cast( + static_cast(emb_data[i]) * row_weight); + } } } } }; -template +template struct store_row_per_warp { - static constexpr int dword_per_row = - (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; - static __device__ void run(output_t* acc, output_t* p_output, int lane_id) { - if constexpr (embedding_dim == 160) { - for (int i = 0; i < dword_per_row; i++) { - if ((lane_id + i * THREADS_PER_ROW) % 192 < 160) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } - } + static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id) { + // Types are not supported, but we need an instance of run method to avoid + // run-time .so symbol failure. Currently, the kernel dispatch for + // unsupported type is guarded on host function + if constexpr ( + std::is_same_v || + std::is_same_v) { + __builtin_trap(); } else { -#pragma unroll - for (int i = 0; i < dword_per_row; i++) { - p_output[lane_id + i * THREADS_PER_ROW] = acc[i]; - } + static_assert( + false, "HIP: Optimized load operation is not supported yet"); } } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); + llvm_amdgcn_raw_buffer_store_fp16(acc[0], out_res, lane_id * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - if ((lane_id + 128) % 192 < 160) { - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); - } + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = + amdgcn_make_buffer_resource(p_output, 160 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16( + acc[2], out_res, (lane_id + 128) * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), - out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float), 0, 0); + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16( + acc[2], out_res, (lane_id + 128) * sizeof(half)); } }; template <> -struct store_row_per_warp { - static __device__ void run(float* acc, float* p_output, int lane_id) { +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(acc), + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 2), out_res, - lane_id * sizeof(floatx2_t), - 0, - 0); - llvm_amdgcn_raw_buffer_store_fp32x2( - *reinterpret_cast(&acc[2]), + (lane_id + 64) * sizeof(half2)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = + amdgcn_make_buffer_resource(p_output, 320 * sizeof(half)); + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 2), out_res, - (lane_id + 64) * sizeof(floatx2_t), - 0, - 0); + (lane_id + 64) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16( + acc[4], out_res, (lane_id + 256) * sizeof(half)); + } +}; + +template +struct store_row_per_warp { + static __device__ void + run(const c10::Half* emb_data, c10::Half* p_emb_table, int lane_id) { + store_row_per_warp::run( + reinterpret_cast(emb_data), + reinterpret_cast(p_emb_table), + lane_id); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[1], out_res, (lane_id + 64) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = + amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 128) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 128) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = amdgcn_make_buffer_resource(p_output); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[3], out_res, (lane_id + 192) * sizeof(float)); + } +}; + +template <> +struct store_row_per_warp { + static __device__ void run(const float* acc, float* p_output, int lane_id) { + int32x4_t out_res = + amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); + llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[1], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[4], out_res, (lane_id + 192) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[5], out_res, (lane_id + 256) * sizeof(float)); } }; @@ -471,7 +702,7 @@ __device__ __forceinline__ void generic_dpp_reduction(data_t& result) { // of trivial operation with an option to use custom operation template __device__ __forceinline__ void dpp_reduction(data_t& result) { -#if defined(__gfx942__) || defined(__gfx90a__) +#if defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) if constexpr (std::is_same_v) { DPP_REDUCE_F16_F32(add); return; diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh old mode 100644 new mode 100755 index 0d65c4798a..d5ec2648a8 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -21,7 +21,9 @@ #include #endif #include - +#ifdef USE_ROCM +#include "fbgemm_gpu/rocm/split_embeddings_common.h" +#endif namespace { inline int get_device_sm_cnt_() { @@ -138,11 +140,19 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { +#ifdef USE_ROCM + return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); +#else #pragma unroll for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); } return val; +#endif } DEVICE_INLINE void syncwarp() { diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp old mode 100644 new mode 100755