Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 47 additions & 29 deletions fbgemm_gpu/bench/merge_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import tabulate
import torch

import math
from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
Expand Down Expand Up @@ -100,45 +100,63 @@ def generate_requests(


# pyre-fixme[3]: Return type must be annotated.

def _get_random_tensor(
num_ads: int,
embedding_dimension: int,
ads_tables: int,
data_type: str,
gpu_idx: int,
include_quantization: bool,
use_pitched: bool = True,
alignment: int = 256, # alignment in bytes
):
device = torch.device(f"cuda:{gpu_idx}")

if data_type == "FP16" or include_quantization:
result_tensor = torch.randn(
num_ads,
embedding_dimension * ads_tables,
dtype=torch.float16,
device=torch.device(f"cuda:{gpu_idx}"),
)
dtype = torch.float16
width_elems = embedding_dimension * ads_tables
elem_size = torch.finfo(dtype).bits // 8

if use_pitched:
width_bytes = width_elems * elem_size
pitch_bytes = math.ceil(width_bytes / alignment) * alignment
pitch_elems = pitch_bytes // elem_size
storage = torch.empty((num_ads, pitch_elems), dtype=dtype, device=device)
result_tensor = storage[:, :width_elems] # logical view
else:
result_tensor = torch.randn(num_ads, width_elems, dtype=dtype, device=device)

elif data_type == "INT8":
assert (
embedding_dimension % 2
) == 0, "needs to align to 2 bytes (half type size) for INT8"
result_tensor = torch.randint(
0,
255,
# 2 FP16 numbers for scale and bias, total of 4 bytes overhead
size=(num_ads, (embedding_dimension + 4) * ads_tables),
dtype=torch.uint8,
device=torch.device(f"cuda:{gpu_idx}"),
)
assert embedding_dimension % 2 == 0, "needs to align to 2 bytes for INT8"
dtype = torch.uint8
width_elems = (embedding_dimension + 4) * ads_tables
elem_size = 1

if use_pitched:
width_bytes = width_elems * elem_size
pitch_bytes = math.ceil(width_bytes / alignment) * alignment
pitch_elems = pitch_bytes // elem_size
storage = torch.randint(0, 255, (num_ads, pitch_elems), dtype=dtype, device=device)
result_tensor = storage[:, :width_elems]
else:
result_tensor = torch.randint(0, 255, (num_ads, width_elems), dtype=dtype, device=device)

elif data_type == "INT4":
assert (
embedding_dimension % 4
) == 0, "needs to align to 2 bytes (half type size) for INT4"
result_tensor = torch.randint(
0,
255,
# Using torch.uint8 for int4 storage
size=(num_ads, (embedding_dimension // 2 + 4) * ads_tables),
dtype=torch.uint8,
device=torch.device(f"cuda:{gpu_idx}"),
)
assert embedding_dimension % 4 == 0, "needs to align to 2 bytes for INT4"
dtype = torch.uint8
width_elems = (embedding_dimension // 2 + 4) * ads_tables
elem_size = 1

if use_pitched:
width_bytes = width_elems * elem_size
pitch_bytes = math.ceil(width_bytes / alignment) * alignment
pitch_elems = pitch_bytes // elem_size
storage = torch.randint(0, 255, (num_ads, pitch_elems), dtype=dtype, device=device)
result_tensor = storage[:, :width_elems]
else:
result_tensor = torch.randint(0, 255, (num_ads, width_elems), dtype=dtype, device=device)

else:
raise ValueError

Expand Down
25 changes: 21 additions & 4 deletions fbgemm_gpu/test/merge_pooled_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


import unittest

from typing import Tuple

Check failure on line 12 in fbgemm_gpu/test/merge_pooled_embeddings_test.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

F401 'typing.Tuple' imported but unused
import math
import fbgemm_gpu

import hypothesis.strategies as st
Expand All @@ -31,9 +32,18 @@

typed_gpu_unavailable: tuple[bool, str] = gpu_unavailable

def make_pitched_tensor(height, width, dtype, device, alignment=256):

Check failure on line 35 in fbgemm_gpu/test/merge_pooled_embeddings_test.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

E302 expected 2 blank lines, found 1
elem_size = torch.finfo(dtype).bits // 8 if dtype.is_floating_point else torch.iinfo(dtype).bits // 8
width_bytes = width * elem_size
pitch_bytes = math.ceil(width_bytes / alignment) * alignment
pitch_elems = pitch_bytes // elem_size
storage = torch.randn((height, pitch_elems), dtype=dtype, device=device)
view = storage[:, :width] # logical shape
return view.contiguous() if alignment == 0 else view # return pitched view


@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(open_source, "Not supported in open source yet")
# @unittest.skipIf(*gpu_unavailable)
# @unittest.skipIf(open_source, "Not supported in open source yet")
class MergePooledEmbeddingsTest(unittest.TestCase):
# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value =
Expand Down Expand Up @@ -127,7 +137,14 @@
) -> None:
dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}")
with torch.cuda.device(dst_device):
inputs = [torch.randn(10, 20) for _ in range(num_inputs)]
pitch = True
if pitch:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kudomcho This logic is strange, we set pitch and immediately check its value..

Copy link
Author

@kudomcho kudomcho Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@q10 This is to check the assertion allclose in case pitching is enabled. Currently it forces to True to test the all to one for the pitch condition. Any preference on passing the pitch condition on the test argument?

Copy link
Contributor

@q10 q10 Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@q10 This is to check the assertion allclose in case pitching is enabled. Currently it forces to True to test the all to one for the pitch condition. Any preference on passing the pitch condition on the test argument?

Yes, could you make this an argument to the test method, and use hypothesis @given(...) to pass the selection in?

inputs = [
make_pitched_tensor(10, 20, torch.float32, "cpu", alignment=256)

Check failure on line 143 in fbgemm_gpu/test/merge_pooled_embeddings_test.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

E122 continuation line missing indentation or outdented
for _ in range(num_inputs)]

Check failure on line 144 in fbgemm_gpu/test/merge_pooled_embeddings_test.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

E122 continuation line missing indentation or outdented
else:
inputs = [torch.randn(10, 20) for _ in range(num_inputs)]

Check failure on line 147 in fbgemm_gpu/test/merge_pooled_embeddings_test.py

View workflow job for this annotation

GitHub Actions / run-lint (3.13)

W293 blank line contains whitespace
cuda_inputs = [
input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs)
]
Expand Down
Loading