Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/scripts/utils_build.bash
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ install_build_tools () {
patchelf \
rhash \
scikit-build \
tbb-devel \
tbb \
wheel \
xz \
Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/CppLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ function(cpp_library)
target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX)
endif()

if(NOT TARGET TBB::tbb)
find_package(TBB QUIET)
endif()
if(TBB_FOUND)
target_link_libraries(${lib_name} PRIVATE TBB::tbb)
else()
find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu)
if(TBB_LIB)
target_link_libraries(${lib_name} PRIVATE ${TBB_LIB})
endif()
endif()

# Add sanitizer options if needed
if(args_SANITIZER_OPTIONS)
target_link_options(${lib_name} PUBLIC
Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/GpuCppLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,18 @@ function(gpu_cpp_library)
list(APPEND library_dependencies ${NVML_LIB_PATH})
endif()

if(NOT TARGET TBB::tbb)
find_package(TBB QUIET)
endif()
if(TBB_FOUND)
list(APPEND library_dependencies TBB::tbb)
else()
find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu)
if(TBB_LIB)
list(APPEND library_dependencies ${TBB_LIB})
endif()
endif()

# Link against the external libraries as needed
target_link_libraries(${lib_name} PRIVATE ${library_dependencies})

Expand Down
2 changes: 0 additions & 2 deletions fbgemm_gpu/cmake/tbe_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down Expand Up @@ -495,7 +494,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down
10 changes: 7 additions & 3 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def render_backward_templates(
return

weighted_options = [True, False]
nobag_options = [True, False] if (not is_gwd) else [False]
nobag_options = (
[True, False]
if (not (is_gwd or kwargs.get("is_hip_optimized_backward")))
else [False]
)
vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False]
ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False]
template = CodeTemplate.load(template_filepath)
Expand Down Expand Up @@ -327,8 +331,7 @@ def generate_backward_indices() -> None:

@staticmethod
def generate_rocm_backward_split(**kwargs: Any) -> None:
# Generate backward device kernels based on weighted (True/False), VBE
# (True/False), no bag (True/False)
# Generate backward device kernels based on weighted (True/False)
template_filepath = (
"training/backward/rocm/embedding_backward_split_device_kernel_template.hip"
)
Expand All @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None:
"has_ssd_support": False,
"dense": False,
"gen_once": False,
"is_hip_optimized_backward": True,
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function(
c10::SymInt /* max_B = -1 */,
c10::SymInt /* max_B_feature_rank = -1 */,
c10::SymInt /* vbe_output_size = -1 */,
bool /* mixed_D = true */) {
bool /* mixed_D = false */) {
return SplitLookupFunction_Dense_Op::apply(
host_weights,
weights_offsets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ class {{ autograd_func }} :

#ifdef USE_ROCM
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
constexpr int32_t max_segment_length_per_warp = 16384;
#else
constexpr int32_t BT_block_size = 32;
constexpr int32_t max_segment_length_per_warp = 32;
Expand Down Expand Up @@ -1116,7 +1116,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
{%- else %}
const c10::SymInt vbe_output_size = -1,
{%- endif %}
const bool mixed_D = true
const bool mixed_D = false
) {
// TODO: refactor into macro
{%- if has_gpu_support %}
Expand Down
134 changes: 133 additions & 1 deletion fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#include "fbgemm_gpu/utils/assert_macros.h"
#include "fbgemm_gpu/utils/kernel_launcher.cuh"

{%- if is_rocm %}
#include "fbgemm_gpu/rocm/cdna_guard.h"
{%- endif %}

using Tensor = at::Tensor;
using namespace fbgemm_gpu;

Expand Down Expand Up @@ -209,8 +213,127 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
2, offset_idx + D_emb <= weights_numel, offset_idx
)
{%- endif %}
int32_t j = 0;
{%- if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe %}
// Currently for split_embedding_codegen_grad_indice_weights_kernel only
if (placement != PlacementType::MANAGED_CACHING) {
for (; j < kWarpSize && l_start + j + 3 < L; j += 4) {
const auto offset_idx_j0 = shfl_sync(offset_idx, j);
const auto offset_idx_j1 = shfl_sync(offset_idx, j+1);
const auto offset_idx_j2 = shfl_sync(offset_idx, j+2);
const auto offset_idx_j3 = shfl_sync(offset_idx, j+3);

at::acc_type<cache_t, true> grad_indice_weight0 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight1 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight2 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight3 = 0.0;

const auto weight_row0 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j0], D);
const auto weight_row1 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j1], D);
const auto weight_row2 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j2], D);
const auto weight_row3 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j3], D);

#pragma unroll kFixedMaxVecsPerThread
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) {
const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth;

Vec4T<at::acc_type<cache_t, true>> weight0, weight1, weight2, weight3;
weight0 = weight_row0.load(d);
weight1 = weight_row1.load(d);
weight2 = weight_row2.load(d);
weight3 = weight_row3.load(d);

grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y +
weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w;
grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y +
weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w;
grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y +
weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w;
grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y +
weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w;
}

grad_indice_weight0 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight0);
grad_indice_weight1 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight1);
grad_indice_weight2 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight2);
grad_indice_weight3 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight3);

if (threadIdx.x == 0) {
grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0;
grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1;
grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2;
grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3;
}
}
} else {
for (; j < kWarpSize && l_start + j + 3 < L; j += 4) {
const auto offset_idx_j0 = shfl_sync(offset_idx, j);
const auto offset_idx_j1 = shfl_sync(offset_idx, j+1);
const auto offset_idx_j2 = shfl_sync(offset_idx, j+2);
const auto offset_idx_j3 = shfl_sync(offset_idx, j+3);

const auto cache_idx_j0 = shfl_sync(cache_idx, j);
const auto cache_idx_j1 = shfl_sync(cache_idx, j+1);
const auto cache_idx_j2 = shfl_sync(cache_idx, j+2);
const auto cache_idx_j3 = shfl_sync(cache_idx, j+3);

at::acc_type<cache_t, true> grad_indice_weight0 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight1 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight2 = 0.0;
at::acc_type<cache_t, true> grad_indice_weight3 = 0.0;

const auto weight_row0 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j0], D);
const auto weight_row1 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j1], D);
const auto weight_row2 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j2], D);
const auto weight_row3 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j3], D);

#pragma unroll kFixedMaxVecsPerThread
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) {
const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth;

Vec4T<at::acc_type<cache_t, true>> weight0, weight1, weight2, weight3;
weight0 = (cache_idx_j0 != kCacheLocationMissing) ?
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j0][d]) :
weight_row0.load(d);

weight1 = (cache_idx_j1 != kCacheLocationMissing) ?
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j1][d]) :
weight_row1.load(d);

weight2 = (cache_idx_j2 != kCacheLocationMissing) ?
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j2][d]) :
weight_row2.load(d);

weight3 = (cache_idx_j3 != kCacheLocationMissing) ?
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j3][d]) :
weight_row3.load(d);


grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y +
weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w;
grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y +
weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w;
grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y +
weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w;
grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y +
weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w;
}

grad_indice_weight0 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight0);
grad_indice_weight1 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight1);
grad_indice_weight2 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight2);
grad_indice_weight3 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight3);

for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
if (threadIdx.x == 0) {
grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0;
grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1;
grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2;
grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3;
}
}
}
{%- endif %}{#-/* if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe */#}
for (; j < kWarpSize && l_start + j < L; ++j) {
const auto offset_idx_j = shfl_sync(offset_idx, j);
{%- if not dense %}
const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j);
Expand Down Expand Up @@ -359,6 +482,15 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output);

CUDA_DEVICE_GUARD(dev_weights);
#ifdef USE_ROCM
if (!rocm::is_supported_cdna()) {
TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal.");
}
else {
// Ensure we're running on a supported CDNA architecture (including MI350)
TORCH_WARN_ONCE("Running on CDNA architecture");
}
#endif

const auto T = D_offsets.size(0) - 1;
TORCH_CHECK_GT(T, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@

{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %}
{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %}
{%- set is_optimized_hip_kernel_supported_mode = is_rocm and
optimizer == "rowwise_adagrad" and
not dense and
not nobag and
not is_index_select and
not is_gwd_kernel and
not vbe and
not ssd %}

#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
Expand Down Expand Up @@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row

{%- endif %}

{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %}
{%- if is_optimized_hip_kernel_supported_mode %}
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include "fbgemm_gpu/rocm/split_embeddings_common.h"
Expand Down Expand Up @@ -612,12 +620,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
{{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}
{%- endif %}
) {
{%- if not nobag %}
int32_t T = D_offsets.size(0) - 1;
{%- else %}
int32_t T = weights_offsets.size(0);
{%- endif %}

auto p_output_grad = grad_output.data();
auto p_emb_table = dev_weights.data();
auto p_hash_size_cumsum = hash_size_cumsum.data();
Expand All @@ -632,8 +635,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
constexpr int32_t segment_prefetch = 2;
constexpr int32_t segment_unroll = 8;
constexpr int32_t segment_split = 0;
auto batch = grad_output.size(0);
auto num_rows = dev_weights.size(0) / T / max_D;
{%- if weighted %}
constexpr bool is_weighted = true;
{%- else %}
Expand All @@ -646,30 +647,15 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
// weight_decay(_mode) is supplied as args.split_function_args_no_defaults
opt_karg.weight_decay_mode = weight_decay_mode_v;
opt_karg.weight_decay = weight_decay;
auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t {
assert(d >= 1 && d <= INT32_MAX);
uint8_t shift;
for(shift = 0; shift < 32; shift++)
if((1U << shift) >= d)
break;

uint64_t one = 1;
uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1;
assert(magic <= 0xffffffffUL);

rocm::magic_div_u32_t result;
result.magic = magic;
result.shift = shift;
return result;
}(batch);

rocm::split_tbe_backward_hip_kernel_{{kdesc}}<
rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, embedding_dim, weight_decay_mode_v>,
rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, index_t, embedding_dim, weight_decay_mode_v>,
rocm::{{optimizer}}_kernel_arg_t,
emb_t,
cache_t,
grad_t,
index_t,
BLOCK_SIZE,
BLOCK_SIZE_ROCM,
embedding_dim,
segment_prefetch,
segment_unroll,
Expand All @@ -680,16 +666,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
p_sorted_linear_indices_run,
p_sorted_linear_indices_cumulative_run_lengths,
p_sorted_linear_indices_num_runs,
{%- if not nobag %}
info_B_num_bits,
info_B_mask,
{%- endif %}
p_sorted_infos,
batch_mdiv,
max_segment_length_per_warp,
emb_dim,
batch,
num_rows,
T,
opt_karg
{%- if weighted %}
Expand Down Expand Up @@ -784,7 +765,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
{%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for kEmbeddingDim in [64, 128, 160, 192, 256] %}
{%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %}
{%- for kWeighDecayMode in [0, 1, 2] %}
{{ hip_template_instantiation(
emb_type,
Expand Down
Loading
Loading