From fe4d8f0de5f1b8df5a17673f71f0139647c5d055 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 11:57:27 +0000 Subject: [PATCH 01/28] Add gfx950 build support + fp16 fix + index type fix --- fbgemm_gpu/cmake/Hip.cmake | 8 ++++++++ .../embedding_backward_split_template.cu | 2 +- ..._backward_split_device_kernel_template.hip | 2 +- .../include/fbgemm_gpu/rocm/cdna_guard.h | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 20 ++++++++++++++++++- fbgemm_gpu/src/tbe/eeg/indices_generator.cpp | 2 +- 6 files changed, 31 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 17640b7254..2011a34c33 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -78,6 +78,14 @@ if(HIP_FOUND) list(APPEND HIP_CXX_FLAGS -mf16c) list(APPEND HIP_CXX_FLAGS -mfma) list(APPEND HIP_CXX_FLAGS -std=c++20) + list(APPEND HIP_CXX_FLAGS -g) + list(APPEND HIP_CXX_FLAGS -ggdb) + + # list(APPEND HIP_CXX_FLAGS -Wa,-adhln) + #list(APPEND HIP_CXX_FLAGS -adhln) + list(APPEND HIP_CXX_FLAGS -save-temps) + list(APPEND HIP_CXX_FLAGS -fverbose-asm) + set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) # Ask hcc to generate device code during compilation so we can use diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 76eba64c99..76a2b347d8 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1193,7 +1193,7 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) + if (use_hip_kernel && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256] %} diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 2fcbba395e..5acc61382e 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -179,7 +179,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_length_mod = segment_length & length_mask; cache_t grad_acc[dword_per_row]; - int32_t infos[segment_unroll]; + uint32_t infos[segment_unroll]; grad_t grad_data[dword_per_row * segment_prefetch]; emb_t emb_data[dword_per_row]; float indice_weights[segment_unroll]; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h index b55fd72fce..447613c5fc 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -38,7 +38,7 @@ namespace fbgemm_gpu::rocm { [[nodiscard]] inline bool is_supported_cdna() { - const std::set supported_archs{"gfx942", "gfx90a"}; + const std::set supported_archs{"gfx942", "gfx90a", "gfx950"}; int device_id = 0; HIP_CHECK(hipGetDevice(&device_id)); hipDeviceProp_t dev_props; diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index b3a56c4b52..c96da01063 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -215,6 +215,24 @@ struct load_row_per_warp { } }; +template +struct load_row_per_warp { + static __device__ void run( + c10::Half* emb_data, + index_t row_index, + const c10::Half* p_emb_table, + int lane_id) { + load_row_per_warp::run( + reinterpret_cast(emb_data), + row_index, + reinterpret_cast(p_emb_table), + lane_id + ); + } + +}; + + template < typename emb_t, int32_t embedding_dim, @@ -471,7 +489,7 @@ __device__ __forceinline__ void generic_dpp_reduction(data_t& result) { // of trivial operation with an option to use custom operation template __device__ __forceinline__ void dpp_reduction(data_t& result) { -#if defined(__gfx942__) || defined(__gfx90a__) +#if defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) if constexpr (std::is_same_v) { DPP_REDUCE_F16_F32(add); return; diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp index 64de6489be..b4e67570b6 100644 --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,7 +131,7 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( - std::execution::par, + // std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs, From 2006f081190e0b9acf7884a6ceca690dcae7dcbf Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 13:16:41 +0000 Subject: [PATCH 02/28] Change int64_t to index_t as template parameters in load_raw_per_warp --- .../rocm/embedding_backward_split_device_kernel_template.hip | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 5acc61382e..d5841d6e00 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -452,7 +452,7 @@ L_tail_grad_acc: } // load the old emb weight data - load_row_per_warp::run( + load_row_per_warp::run( &emb_data[0], emb_idx, p_emb_table, lane_id); optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx); From 757a2f4704a682134028389605785672e0e232a5 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 29 Jul 2025 14:39:22 +0000 Subject: [PATCH 03/28] Implement llvm fp16 buffer load for gfx950 --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index c96da01063..4b33fd1422 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -60,7 +60,12 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i16"); +#else + __asm("llvm.amdgcn.raw.buffer.load.f16"); +#endif __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, @@ -72,7 +77,12 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32x4_t srsrc, int32_t voffset, int32_t soffset, - int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + int32_t glc_slc) +#if defined(__gfx950__) + __asm("llvm.amdgcn.raw.buffer.load.i32"); +#else + __asm("llvm.amdgcn.raw.buffer.load.v2f16"); +#endif __device__ void llvm_amdgcn_raw_buffer_store_fp32( float vdata, From f875d5485abbbd7ee6485b138f2728d67806b601 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 11 Aug 2025 08:23:47 +0000 Subject: [PATCH 04/28] Fix c-style half to float cast --- .../include/fbgemm_gpu/rocm/split_embeddings_common.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 4b33fd1422..238a83440a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -261,7 +261,14 @@ struct accumulate_row_per_warp { } else { #pragma unroll for (int i = 0; i < dword_per_row; i++) { - acc[i] += static_cast((float)emb_data[i] * row_weight); + if constexpr (std::is_same_v) + { + acc[i] += static_cast(__half2float(emb_data[i]) * row_weight); + } + else + { + acc[i] += static_cast(static_cast(emb_data[i]) * row_weight); + } } } } From ea9d8f8d364445eca729c3c7bb07c0673c92dda6 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Mon, 11 Aug 2025 08:24:29 +0000 Subject: [PATCH 05/28] Patch 256 half stores --- .../include/fbgemm_gpu/rocm/split_embeddings_common.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 238a83440a..974eae2594 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -294,6 +294,16 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + out[lane_id + 64] = *reinterpret_cast(&acc[2]); + } +}; + + template <> struct store_row_per_warp { static __device__ void run(float* acc, float* p_output, int lane_id) { From e63ead2c7cae4c3afe2ef8936ed1727ad7ed82f1 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Fri, 8 Aug 2025 05:02:58 +0000 Subject: [PATCH 06/28] cta_per_row workgroup optim --- .../training/backward/embedding_backward_split_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 76a2b347d8..9412edc1a5 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1042,7 +1042,7 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + int32_t num_cta_per_row_groups = (kMaxThreads/2) / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1053,7 +1053,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, kMaxThreads), + div_round_up(total_unique_indices, (kMaxThreads/2)), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( From b9eebb40e4d30ceb31ef2ee20ede3d83da7433e6 Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 11 Aug 2025 21:06:48 +0000 Subject: [PATCH 07/28] Added mi350 guards --- ...ding_backward_split_indice_weights_template.cu | 15 ++++++++++++++- .../backward/embedding_backward_split_template.cu | 10 ++++++++++ .../forward/embedding_forward_split_template.cu | 14 ++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) mode change 100644 => 100755 fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu mode change 100644 => 100755 fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu mode change 100644 => 100755 fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu old mode 100644 new mode 100755 index 1afb2943bb..8f190d04d2 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -23,6 +23,10 @@ #include "fbgemm_gpu/utils/assert_macros.h" #include "fbgemm_gpu/utils/kernel_launcher.cuh" +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -359,7 +363,16 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); CUDA_DEVICE_GUARD(dev_weights); - + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu old mode 100644 new mode 100755 index 9412edc1a5..9e9e7aac68 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -652,6 +652,16 @@ Tensor {{ embedding_cuda_op }}( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if nobag and not is_index_select %} auto max_D = D; {%- endif %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu old mode 100644 new mode 100755 index 6574bda45e..bbd62a8bbc --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -31,6 +31,10 @@ #include "fbgemm_gpu/utils/dispatch_macros.h" {%- endif %} +{%- if is_rocm %} +#include "fbgemm_gpu/rocm/cdna_guard.h" +{%- endif %} + {%- if not is_index_select %} //////////////////////////////////////////////////////////////////////////////// // Required for op registrations @@ -454,6 +458,16 @@ batch_index_select_dim0_codegen_forward_cuda( CUDA_DEVICE_GUARD(dev_weights); + #ifdef USE_ROCM + if (!rocm::is_supported_cdna()) { + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); + } + else { + // Ensure we're running on a supported CDNA architecture (including MI350) + TORCH_WARN_ONCE("Running on CDNA architecture"); + } + #endif + {%- if not nobag %} int32_t T = D_offsets.numel() - 1; {%- else %} From 69ae10ebe577612174c4433b70b58ca8e44476c8 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Tue, 12 Aug 2025 15:09:39 +0000 Subject: [PATCH 08/28] Fix index overflow in row load --- ..._backward_split_device_kernel_template.hip | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index d5841d6e00..d1a874805a 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -227,7 +227,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); {%- if nobag %} @@ -236,7 +236,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ #pragma unroll @@ -250,7 +250,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -261,7 +261,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -290,7 +290,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -301,7 +301,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -328,7 +328,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); {%- if nobag %} @@ -337,7 +337,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted) { @@ -352,7 +352,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -363,7 +363,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -383,7 +383,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -394,7 +394,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -420,7 +420,7 @@ L_tail_grad_acc: table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -441,7 +441,7 @@ L_tail_grad_acc: table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; {%- endif %} - load_row_per_warp::run( + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); From 8f692dce242d5836f1746967487ba5e5945e5561 Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Tue, 12 Aug 2025 20:13:09 +0000 Subject: [PATCH 09/28] cta_per_row workgroup reduce by 4 optim --- .../training/backward/embedding_backward_split_template.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 9e9e7aac68..c59f6fe9aa 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1052,7 +1052,7 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = (kMaxThreads/2) / kWarpSize; + int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1063,7 +1063,7 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, (kMaxThreads/2)), + div_round_up(total_unique_indices, (kMaxThreads/4)), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( From 768dc01584473b5d764fa27c35025e883893b9a1 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 13 Aug 2025 13:21:38 +0000 Subject: [PATCH 10/28] Fix mixed_D frontend to backend connection --- .../training/backward/embedding_backward_split_template.cu | 2 +- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 1 + .../split_table_batched_embeddings_ops_training.py | 5 ++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index c59f6fe9aa..c8a846a552 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1203,7 +1203,7 @@ Tensor {{ embedding_cuda_op }}( const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half || dev_weights.scalar_type() == at::ScalarType::Float; - if (use_hip_kernel && supported_weights_type && rocm::is_supported_cdna()) + if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; {%- for kDimSize in [64, 128, 160, 192, 256] %} diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 3720f1ea42..20c055e917 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -698,6 +698,7 @@ class {{ autograd_func }} : TORCH_CHECK(aux_tensor[IDX_LXU_CACHE_LOCATIONS].has_value(), "lxu_cache_locations should have value."); const auto lxu_cache_locations = aux_tensor[IDX_LXU_CACHE_LOCATIONS].value(); const auto is_experimental = aux_bool[IDX_IS_EXPERIMENTAL_TBE]; + const auto mixed_D = aux_bool[IDX_MIXED_D]; {%- endif %} // Default values for Dynamo tracing diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 1b45f7f147..fc96386d48 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -813,7 +813,7 @@ def __init__( # noqa C901 assert ( self.pooling_mode != PoolingMode.NONE ), "Mixed dimension tables only supported for pooling tables." - + self.mixed_D = mixed_D assert all( cd == compute_devices[0] for cd in compute_devices ), "Heterogenous compute_devices are NOT supported!" @@ -2314,6 +2314,7 @@ def forward( # noqa: C901 row_counter, iter_int, self.max_counter.item(), + mixed_D=self.mixed_D, ), ) elif self._used_rowwise_adagrad_with_global_weight_decay: @@ -2332,6 +2333,7 @@ def forward( # noqa: C901 # `Optional[Tensor]` but got `Union[Module, Tensor]`. prev_iter_dev=self.prev_iter_dev, gwd_lower_bound=self.gwd_lower_bound, + mixed_D=self.mixed_D, ), ) else: @@ -2341,6 +2343,7 @@ def forward( # noqa: C901 common_args, self.optimizer_args, momentum1, + mixed_D=self.mixed_D, ), ) From 151d2dd3fbbd8afcbc09d95a780410e3543ff391 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Fri, 15 Aug 2025 15:32:19 +0000 Subject: [PATCH 11/28] changed max_segment_length_per_cta to 4096 --- .../training/backward/embedding_backward_split_template.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index c8a846a552..1ddcea55b2 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -980,7 +980,7 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { From 54e0e24b280b8c16f5ffdbd774051981cf1822fc Mon Sep 17 00:00:00 2001 From: shbiswas834 Date: Mon, 18 Aug 2025 22:32:58 +0000 Subject: [PATCH 12/28] added rocm guards and removed comment --- .../embedding_backward_split_template.cu | 19 ++++++++++++++++--- fbgemm_gpu/src/tbe/eeg/indices_generator.cpp | 1 - 2 files changed, 16 insertions(+), 4 deletions(-) mode change 100644 => 100755 fbgemm_gpu/src/tbe/eeg/indices_generator.cpp diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 1ddcea55b2..099c7e5685 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -980,7 +980,11 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + #ifdef USE_ROCM + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; + #else + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; + #endif Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { @@ -1052,7 +1056,11 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); - int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + #ifdef USE_ROCM + int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + #else + int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + #endif const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, [&] (int num_groups) { @@ -1063,7 +1071,12 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - div_round_up(total_unique_indices, (kMaxThreads/4)), + #ifdef USE_ROCM + div_round_up(total_unique_indices, (kMaxThreads/4)), + #else + div_round_up(total_unique_indices, kMaxThreads), + #endif + get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( diff --git a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp old mode 100644 new mode 100755 index b4e67570b6..7c22337bc3 --- a/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp +++ b/fbgemm_gpu/src/tbe/eeg/indices_generator.cpp @@ -131,7 +131,6 @@ torch::Tensor IndicesGenerator::generate() { // Now sort the indices by their tags. Use parallel sort for some extra speed // (vector is very large). std::sort( - // std::execution::par, std::begin(indicesWithTags), std::end(indicesWithTags), [](const std::pair& lhs, From 7b2684c02c6ff0aa157ffa122835e69711c11e30 Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 20 Aug 2025 03:00:56 +0000 Subject: [PATCH 13/28] clean debug statements in Hip.cmake --- fbgemm_gpu/cmake/Hip.cmake | 8 -------- 1 file changed, 8 deletions(-) diff --git a/fbgemm_gpu/cmake/Hip.cmake b/fbgemm_gpu/cmake/Hip.cmake index 2011a34c33..17640b7254 100644 --- a/fbgemm_gpu/cmake/Hip.cmake +++ b/fbgemm_gpu/cmake/Hip.cmake @@ -78,14 +78,6 @@ if(HIP_FOUND) list(APPEND HIP_CXX_FLAGS -mf16c) list(APPEND HIP_CXX_FLAGS -mfma) list(APPEND HIP_CXX_FLAGS -std=c++20) - list(APPEND HIP_CXX_FLAGS -g) - list(APPEND HIP_CXX_FLAGS -ggdb) - - # list(APPEND HIP_CXX_FLAGS -Wa,-adhln) - #list(APPEND HIP_CXX_FLAGS -adhln) - list(APPEND HIP_CXX_FLAGS -save-temps) - list(APPEND HIP_CXX_FLAGS -fverbose-asm) - set(HIP_HCC_FLAGS ${HIP_CXX_FLAGS}) # Ask hcc to generate device code during compilation so we can use From 9b22e17e72be0fd6a7711d00a91505384ba08ac7 Mon Sep 17 00:00:00 2001 From: Shreya Date: Thu, 28 Aug 2025 11:43:32 -0500 Subject: [PATCH 14/28] Merge pull request #121 warp per row wg change --- .../embedding_backward_split_template.cu | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 099c7e5685..2425322948 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1056,10 +1056,21 @@ Tensor {{ embedding_cuda_op }}( // Compute shared memory size for cta_per_row constexpr auto kCacheAccBytes = sizeof(at::acc_type); + int32_t total_L = indices.numel(); #ifdef USE_ROCM - int32_t num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + int32_t num_cta_per_row_groups; + int32_t work_group_size; + if (total_L/total_B > 1){ + num_cta_per_row_groups = (kMaxThreads/4) / kWarpSize; + work_group_size = (kMaxThreads/4); + } + else{ + num_cta_per_row_groups = kMaxThreads / kWarpSize; + work_group_size = kMaxThreads; + } #else int32_t num_cta_per_row_groups = kMaxThreads / kWarpSize; + int32_t work_group_size = kMaxThreads; #endif const size_t cta_per_row_smem_bytes = compute_num_groups_and_dynamic_smem_bytes( &num_cta_per_row_groups, @@ -1071,17 +1082,13 @@ Tensor {{ embedding_cuda_op }}( ); const int32_t cta_per_row_grid_size = std::min( - #ifdef USE_ROCM - div_round_up(total_unique_indices, (kMaxThreads/4)), - #else - div_round_up(total_unique_indices, kMaxThreads), - #endif - + div_round_up(total_unique_indices, work_group_size), get_max_thread_blocks_()); FBGEMM_LAUNCH_KERNEL( backward_cta_per_row_kernel, cta_per_row_grid_size, + // (64, 2) dim3(kThreadGroupSize, num_cta_per_row_groups), cta_per_row_smem_bytes, at::cuda::getCurrentCUDAStream(), @@ -1185,7 +1192,18 @@ Tensor {{ embedding_cuda_op }}( kUseVecBlocking>; // Compute shared memory size for warp_per_row - int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + #ifdef USE_ROCM + int32_t num_warp_per_row_groups; + + if (total_L/total_B > 1){ + num_warp_per_row_groups = (kBackwardMaxThreads/2) / kThreadGroupSize; + } + else{ + num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + } + #else + int32_t num_warp_per_row_groups = kBackwardMaxThreads / kThreadGroupSize; + #endif int32_t warp_per_row_smem_bytes = 0; if constexpr (kUseVecBlocking) { From 9adc6bcbb06c793c26a82fea6a7753718ac03247 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 2 Sep 2025 09:25:03 +0000 Subject: [PATCH 15/28] Guard f16 llvm intrinsics with ROCm >=7.0 --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 974eae2594..46c4603381 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -24,6 +24,7 @@ #include #include #include +#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -61,7 +62,7 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i16"); #else __asm("llvm.amdgcn.raw.buffer.load.f16"); @@ -78,7 +79,7 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if defined(__gfx950__) +#if ROCM_VERSION_MAJOR >= 7 __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); From dc16185436cec4b835c4f4a9c3127bcd09cf931f Mon Sep 17 00:00:00 2001 From: Li Li Date: Thu, 18 Sep 2025 16:28:31 +0000 Subject: [PATCH 16/28] fix the bug in dimention 160 in ROCm optimization --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 46c4603381..8a97579d6a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -165,7 +165,7 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); + amdgcn_make_buffer_resource(p_emb_table + row_index * 160); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, lane_id * sizeof(half2), 0, 0); if ((lane_id + 128) % 192 < 160) { From 00c19144b19f2ad946a32144a446d5b8d9cccab5 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 19 Aug 2025 13:41:17 +0000 Subject: [PATCH 17/28] Cleanup optimized warp_per_raw kernel --- fbgemm_gpu/cmake/tbe_sources.py | 2 - .../genscript/generate_backward_split.py | 10 +- ...ing_backward_split_kernel_warp_template.cu | 40 +++----- .../embedding_backward_split_template.cu | 18 ++-- ..._backward_split_device_kernel_template.hip | 94 +++++-------------- 5 files changed, 54 insertions(+), 110 deletions(-) diff --git a/fbgemm_gpu/cmake/tbe_sources.py b/fbgemm_gpu/cmake/tbe_sources.py index 31200b6190..dc3acace35 100644 --- a/fbgemm_gpu/cmake/tbe_sources.py +++ b/fbgemm_gpu/cmake/tbe_sources.py @@ -176,7 +176,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( @@ -495,7 +494,6 @@ "_nobag" if nobag else "", ) for nobag in [ - True, False, ] for weighted in ( diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index a817232910..5acb6f2e7f 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -52,7 +52,11 @@ def render_backward_templates( return weighted_options = [True, False] - nobag_options = [True, False] if (not is_gwd) else [False] + nobag_options = ( + [True, False] + if (not (is_gwd or kwargs.get("is_hip_optimized_backward"))) + else [False] + ) vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False] ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False] template = CodeTemplate.load(template_filepath) @@ -327,8 +331,7 @@ def generate_backward_indices() -> None: @staticmethod def generate_rocm_backward_split(**kwargs: Any) -> None: - # Generate backward device kernels based on weighted (True/False), VBE - # (True/False), no bag (True/False) + # Generate backward device kernels based on weighted (True/False) template_filepath = ( "training/backward/rocm/embedding_backward_split_device_kernel_template.hip" ) @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None: "has_ssd_support": False, "dense": False, "gen_once": False, + "is_hip_optimized_backward": True, }, ) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 5137b5766c..1158721526 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -32,6 +32,14 @@ {%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %} {%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" @@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row {%- endif %} -{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include #include #include "fbgemm_gpu/rocm/split_embeddings_common.h" @@ -612,12 +620,8 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }} {%- endif %} ) { - {%- if not nobag %} int32_t T = D_offsets.size(0) - 1; - {%- else %} - int32_t T = weights_offsets.size(0); - {%- endif %} - + auto p_output_grad = grad_output.data(); auto p_emb_table = dev_weights.data(); auto p_hash_size_cumsum = hash_size_cumsum.data(); @@ -632,8 +636,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd constexpr int32_t segment_prefetch = 2; constexpr int32_t segment_unroll = 8; constexpr int32_t segment_split = 0; - auto batch = grad_output.size(0); - auto num_rows = dev_weights.size(0) / T / max_D; {%- if weighted %} constexpr bool is_weighted = true; {%- else %} @@ -646,22 +648,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd // weight_decay(_mode) is supplied as args.split_function_args_no_defaults opt_karg.weight_decay_mode = weight_decay_mode_v; opt_karg.weight_decay = weight_decay; - auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t { - assert(d >= 1 && d <= INT32_MAX); - uint8_t shift; - for(shift = 0; shift < 32; shift++) - if((1U << shift) >= d) - break; - - uint64_t one = 1; - uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; - assert(magic <= 0xffffffffUL); - - rocm::magic_div_u32_t result; - result.magic = magic; - result.shift = shift; - return result; - }(batch); + rocm::split_tbe_backward_hip_kernel_{{kdesc}}< rocm::{{optimizer}}_optimizer_t, rocm::{{optimizer}}_kernel_arg_t, @@ -680,16 +667,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd p_sorted_linear_indices_run, p_sorted_linear_indices_cumulative_run_lengths, p_sorted_linear_indices_num_runs, - {%- if not nobag %} info_B_num_bits, info_B_mask, - {%- endif %} p_sorted_infos, - batch_mdiv, max_segment_length_per_warp, emb_dim, - batch, - num_rows, T, opt_karg {%- if weighted %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 2425322948..fb125101e7 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -48,6 +48,15 @@ using namespace fbgemm_gpu; has_global_weight_decay_support, ssd) %} {%- set desc_suffix = get_desc_suffix(is_gwd_kernel) %} +{%- set is_optimized_hip_kernel_supported_mode = is_rocm and + optimizer == "rowwise_adagrad" and + not dense and + not nobag and + not is_index_select and + not is_gwd_kernel and + not vbe and + not ssd %} + template < typename emb_t, typename grad_t, @@ -227,8 +236,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} ); -{%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} +{%- if is_optimized_hip_kernel_supported_mode %} #include "fbgemm_gpu/rocm/split_embeddings_common.h" template < typename emb_t, @@ -862,8 +870,7 @@ Tensor {{ embedding_cuda_op }}( } {%- endif %} - {%- if is_rocm and optimizer == "rowwise_adagrad" and not dense and not is_index_select - and not is_gwd_kernel and not vbe and not ssd %} + {%- if is_optimized_hip_kernel_supported_mode %} {%- set hip_kernel = "hip_split_embedding{}_backward_codegen_{}_{}{}_kernel_warp_per_row_1".format( ndesc, optimizer, @@ -1226,8 +1233,7 @@ Tensor {{ embedding_cuda_op }}( get_max_thread_blocks_()); #ifdef USE_ROCM - {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and - not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %} + {%- if is_optimized_hip_kernel_supported_mode %} const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index d1a874805a..951cff4399 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -122,20 +122,11 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const index_t* p_sorted_linear_indices_run, const int32_t* p_sorted_linear_indices_cumulative_run_lengths, const int32_t* p_sorted_linear_indices_num_runs, - {%- if not nobag %} const int32_t info_B_num_bits, const uint32_t info_B_mask, - {%- endif %} - {%- if not nobag %} const int32_t* p_sorted_infos, - {%- else %} - const int64_t* p_sorted_infos, - {%- endif %} - magic_div_u32_t batch_mdiv, uint32_t max_segment_length_per_warp, uint32_t emb_dim, - uint32_t batch, - uint32_t num_rows, uint32_t num_tables, optimizer_karg_t opt_karg, const float * p_sorted_indice_weights = nullptr) @@ -157,13 +148,9 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id]; const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1]; - {%- if nobag %} - const auto info_0 = p_sorted_infos[segment_start]; - int32_t t_0 = info_0 % num_tables; - {%- else %} const auto info_0 = reinterpret_cast(&p_sorted_infos[0])[segment_start]; const auto t_0 = info_0 >> info_B_num_bits; - {%- endif %} + int64_t hash_size = p_hash_size_cumsum[t_0]; const int64_t emb_idx = linear_index - hash_size; @@ -221,21 +208,15 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( // LOOP for(; itr < segment_length_mod; itr += segment_unroll) { - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); if constexpr (!weighted){ @@ -244,23 +225,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -284,23 +261,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -322,21 +295,16 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( } // LAST - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[1], batch, table_index, bag_index); - {%- else %} + table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); @@ -346,23 +314,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -377,23 +341,19 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( { accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[j-2]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j], batch, table_index, bag_index); - {%- else %} + table_index = infos[j] >> info_B_num_bits; bag_index = infos[j] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[j + 1], batch, table_index, bag_index); - {%- else %} + table_index = infos[j + 1] >> info_B_num_bits; bag_index = infos[j + 1] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); } @@ -414,12 +374,9 @@ L_tail_grad_acc: infos[0] = p_sorted_infos[segment_start]; p_sorted_infos++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( @@ -435,12 +392,9 @@ L_tail_grad_acc: p_sorted_infos++; p_sorted_indice_weights++; - {%- if nobag %} - magic_div_u32_run_with_mod(batch_mdiv, infos[0], batch, table_index, bag_index); - {%- else %} table_index = infos[0] >> info_B_num_bits; bag_index = infos[0] & info_B_mask; - {%- endif %} + load_row_per_warp::run( &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); accumulate_row_per_warp::run( From f2c662ace41550b3934257ef3bc6805e26618e50 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Wed, 20 Aug 2025 12:15:37 +0000 Subject: [PATCH 18/28] Add 320 embedding dim support for optimized warp_per_row kernel --- ...ing_backward_split_kernel_warp_template.cu | 2 +- .../embedding_backward_split_template.cu | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 26 +++++++++++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 1158721526..e61b3fc0aa 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -766,7 +766,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256] %} + {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index fb125101e7..7eb2b6880f 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1243,7 +1243,7 @@ Tensor {{ embedding_cuda_op }}( if (use_hip_kernel && !mixed_D && supported_weights_type && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; - {%- for kDimSize in [64, 128, 160, 192, 256] %} + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} {%- for kWeightDecayMode in [0, 1, 2] %} if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) { diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 8a97579d6a..5b9d69d910 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -205,6 +205,22 @@ struct load_row_per_warp { } }; +template +struct load_row_per_warp { + static __device__ void + run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * 320); + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, lane_id * sizeof(half2), 0, 0); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64) * sizeof(half2), 0, 0); + emb_data[4] = p_emb_table[row_index * 320 + 256 + lane_id]; + } +}; + template struct load_row_per_warp { static __device__ void @@ -304,6 +320,16 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + out[lane_id + 64] = *reinterpret_cast(&acc[2]); + p_output[lane_id + 256] = acc[4]; + } +}; + template <> struct store_row_per_warp { From f8fe9d787f2650285429908324deae58e78c65b0 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 8 Sep 2025 19:34:16 +0000 Subject: [PATCH 19/28] changed the max length per warp and cta per row WG size --- .../backward/embedding_backward_split_host_template.cpp | 2 +- .../training/backward/embedding_backward_split_template.cu | 6 +----- .../training/index_select/batch_index_select_dim0_host.cpp | 2 +- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 2 +- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index e071d88768..6d3769534e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -949,7 +949,7 @@ class {{ autograd_func }} : #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + constexpr int32_t max_segment_length_per_warp = 4096; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 7eb2b6880f..86d4ce8b8b 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -987,11 +987,7 @@ Tensor {{ embedding_cuda_op }}( auto num_long_run_ids = at::zeros({1}, indices.options().dtype(at::kInt)); const bool use_deterministic_algorithms = at::globalContext().deterministicAlgorithms(); - #ifdef USE_ROCM - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; - #else - const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 1024; - #endif + const int max_segment_length_per_cta = use_deterministic_algorithms ? INT_MAX : 4096; Tensor long_run_id_to_really_long_run_ids; if (use_deterministic_algorithms) { diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index 06cd53b16b..fd43cc18a6 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -656,7 +656,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 32; + constexpr int32_t max_segment_length_per_warp = 4096; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 20c055e917..46384be1bb 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1006,7 +1006,7 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 64; + constexpr int32_t max_segment_length_per_warp = 4096; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; From 279aeac088b9cb029ec1e090b388cae051fa490c Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 9 Sep 2025 20:25:30 +0000 Subject: [PATCH 20/28] added DPP and changed max length per warp to 16k --- .../embedding_backward_split_host_template.cpp | 2 +- .../index_select/batch_index_select_dim0_host.cpp | 4 ++-- .../embedding_split_host_pt2_autograd_template.cpp | 2 +- .../include/fbgemm_gpu/utils/cuda_prelude.cuh | 14 ++++++++------ 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp index 6d3769534e..05b93d9d7e 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp @@ -949,7 +949,7 @@ class {{ autograd_func }} : #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 4096; + constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index fd43cc18a6..37c5ce7cc0 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -341,7 +341,7 @@ class BatchIndexSelectDim0GPUOp Tensor grad_dev_weights; TORCH_CHECK_EQ(grad_outputs.size(), 1); - constexpr int32_t max_segment_length_per_warp = 32; + constexpr int32_t max_segment_length_per_warp = 16384; auto grad_output = grad_outputs[0]; @@ -656,7 +656,7 @@ class BatchIndexSelectDim0TensorGPUOp const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); - constexpr int32_t max_segment_length_per_warp = 4096; + constexpr int32_t max_segment_length_per_warp = 16384; auto grad_output = grad_outputs[0]; diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 46384be1bb..8fb2cdf2ed 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1006,7 +1006,7 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 4096; + constexpr int32_t max_segment_length_per_warp = 16384; #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh index 0d65c4798a..a1d9819017 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -21,7 +21,9 @@ #include #endif #include - +#ifdef USE_ROCM +#include "fbgemm_gpu/rocm/split_embeddings_common.h" +#endif namespace { inline int get_device_sm_cnt_() { @@ -138,11 +140,11 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { -#pragma unroll - for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { - val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); - } - return val; + return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); } DEVICE_INLINE void syncwarp() { From d59e3d6e3162c57e47cb6ef4f8faa269e7a22c39 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Wed, 10 Sep 2025 19:33:44 +0000 Subject: [PATCH 21/28] guard max segment warp based on emb dim --- ...dding_split_host_pt2_autograd_template.cpp | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 8fb2cdf2ed..c587ccb83a 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1006,7 +1006,25 @@ static torch::autograd::variable_list backward( #ifdef USE_ROCM constexpr int32_t BT_block_size = 64; - constexpr int32_t max_segment_length_per_warp = 16384; + int32_t max_segment_length_per_warp = 64; + // Workaround. Should not be upstreamed in any way. + // Redistribute all cta_per_row work to warp_per_row. + {%- if (not nobag) and + (optimizer == "rowwise_adagrad") and + (not vbe) and + (not is_gwd) and + (not ssd) and + (not is_index_select) and + (not dense) %} + const auto T = weights_offsets.sym_numel(); + const auto B = (offsets.size(0) - 1) / T; + {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} + if(!mixed_D && (max_D == {{ kDimSize }})) + { + max_segment_length_per_warp = 16384; + } + {%- endfor %} + {%- endif %} #else constexpr int32_t BT_block_size = 32; constexpr int32_t max_segment_length_per_warp = 32; From faf378e1559bbc3fecf928c1ea1bade51f172a41 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Wed, 10 Sep 2025 22:00:20 +0000 Subject: [PATCH 22/28] added guarding opt of max segment for the case batch size list=1 --- .../pt2/embedding_split_host_pt2_autograd_template.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index c587ccb83a..fa6a27ab55 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -1009,6 +1009,7 @@ static torch::autograd::variable_list backward( int32_t max_segment_length_per_warp = 64; // Workaround. Should not be upstreamed in any way. // Redistribute all cta_per_row work to warp_per_row. + int32_t total_L = indices.numel(); {%- if (not nobag) and (optimizer == "rowwise_adagrad") and (not vbe) and @@ -1017,9 +1018,10 @@ static torch::autograd::variable_list backward( (not is_index_select) and (not dense) %} const auto T = weights_offsets.sym_numel(); - const auto B = (offsets.size(0) - 1) / T; + auto total_B = (offsets.size(0) - 1); + const auto B = total_B / T; {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} - if(!mixed_D && (max_D == {{ kDimSize }})) + if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) { max_segment_length_per_warp = 16384; } From d9239e9da002d45662efc80f492fc82efc6788f2 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Sep 2025 09:26:57 +0000 Subject: [PATCH 23/28] opt for grad_indice_weights kernel --- ..._backward_split_indice_weights_template.cu | 77 ++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index 8f190d04d2..e24a812e8b 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -214,7 +214,82 @@ __global__ __launch_bounds__(kForwardMaxThreads) void ) {%- endif %} - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { + int32_t j = 0; + {%- if not ssd and not dense and not use_vec_blocking and not vbe %} + // Currently for split_embedding_codegen_grad_indice_weights_kernel only + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + const auto cache_idx_j0 = shfl_sync(cache_idx, j); + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + [[maybe_unused]] const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + [[maybe_unused]] const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + [[maybe_unused]] const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + [[maybe_unused]] const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + if (placement == PlacementType::MANAGED_CACHING) { + weight0 = (cache_idx_j0 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : + weight_row0.load(d); + + weight1 = (cache_idx_j1 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j1][d]) : + weight_row1.load(d); + + weight2 = (cache_idx_j2 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j2][d]) : + weight_row2.load(d); + + weight3 = (cache_idx_j3 != kCacheLocationMissing) ? + Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : + weight_row3.load(d); + } else { + weight0 = weight_row0.load(d); + weight1 = weight_row1.load(d); + weight2 = weight_row2.load(d); + weight3 = weight_row3.load(d); + } + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); + + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + {%- endif %} + for (; j < kWarpSize && l_start + j < L; ++j) { const auto offset_idx_j = shfl_sync(offset_idx, j); {%- if not dense %} const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); From 145a6737b3651447b30516e1f095bd54d9bca9f6 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 23 Sep 2025 02:09:26 +0000 Subject: [PATCH 24/28] added store row per warp on emb 192 and added accuracy test functionality --- ...plit_table_batched_embeddings_benchmark.py | 223 +++++++++++++----- fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 125 ++++++++-- .../fbgemm_gpu/rocm/split_embeddings_common.h | 18 +- 3 files changed, 277 insertions(+), 89 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py index 4ffb7341a5..3fad8f53fe 100644 --- a/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py @@ -7,7 +7,8 @@ # pyre-strict - +import gzip +import yaml import logging import os import tempfile @@ -1011,7 +1012,15 @@ def context_factory(on_trace_ready: Callable[[profile], None]): @TbeBenchClickInterface.common_options @TbeBenchClickInterface.device_options @TbeBenchClickInterface.vbe_options +@click.option("--save", type=str, default=None) +@click.option("--load", type=str, default=None) +@click.option("--random-weights", is_flag=True, default=False) +@click.option("--compressed", is_flag=True, default=False) +@click.option("--slice-min", type=int, default=None) +@click.option("--slice-max", type=int, default=None) +@click.pass_context def device_with_spec( # noqa C901 + ctx, alpha: float, bag_size_list: str, bag_size_sigma_list: str, @@ -1031,7 +1040,39 @@ def device_with_spec( # noqa C901 bounds_check_mode: int, flush_gpu_cache_size_mb: int, output_dtype: SparseType, + save: str, + load: str, + random_weights: bool, + compressed: bool, + slice_min: int, + slice_max: int, ) -> None: + if load: + with open(f"{load}/params.yaml", "r") as f: + ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader) + alpha = ctx.params["alpha"] + bag_size_list = ctx.params["bag_size_list"] + bag_size_sigma_list = ctx.params["bag_size_sigma_list"] + batch_size = ctx.params["batch_size"] + embedding_dim_list = ctx.params["embedding_dim_list"] + weights_precision = ctx.params["weights_precision"] + cache_precision = ctx.params["cache_precision"] + stoc = ctx.params["stoc"] + iters = ctx.params["iters"] + warmup_runs = ctx.params["warmup_runs"] + managed = ctx.params["managed"] + num_embeddings_list = ctx.params["num_embeddings_list"] + reuse = ctx.params["reuse"] + row_wise = ctx.params["row_wise"] + weighted = ctx.params["weighted"] + pooling = ctx.params["pooling"] + bounds_check_mode = ctx.params["bounds_check_mode"] + flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"] + output_dtype = ctx.params["output_dtype"] + random_weights = ctx.params["random_weights"] + compressed = ctx.params["compressed"] + slice_min = ctx.params["slice_min"] + slice_max = ctx.params["slice_max"] np.random.seed(42) torch.manual_seed(42) B = batch_size @@ -1040,6 +1081,11 @@ def device_with_spec( # noqa C901 T = len(Ds) use_variable_bag_sizes = bag_size_sigma_list != "None" + params = ctx.params + if save: + os.makedirs(f"{save}", exist_ok=True) + with open(f"{save}/params.yaml", "w") as f: + yaml.dump(params, f, sort_keys=False) if use_variable_bag_sizes: Ls = [int(mu) for mu in bag_size_list.split(",")] @@ -1118,6 +1164,22 @@ def device_with_spec( # noqa C901 if weights_precision == SparseType.INT8: emb.init_embedding_weights_uniform(-0.0003, 0.0003) + elif random_weights: + emb.init_embedding_weights_uniform(-1.0, 1.0) + + if save: + if compressed: + with gzip.open(f"{save}/model_state.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/model_state.pth") + + if load: + if compressed: + with gzip.open(f"{load}/model_state.pth.gz", "rb") as f: + emb.load_state_dict(torch.load(f)) + else: + emb.load_state_dict(torch.load(f"{load}/model_state.pth")) nparams = sum(w.numel() for w in emb.split_embedding_weights()) param_size_multiplier = weights_precision.bit_rate() / 8.0 @@ -1130,53 +1192,68 @@ def device_with_spec( # noqa C901 "weights": [[] for _ in range(iters)], } # row = iter, column = tensor - for t, e in enumerate(Es): - # (indices, offsets, weights) - requests = generate_requests( - iters, - B, - 1, - Ls[t], - e, - reuse=reuse, - alpha=alpha, - weighted=weighted, - # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. - sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, - zipf_oversample_ratio=3 if Ls[t] > 5 else 5, - use_cpu=get_available_compute_device() == ComputeDevice.CPU, - index_dtype=torch.long, - offset_dtype=torch.long, - ) - for i, req in enumerate(requests): - indices, offsets, weights = req.unpack_3() - all_requests["indices"][i].append(indices) - if t > 0: - offsets = offsets[1:] # remove the first element - offsets += all_requests["offsets"][i][t - 1][-1] - all_requests["offsets"][i].append(offsets) - all_requests["weights"][i].append(weights) - - prev_indices_len = -1 - requests = [] - for i in range(iters): - indices = torch.concat(all_requests["indices"][i]) - if prev_indices_len == -1: - prev_indices_len = indices.numel() - assert ( - prev_indices_len == indices.numel() - ), "Number of indices for every iteration must be the same" - offsets = torch.concat(all_requests["offsets"][i]) - if weighted: - weights = torch.concat(all_requests["weights"][i]) - else: - weights = None - requests.append(TBERequest(indices, offsets, weights)) - - del all_requests - + + if load: + requests = [] + for i in range(iters): + indices = torch.load(f"{load}/{i}_indices.pt") + offsets = torch.load(f"{load}/{i}_offsets.pt") + per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt") + Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt") + requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank)) + else: + for t, e in enumerate(Es): + # (indices, offsets, weights) + requests = generate_requests( + iters, + B, + 1, + Ls[t], + e, + reuse=reuse, + alpha=alpha, + weighted=weighted, + # pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined. + sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None, + zipf_oversample_ratio=3 if Ls[t] > 5 else 5, + use_cpu=get_available_compute_device() == ComputeDevice.CPU, + index_dtype=torch.long, + offset_dtype=torch.long, + ) + for i, req in enumerate(requests): + indices, offsets, weights = req.unpack_3() + all_requests["indices"][i].append(indices) + if t > 0: + offsets = offsets[1:] # remove the first element + offsets += all_requests["offsets"][i][t - 1][-1] + all_requests["offsets"][i].append(offsets) + all_requests["weights"][i].append(weights) + + prev_indices_len = -1 + requests = [] + for i in range(iters): + indices = torch.concat(all_requests["indices"][i]) + if prev_indices_len == -1: + prev_indices_len = indices.numel() + assert ( + prev_indices_len == indices.numel() + ), "Number of indices for every iteration must be the same" + offsets = torch.concat(all_requests["offsets"][i]) + if weighted: + weights = torch.concat(all_requests["weights"][i]) + else: + weights = None + requests.append(TBERequest(indices, offsets, weights)) + del all_requests assert len(requests) == iters - + if save: + for i in range(iters): + req = requests[i] + torch.save(req.indices, f"{save}/{i}_indices.pt") + torch.save(req.offsets, f"{save}/{i}_offsets.pt") + torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt") + torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt") + sum_DLs = sum([d * l for d, l in zip(Ds, Ls)]) if do_pooling: read_write_bytes = ( @@ -1203,34 +1280,44 @@ def device_with_spec( # noqa C901 # forward time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, per_sample_weights: emb.forward( - indices, - offsets, - per_sample_weights, - feature_requires_grad=feature_requires_grad, - ), - flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, - num_warmups=warmup_runs, - ) + requests, + lambda indices, offsets, per_sample_weights: emb.forward( + indices, + offsets, + per_sample_weights, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=flush_gpu_cache_size_mb, + num_warmups=warmup_runs, + ) logging.info( - f"Forward, B: {B}, " - f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " - f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"T: {time_per_iter * 1.0e6:.0f}us" - ) + f"Forward, B: {B}, " + f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) + if output_dtype == SparseType.INT8: # backward bench not representative return - if do_pooling: - grad_output = torch.randn(B, sum(Ds)).to(get_device()) + if load: + grad_output = torch.load(f"{load}/grad_output.pt") else: # Obtain B * L from indices len # pyre-ignore[19] # pyre-fixme[61]: `D` is undefined, or not always defined. - grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + if do_pooling: + grad_output = torch.randn(B, sum(Ds)).to(get_device()) + else: + # Obtain B * L from indices len + # pyre-ignore[19] + # pyre-fixme[61]: `D` is undefined, or not always defined. + grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device()) + + if save: + torch.save(grad_output, f"{save}/grad_output.pt") # backward time_per_iter = benchmark_requests( requests, @@ -1244,6 +1331,12 @@ def device_with_spec( # noqa C901 bwd_only=True, grad=grad_output, num_warmups=warmup_runs, + emb=emb, + save=save, + load=load, + compressed=compressed, + slice_min=slice_min, + slice_max=slice_max, ) logging.info( f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, " diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 1243f14db4..1bda3188e5 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -11,6 +11,7 @@ import statistics import threading import time +import gzip from subprocess import Popen from typing import Callable, Optional @@ -18,7 +19,7 @@ from fbgemm_gpu.tbe.utils import b_indices, TBERequest from fbgemm_gpu.tbe.utils.common import get_device - +from fbgemm_gpu.split_table_batched_embeddings_ops_training import SplitTableBatchedEmbeddingBagsCodegen logging.basicConfig(level=logging.DEBUG) @@ -241,36 +242,43 @@ def benchmark_requests( # noqa: C901 periodic_logs: bool = False, warmup_ms: Optional[int] = None, iters: int = -1, + emb: Optional[SplitTableBatchedEmbeddingBagsCodegen] = None, + save: Optional[str] = None, + load: Optional[str] = None, + compressed: bool = False, + slice_min: Optional[int] = None, + slice_max: Optional[int] = None, ) -> float: times = [] # Run at least one warmup iteration to avoid the long cudaLaunchKernel time # for the first kernel if warmup_ms > 0 # warmup_ms is prioritized over num_warmups - + import copy if warmup_ms is None: num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 - # warm-up the GPU before profiling - bench_warmup( - requests[0], - # pyre-ignore[6] - warmup_ms, - num_warmups, - lambda indices, offsets, per_sample_weights: func( - indices, - offsets, - per_sample_weights, - ), - bwd_only=bwd_only, - grad=grad, - ) + if not (load or save): + # warm-up the GPU before profiling + bench_warmup( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights: func( + indices, + offsets, + per_sample_weights, + ), + bwd_only=bwd_only, + grad=grad, + ) - if callback_after_warmup is not None: - callback_after_warmup() + if callback_after_warmup is not None: + callback_after_warmup() num_reqs = len(requests) iters = num_reqs if iters == -1 else iters - + sliced = slice_min is not None and slice_max is not None if torch.cuda.is_available(): torch.cuda.synchronize() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -278,7 +286,86 @@ def benchmark_requests( # noqa: C901 else: start_events = [] end_events = [] + if save and emb: + for it in range(iters): + req = requests[it % num_reqs] + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + if compressed: + with gzip.open(f"{save}/{it}_fwd_grad_out.pt.gz", "wb") as f: + torch.save(out, f) + else: + torch.save(out, f"{save}/{it}_fwd_grad_out.pt") + + out.backward(grad) + torch.cuda.synchronize() + torch.save(out, f"{save}/{it}_bwd_grad_out.pt") + + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{save}/{it}_{id}_bwd_weights_out.pt.gz", "wb") as f: + torch.save(t[slice_min:slice_max,:].clone(), f) + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + else: + torch.save(t[slice_min:slice_max,:].clone(), f"{save}/{it}_{id}_bwd_weights_out.pt") + torch.save(emb.momentum1_dev, f"{save}/{it}_bwd_momentum1_dev_out.pt") + torch.save(emb.momentum1_uvm, f"{save}/{it}_bwd_momentum1_uvm_out.pt") + + else: + if compressed: + with gzip.open(f"{save}/{it}_bwd_state_out.pth.gz", "wb") as f: + torch.save(emb.state_dict(), f) + else: + torch.save(emb.state_dict(), f"{save}/{it}_bwd_state_out.pth") + + if load and emb: + for it in range(iters): + req = requests[it % num_reqs] + + indices, offsets, weights = req.unpack_3() + out = emb(indices, offsets, weights) + torch.cuda.synchronize() + + out.backward(grad) + torch.cuda.synchronize() + emb_ref = copy.deepcopy(emb) + if not sliced: + if compressed: + with gzip.open(f"{load}/{it}_bwd_state_out.pth.gz", "rb") as f: + emb_ref.load_state_dict(torch.load(f)) + else: + emb_ref.load_state_dict(torch.load(f"{load}/{it}_bwd_state_out.pth")) + + print(f"[{it + 1}/{iters}] Backward weights check... ", end="", flush=True) + if sliced: + for id, t in enumerate(emb.split_embedding_weights()): + if compressed: + with gzip.open(f"{it}_{id}_bwd_weights_out.pt.gz", "rb") as f: + w_ref = torch.load(f) + else: + w_ref = torch.load(f"{load}/{it}_{id}_bwd_weights_out.pt") + torch.testing.assert_close(t[slice_min:slice_max,:], w_ref, + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + else: + for id, t in enumerate(emb.split_embedding_weights()): + torch.testing.assert_close(t, emb_ref.split_embedding_weights()[id], + msg=f"FAILED table = {id}", atol=1.0e-3, rtol=10e-3) + print("PASS") + + print(f"[{it + 1}/{iters}] Backward momentum check... ", end="", flush=True) + if sliced: + m_dev_ref = torch.load(f"{load}/{it}_bwd_momentum1_dev_out.pt") + m_uvm_ref = torch.load(f"{load}/{it}_bwd_momentum1_uvm_out.pt") + else: + m_dev_ref = emb_ref.momentum1_dev + m_uvm_ref = emb_ref.momentum1_uvm + torch.testing.assert_close(emb.momentum1_dev, m_dev_ref, atol=1.0e-4, rtol=1.0e-4) + torch.testing.assert_close(emb.momentum1_uvm, m_uvm_ref, atol=1.0e-4, rtol=1.0e-4) + print("PASS") for it in range(iters): req = requests[it % num_reqs] diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 5b9d69d910..745499ac08 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -24,7 +24,6 @@ #include #include #include -#include /******************************************************************************/ typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); @@ -62,7 +61,7 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if ROCM_VERSION_MAJOR >= 7 +#if defined(__gfx950__) __asm("llvm.amdgcn.raw.buffer.load.i16"); #else __asm("llvm.amdgcn.raw.buffer.load.f16"); @@ -79,7 +78,7 @@ __device__ half2 llvm_amdgcn_raw_buffer_load_fp16x2( int32_t voffset, int32_t soffset, int32_t glc_slc) -#if ROCM_VERSION_MAJOR >= 7 +#if defined(__gfx950__) __asm("llvm.amdgcn.raw.buffer.load.i32"); #else __asm("llvm.amdgcn.raw.buffer.load.v2f16"); @@ -165,7 +164,7 @@ struct load_row_per_warp { static __device__ void run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 160); + amdgcn_make_buffer_resource(p_emb_table + row_index * 192); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, lane_id * sizeof(half2), 0, 0); if ((lane_id + 128) % 192 < 160) { @@ -320,6 +319,15 @@ struct store_row_per_warp { } }; +template <> +struct store_row_per_warp { + static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { + auto out = reinterpret_cast(p_output); + out[lane_id] = *reinterpret_cast(acc); + *(reinterpret_cast(&out[64]) + lane_id) = *reinterpret_cast(acc + 2); + } +}; + template <> struct store_row_per_warp { static __device__ void run(c10::Half* acc, c10::Half* p_output, int lane_id) { @@ -619,4 +627,4 @@ __device__ inline void magic_div_u32_run_with_mod( quo = magic_div_u32_run(mdiv, n); rem = n - quo * d; } -} // namespace fbgemm_gpu::rocm +} // namespace fbgemm_gpu::rocm \ No newline at end of file From 986cceb660f14d5f26bae57b45682c4357052c5e Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 22 Sep 2025 16:09:05 +0000 Subject: [PATCH 25/28] workgroup tuning and loop unrolled --- .../forward/embedding_forward_split_kernel_template.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) mode change 100644 => 100755 fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu old mode 100644 new mode 100755 index 0122cfcee9..5b13aefef8 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -469,10 +469,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + kThreadGroupSize > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize/32) > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { + for (auto j = 0; j < (kThreadGroupSize/32) && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group @@ -641,7 +641,7 @@ batch_index_select_dim0_codegen_forward_kernel( {%- endif %} {%- if is_rocm %} // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 4; + constexpr int kManualUnrollLength = 8; {%- endif %} // Determine the linearized warp ID, and exit early if needed From 2bf70c6b137757292f459dce483cb53041110b72 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Sep 2025 22:38:17 +0200 Subject: [PATCH 26/28] specialize --- ..._backward_split_indice_weights_template.cu | 145 ++++++++++++------ 1 file changed, 95 insertions(+), 50 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu index e24a812e8b..292a08d62a 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu @@ -217,33 +217,82 @@ __global__ __launch_bounds__(kForwardMaxThreads) void int32_t j = 0; {%- if not ssd and not dense and not use_vec_blocking and not vbe %} // Currently for split_embedding_codegen_grad_indice_weights_kernel only - for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { - const auto offset_idx_j0 = shfl_sync(offset_idx, j); - const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); - const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); - const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); - - const auto cache_idx_j0 = shfl_sync(cache_idx, j); - const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); - const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); - const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); - - at::acc_type grad_indice_weight0 = 0.0; - at::acc_type grad_indice_weight1 = 0.0; - at::acc_type grad_indice_weight2 = 0.0; - at::acc_type grad_indice_weight3 = 0.0; - - [[maybe_unused]] const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); - [[maybe_unused]] const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); - [[maybe_unused]] const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); - [[maybe_unused]] const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + if (placement != PlacementType::MANAGED_CACHING) { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; + weight0 = weight_row0.load(d); + weight1 = weight_row1.load(d); + weight2 = weight_row2.load(d); + weight3 = weight_row3.load(d); - #pragma unroll kFixedMaxVecsPerThread - for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { - const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; + } + + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); - Vec4T> weight0, weight1, weight2, weight3; - if (placement == PlacementType::MANAGED_CACHING) { + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } + } + } else { + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { + const auto offset_idx_j0 = shfl_sync(offset_idx, j); + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); + + const auto cache_idx_j0 = shfl_sync(cache_idx, j); + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); + + at::acc_type grad_indice_weight0 = 0.0; + at::acc_type grad_indice_weight1 = 0.0; + at::acc_type grad_indice_weight2 = 0.0; + at::acc_type grad_indice_weight3 = 0.0; + + const auto weight_row0 = WeightRowAccessor>(&weights[offset_idx_j0], D); + const auto weight_row1 = WeightRowAccessor>(&weights[offset_idx_j1], D); + const auto weight_row2 = WeightRowAccessor>(&weights[offset_idx_j2], D); + const auto weight_row3 = WeightRowAccessor>(&weights[offset_idx_j3], D); + + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; + + Vec4T> weight0, weight1, weight2, weight3; weight0 = (cache_idx_j0 != kCacheLocationMissing) ? Vec4T>(&lxu_cache_weights[cache_idx_j0][d]) : weight_row0.load(d); @@ -259,33 +308,29 @@ __global__ __launch_bounds__(kForwardMaxThreads) void weight3 = (cache_idx_j3 != kCacheLocationMissing) ? Vec4T>(&lxu_cache_weights[cache_idx_j3][d]) : weight_row3.load(d); - } else { - weight0 = weight_row0.load(d); - weight1 = weight_row1.load(d); - weight2 = weight_row2.load(d); - weight3 = weight_row3.load(d); + + + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; } - grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + - weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; - grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + - weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; - grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + - weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; - grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + - weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; - } - - grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); - grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); - grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); - grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); + grad_indice_weight0 = warpReduceAllSum>(grad_indice_weight0); + grad_indice_weight1 = warpReduceAllSum>(grad_indice_weight1); + grad_indice_weight2 = warpReduceAllSum>(grad_indice_weight2); + grad_indice_weight3 = warpReduceAllSum>(grad_indice_weight3); - if (threadIdx.x == 0) { - grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; - grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; - grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; - grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + if (threadIdx.x == 0) { + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; + } } } {%- endif %} @@ -447,7 +492,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( TORCH_WARN_ONCE("Running on CDNA architecture"); } #endif - + const auto T = D_offsets.size(0) - 1; TORCH_CHECK_GT(T, 0); // offsets = [B x T + 1] From 4d2bfddf472907e443203acc185939fabfabe9f6 Mon Sep 17 00:00:00 2001 From: Li Li Date: Wed, 24 Sep 2025 00:48:35 +0000 Subject: [PATCH 27/28] explicitly link to tbb --- cmake/modules/CppLibrary.cmake | 12 ++++++++++++ cmake/modules/GpuCppLibrary.cmake | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/cmake/modules/CppLibrary.cmake b/cmake/modules/CppLibrary.cmake index 92a93a60b6..388d3ac779 100644 --- a/cmake/modules/CppLibrary.cmake +++ b/cmake/modules/CppLibrary.cmake @@ -168,6 +168,18 @@ function(cpp_library) target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + target_link_libraries(${lib_name} PUBLIC TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + target_link_libraries(${lib_name} PUBLIC ${TBB_LIB}) + endif() + endif() + # Add sanitizer options if needed if(args_SANITIZER_OPTIONS) target_link_options(${lib_name} PUBLIC diff --git a/cmake/modules/GpuCppLibrary.cmake b/cmake/modules/GpuCppLibrary.cmake index 51c30df750..e662848348 100644 --- a/cmake/modules/GpuCppLibrary.cmake +++ b/cmake/modules/GpuCppLibrary.cmake @@ -302,6 +302,18 @@ function(gpu_cpp_library) list(APPEND library_dependencies ${NVML_LIB_PATH}) endif() + if(NOT TARGET TBB::tbb) + find_package(TBB QUIET) + endif() + if(TBB_FOUND) + list(APPEND library_dependencies TBB::tbb) + else() + find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu) + if(TBB_LIB) + list(APPEND library_dependencies ${TBB_LIB}) + endif() + endif() + # Link against the external libraries as needed target_link_libraries(${lib_name} PRIVATE ${library_dependencies}) From f10335ae558b55ac052e50e5a17c0e57b5224cbd Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Thu, 25 Sep 2025 19:00:23 +0000 Subject: [PATCH 28/28] added warpReduceAllSum with rocm guards --- .../include/fbgemm_gpu/utils/cuda_prelude.cuh | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) mode change 100644 => 100755 fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh old mode 100644 new mode 100755 index a1d9819017..d51e3fa475 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh @@ -140,11 +140,19 @@ template DEVICE_INLINE T warpReduceAllSum( T val, unsigned shfl_sync_mask = static_cast(kFullWarpMask)) { - return rocm::wave_reduce< - rocm::reduce_op::sum, // Sum reduction - T, // Data type - ReduceWidth // Wave/Warp size - >(val); + #ifdef USE_ROCM + return rocm::wave_reduce< + rocm::reduce_op::sum, // Sum reduction + T, // Data type + ReduceWidth // Wave/Warp size + >(val); + #else + #pragma unroll + for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) { + val += shfl_xor(val, mask, ReduceWidth, shfl_sync_mask); + } + return val; + #endif } DEVICE_INLINE void syncwarp() {