Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
fe4d8f0
Add gfx950 build support + fp16 fix + index type fix
avbokovoy Jul 29, 2025
2006f08
Change int64_t to index_t as template parameters in load_raw_per_warp
avbokovoy Jul 29, 2025
757a2f4
Implement llvm fp16 buffer load for gfx950
avbokovoy Jul 29, 2025
f875d54
Fix c-style half to float cast
avbokovoy Aug 11, 2025
ea9d8f8
Patch 256 half stores
avbokovoy Aug 11, 2025
e63ead2
cta_per_row workgroup optim
shbiswas834 Aug 8, 2025
b9eebb4
Added mi350 guards
shbiswas834 Aug 11, 2025
69ae10e
Fix index overflow in row load
shbiswas834 Aug 12, 2025
8f692dc
cta_per_row workgroup reduce by 4 optim
shbiswas834 Aug 12, 2025
768dc01
Fix mixed_D frontend to backend connection
avbokovoy Aug 13, 2025
151d2dd
changed max_segment_length_per_cta to 4096
kudomcho Aug 15, 2025
54e0e24
added rocm guards and removed comment
shbiswas834 Aug 18, 2025
7b2684c
clean debug statements in Hip.cmake
liligwu Aug 20, 2025
9b22e17
Merge pull request #121
shbiswas834 Aug 28, 2025
9adc6bc
Guard f16 llvm intrinsics with ROCm >=7.0
avbokovoy Sep 2, 2025
dc16185
fix the bug in dimention 160 in ROCm optimization
liligwu Sep 18, 2025
00c1914
Cleanup optimized warp_per_raw kernel
avbokovoy Aug 19, 2025
f2c662a
Add 320 embedding dim support for optimized warp_per_row kernel
avbokovoy Aug 20, 2025
f8fe9d7
changed the max length per warp and cta per row WG size
Sep 8, 2025
279aeac
added DPP and changed max length per warp to 16k
kudomcho Sep 9, 2025
d59e3d6
guard max segment warp based on emb dim
kudomcho Sep 10, 2025
faf378e
added guarding opt of max segment for the case batch size list=1
kudomcho Sep 10, 2025
d9239e9
opt for grad_indice_weights kernel
Sep 18, 2025
145a673
added store row per warp on emb 192 and added accuracy test functiona…
kudomcho Sep 23, 2025
986cceb
workgroup tuning and loop unrolled
shbiswas834 Sep 22, 2025
2bf70c6
specialize
Hardcode84 Sep 19, 2025
4d2bfdd
explicitly link to tbb
liligwu Sep 24, 2025
f10335a
added warpReduceAllSum with rocm guards
shbiswas834 Sep 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cmake/modules/CppLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ function(cpp_library)
target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX)
endif()

if(NOT TARGET TBB::tbb)
find_package(TBB QUIET)
endif()
if(TBB_FOUND)
target_link_libraries(${lib_name} PUBLIC TBB::tbb)
else()
find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu)
if(TBB_LIB)
target_link_libraries(${lib_name} PUBLIC ${TBB_LIB})
endif()
endif()

# Add sanitizer options if needed
if(args_SANITIZER_OPTIONS)
target_link_options(${lib_name} PUBLIC
Expand Down
12 changes: 12 additions & 0 deletions cmake/modules/GpuCppLibrary.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,18 @@ function(gpu_cpp_library)
list(APPEND library_dependencies ${NVML_LIB_PATH})
endif()

if(NOT TARGET TBB::tbb)
find_package(TBB QUIET)
endif()
if(TBB_FOUND)
list(APPEND library_dependencies TBB::tbb)
else()
find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu)
if(TBB_LIB)
list(APPEND library_dependencies ${TBB_LIB})
endif()
endif()

# Link against the external libraries as needed
target_link_libraries(${lib_name} PRIVATE ${library_dependencies})

Expand Down
223 changes: 158 additions & 65 deletions fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

# pyre-strict


import gzip
import yaml
import logging
import os
import tempfile
Expand Down Expand Up @@ -1011,7 +1012,15 @@
@TbeBenchClickInterface.common_options
@TbeBenchClickInterface.device_options
@TbeBenchClickInterface.vbe_options
@click.option("--save", type=str, default=None)
@click.option("--load", type=str, default=None)
@click.option("--random-weights", is_flag=True, default=False)
@click.option("--compressed", is_flag=True, default=False)
@click.option("--slice-min", type=int, default=None)
@click.option("--slice-max", type=int, default=None)
@click.pass_context
def device_with_spec( # noqa C901
ctx,
alpha: float,
bag_size_list: str,
bag_size_sigma_list: str,
Expand All @@ -1031,7 +1040,39 @@
bounds_check_mode: int,
flush_gpu_cache_size_mb: int,
output_dtype: SparseType,
save: str,
load: str,
random_weights: bool,
compressed: bool,
slice_min: int,
slice_max: int,
) -> None:
if load:
with open(f"{load}/params.yaml", "r") as f:
ctx.params = yaml.load(f, Loader=yaml.UnsafeLoader)
alpha = ctx.params["alpha"]
bag_size_list = ctx.params["bag_size_list"]
bag_size_sigma_list = ctx.params["bag_size_sigma_list"]
batch_size = ctx.params["batch_size"]
embedding_dim_list = ctx.params["embedding_dim_list"]
weights_precision = ctx.params["weights_precision"]
cache_precision = ctx.params["cache_precision"]
stoc = ctx.params["stoc"]
iters = ctx.params["iters"]
warmup_runs = ctx.params["warmup_runs"]
managed = ctx.params["managed"]
num_embeddings_list = ctx.params["num_embeddings_list"]
reuse = ctx.params["reuse"]
row_wise = ctx.params["row_wise"]
weighted = ctx.params["weighted"]
pooling = ctx.params["pooling"]
bounds_check_mode = ctx.params["bounds_check_mode"]
flush_gpu_cache_size_mb = ctx.params["flush_gpu_cache_size_mb"]
output_dtype = ctx.params["output_dtype"]
random_weights = ctx.params["random_weights"]
compressed = ctx.params["compressed"]
slice_min = ctx.params["slice_min"]
slice_max = ctx.params["slice_max"]
np.random.seed(42)
torch.manual_seed(42)
B = batch_size
Expand All @@ -1040,6 +1081,11 @@
T = len(Ds)

use_variable_bag_sizes = bag_size_sigma_list != "None"
params = ctx.params
if save:
os.makedirs(f"{save}", exist_ok=True)
with open(f"{save}/params.yaml", "w") as f:
yaml.dump(params, f, sort_keys=False)

if use_variable_bag_sizes:
Ls = [int(mu) for mu in bag_size_list.split(",")]
Expand Down Expand Up @@ -1118,6 +1164,22 @@

if weights_precision == SparseType.INT8:
emb.init_embedding_weights_uniform(-0.0003, 0.0003)
elif random_weights:
emb.init_embedding_weights_uniform(-1.0, 1.0)

if save:
if compressed:
with gzip.open(f"{save}/model_state.pth.gz", "wb") as f:
torch.save(emb.state_dict(), f)
else:
torch.save(emb.state_dict(), f"{save}/model_state.pth")

if load:
if compressed:
with gzip.open(f"{load}/model_state.pth.gz", "rb") as f:
emb.load_state_dict(torch.load(f))
else:
emb.load_state_dict(torch.load(f"{load}/model_state.pth"))

nparams = sum(w.numel() for w in emb.split_embedding_weights())
param_size_multiplier = weights_precision.bit_rate() / 8.0
Expand All @@ -1130,53 +1192,68 @@
"weights": [[] for _ in range(iters)],
}
# row = iter, column = tensor
for t, e in enumerate(Es):
# (indices, offsets, weights)
requests = generate_requests(
iters,
B,
1,
Ls[t],
e,
reuse=reuse,
alpha=alpha,
weighted=weighted,
# pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined.
sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None,
zipf_oversample_ratio=3 if Ls[t] > 5 else 5,
use_cpu=get_available_compute_device() == ComputeDevice.CPU,
index_dtype=torch.long,
offset_dtype=torch.long,
)
for i, req in enumerate(requests):
indices, offsets, weights = req.unpack_3()
all_requests["indices"][i].append(indices)
if t > 0:
offsets = offsets[1:] # remove the first element
offsets += all_requests["offsets"][i][t - 1][-1]
all_requests["offsets"][i].append(offsets)
all_requests["weights"][i].append(weights)

prev_indices_len = -1
requests = []
for i in range(iters):
indices = torch.concat(all_requests["indices"][i])
if prev_indices_len == -1:
prev_indices_len = indices.numel()
assert (
prev_indices_len == indices.numel()
), "Number of indices for every iteration must be the same"
offsets = torch.concat(all_requests["offsets"][i])
if weighted:
weights = torch.concat(all_requests["weights"][i])
else:
weights = None
requests.append(TBERequest(indices, offsets, weights))

del all_requests


Check failure on line 1195 in fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

W293 blank line contains whitespace
if load:
requests = []
for i in range(iters):
indices = torch.load(f"{load}/{i}_indices.pt")
offsets = torch.load(f"{load}/{i}_offsets.pt")
per_sample_weights = torch.load(f"{load}/{i}_per_sample_weights.pt")
Bs_per_feature_per_rank = torch.load(f"{load}/{i}_Bs_per_feature_per_rank.pt")
requests.append(TBERequest(indices, offsets, per_sample_weights, Bs_per_feature_per_rank))
else:
for t, e in enumerate(Es):
# (indices, offsets, weights)
requests = generate_requests(
iters,
B,
1,
Ls[t],
e,
reuse=reuse,
alpha=alpha,
weighted=weighted,
# pyre-fixme[61]: `sigma_Ls` is undefined, or not always defined.
sigma_L=sigma_Ls[t] if use_variable_bag_sizes else None,
zipf_oversample_ratio=3 if Ls[t] > 5 else 5,
use_cpu=get_available_compute_device() == ComputeDevice.CPU,
index_dtype=torch.long,
offset_dtype=torch.long,
)
for i, req in enumerate(requests):
indices, offsets, weights = req.unpack_3()
all_requests["indices"][i].append(indices)
if t > 0:
offsets = offsets[1:] # remove the first element
offsets += all_requests["offsets"][i][t - 1][-1]
all_requests["offsets"][i].append(offsets)
all_requests["weights"][i].append(weights)

prev_indices_len = -1
requests = []
for i in range(iters):
indices = torch.concat(all_requests["indices"][i])
if prev_indices_len == -1:
prev_indices_len = indices.numel()
assert (
prev_indices_len == indices.numel()
), "Number of indices for every iteration must be the same"
offsets = torch.concat(all_requests["offsets"][i])
if weighted:
weights = torch.concat(all_requests["weights"][i])
else:
weights = None
requests.append(TBERequest(indices, offsets, weights))
del all_requests
assert len(requests) == iters

if save:
for i in range(iters):
req = requests[i]
torch.save(req.indices, f"{save}/{i}_indices.pt")
torch.save(req.offsets, f"{save}/{i}_offsets.pt")
torch.save(req.per_sample_weights, f"{save}/{i}_per_sample_weights.pt")
torch.save(req.Bs_per_feature_per_rank, f"{save}/{i}_Bs_per_feature_per_rank.pt")

Check failure on line 1256 in fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

W293 blank line contains whitespace
sum_DLs = sum([d * l for d, l in zip(Ds, Ls)])
if do_pooling:
read_write_bytes = (
Expand All @@ -1203,34 +1280,44 @@

# forward
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, per_sample_weights: emb.forward(
indices,
offsets,
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
num_warmups=warmup_runs,
)
requests,

Check failure on line 1283 in fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

E126 continuation line over-indented for hanging indent
lambda indices, offsets, per_sample_weights: emb.forward(
indices,
offsets,
per_sample_weights,
feature_requires_grad=feature_requires_grad,
),
flush_gpu_cache_size_mb=flush_gpu_cache_size_mb,
num_warmups=warmup_runs,
)

Check failure on line 1292 in fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

E121 continuation line under-indented for hanging indent
logging.info(
f"Forward, B: {B}, "
f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, "
f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"T: {time_per_iter * 1.0e6:.0f}us"
)
f"Forward, B: {B}, "

Check failure on line 1294 in fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

E126 continuation line over-indented for hanging indent
f"Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, W: {weighted}, "
f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"T: {time_per_iter * 1.0e6:.0f}us"
)

Check failure on line 1298 in fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

E121 continuation line under-indented for hanging indent


if output_dtype == SparseType.INT8:

Check failure on line 1301 in fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

E303 too many blank lines (2)
# backward bench not representative
return

if do_pooling:
grad_output = torch.randn(B, sum(Ds)).to(get_device())
if load:
grad_output = torch.load(f"{load}/grad_output.pt")
else:
# Obtain B * L from indices len
# pyre-ignore[19]
# pyre-fixme[61]: `D` is undefined, or not always defined.
grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device())
if do_pooling:
grad_output = torch.randn(B, sum(Ds)).to(get_device())
else:
# Obtain B * L from indices len
# pyre-ignore[19]
# pyre-fixme[61]: `D` is undefined, or not always defined.
grad_output = torch.randn(requests[0].indices.numel(), D).to(get_device())

Check failure on line 1318 in fbgemm_gpu/bench/tbe/split_table_batched_embeddings_benchmark.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

W293 blank line contains whitespace
if save:
torch.save(grad_output, f"{save}/grad_output.pt")
# backward
time_per_iter = benchmark_requests(
requests,
Expand All @@ -1244,6 +1331,12 @@
bwd_only=True,
grad=grad_output,
num_warmups=warmup_runs,
emb=emb,
save=save,
load=load,
compressed=compressed,
slice_min=slice_min,
slice_max=slice_max,
)
logging.info(
f"Backward, B: {B}, Es: {Es}, T: {T}, Ds: {Ds}, Ls: {Ls_str}, "
Expand Down
2 changes: 0 additions & 2 deletions fbgemm_gpu/cmake/tbe_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down Expand Up @@ -495,7 +494,6 @@
"_nobag" if nobag else "",
)
for nobag in [
True,
False,
]
for weighted in (
Expand Down
10 changes: 7 additions & 3 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def render_backward_templates(
return

weighted_options = [True, False]
nobag_options = [True, False] if (not is_gwd) else [False]
nobag_options = (
[True, False]
if (not (is_gwd or kwargs.get("is_hip_optimized_backward")))
else [False]
)
vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False]
ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False]
template = CodeTemplate.load(template_filepath)
Expand Down Expand Up @@ -327,8 +331,7 @@ def generate_backward_indices() -> None:

@staticmethod
def generate_rocm_backward_split(**kwargs: Any) -> None:
# Generate backward device kernels based on weighted (True/False), VBE
# (True/False), no bag (True/False)
# Generate backward device kernels based on weighted (True/False)
template_filepath = (
"training/backward/rocm/embedding_backward_split_device_kernel_template.hip"
)
Expand All @@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None:
"has_ssd_support": False,
"dense": False,
"gen_once": False,
"is_hip_optimized_backward": True,
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ class {{ autograd_func }} :

#ifdef USE_ROCM
constexpr int32_t BT_block_size = 64;
constexpr int32_t max_segment_length_per_warp = 64;
constexpr int32_t max_segment_length_per_warp = 16384;
#else
constexpr int32_t BT_block_size = 32;
constexpr int32_t max_segment_length_per_warp = 32;
Expand Down
Loading
Loading