Skip to content

Commit 3a36197

Browse files
committed
fixed the correctness of merged pool embedding
1 parent 517b712 commit 3a36197

File tree

2 files changed

+67
-33
lines changed

2 files changed

+67
-33
lines changed

fbgemm_gpu/bench/merge_embeddings_benchmark.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818
import tabulate
1919
import torch
20-
20+
import math
2121
from fbgemm_gpu.split_embedding_configs import SparseType
2222
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
2323
BoundsCheckMode,
@@ -101,45 +101,63 @@ def generate_requests(
101101

102102

103103
# pyre-fixme[3]: Return type must be annotated.
104+
104105
def _get_random_tensor(
105106
num_ads: int,
106107
embedding_dimension: int,
107108
ads_tables: int,
108109
data_type: str,
109110
gpu_idx: int,
110111
include_quantization: bool,
112+
use_pitched: bool = True,
113+
alignment: int = 256, # alignment in bytes
111114
):
115+
device = torch.device(f"cuda:{gpu_idx}")
116+
112117
if data_type == "FP16" or include_quantization:
113-
result_tensor = torch.randn(
114-
num_ads,
115-
embedding_dimension * ads_tables,
116-
dtype=torch.float16,
117-
device=torch.device(f"cuda:{gpu_idx}"),
118-
)
118+
dtype = torch.float16
119+
width_elems = embedding_dimension * ads_tables
120+
elem_size = torch.finfo(dtype).bits // 8
121+
122+
if use_pitched:
123+
width_bytes = width_elems * elem_size
124+
pitch_bytes = math.ceil(width_bytes / alignment) * alignment
125+
pitch_elems = pitch_bytes // elem_size
126+
storage = torch.empty((num_ads, pitch_elems), dtype=dtype, device=device)
127+
result_tensor = storage[:, :width_elems] # logical view
128+
else:
129+
result_tensor = torch.randn(num_ads, width_elems, dtype=dtype, device=device)
130+
119131
elif data_type == "INT8":
120-
assert (
121-
embedding_dimension % 2
122-
) == 0, "needs to align to 2 bytes (half type size) for INT8"
123-
result_tensor = torch.randint(
124-
0,
125-
255,
126-
# 2 FP16 numbers for scale and bias, total of 4 bytes overhead
127-
size=(num_ads, (embedding_dimension + 4) * ads_tables),
128-
dtype=torch.uint8,
129-
device=torch.device(f"cuda:{gpu_idx}"),
130-
)
132+
assert embedding_dimension % 2 == 0, "needs to align to 2 bytes for INT8"
133+
dtype = torch.uint8
134+
width_elems = (embedding_dimension + 4) * ads_tables
135+
elem_size = 1
136+
137+
if use_pitched:
138+
width_bytes = width_elems * elem_size
139+
pitch_bytes = math.ceil(width_bytes / alignment) * alignment
140+
pitch_elems = pitch_bytes // elem_size
141+
storage = torch.randint(0, 255, (num_ads, pitch_elems), dtype=dtype, device=device)
142+
result_tensor = storage[:, :width_elems]
143+
else:
144+
result_tensor = torch.randint(0, 255, (num_ads, width_elems), dtype=dtype, device=device)
145+
131146
elif data_type == "INT4":
132-
assert (
133-
embedding_dimension % 4
134-
) == 0, "needs to align to 2 bytes (half type size) for INT4"
135-
result_tensor = torch.randint(
136-
0,
137-
255,
138-
# Using torch.uint8 for int4 storage
139-
size=(num_ads, (embedding_dimension // 2 + 4) * ads_tables),
140-
dtype=torch.uint8,
141-
device=torch.device(f"cuda:{gpu_idx}"),
142-
)
147+
assert embedding_dimension % 4 == 0, "needs to align to 2 bytes for INT4"
148+
dtype = torch.uint8
149+
width_elems = (embedding_dimension // 2 + 4) * ads_tables
150+
elem_size = 1
151+
152+
if use_pitched:
153+
width_bytes = width_elems * elem_size
154+
pitch_bytes = math.ceil(width_bytes / alignment) * alignment
155+
pitch_elems = pitch_bytes // elem_size
156+
storage = torch.randint(0, 255, (num_ads, pitch_elems), dtype=dtype, device=device)
157+
result_tensor = storage[:, :width_elems]
158+
else:
159+
result_tensor = torch.randint(0, 255, (num_ads, width_elems), dtype=dtype, device=device)
160+
143161
else:
144162
raise ValueError
145163

fbgemm_gpu/test/merge_pooled_embeddings_test.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import unittest
1212
from typing import Tuple
13-
13+
import math
1414
import fbgemm_gpu
1515

1616
import hypothesis.strategies as st
@@ -32,9 +32,18 @@
3232

3333
typed_gpu_unavailable: Tuple[bool, str] = gpu_unavailable
3434

35+
def make_pitched_tensor(height, width, dtype, device, alignment=256):
36+
elem_size = torch.finfo(dtype).bits // 8 if dtype.is_floating_point else torch.iinfo(dtype).bits // 8
37+
width_bytes = width * elem_size
38+
pitch_bytes = math.ceil(width_bytes / alignment) * alignment
39+
pitch_elems = pitch_bytes // elem_size
40+
storage = torch.randn((height, pitch_elems), dtype=dtype, device=device)
41+
view = storage[:, :width] # logical shape
42+
return view.contiguous() if alignment == 0 else view # return pitched view
43+
3544

36-
@unittest.skipIf(*gpu_unavailable)
37-
@unittest.skipIf(open_source, "Not supported in open source yet")
45+
# @unittest.skipIf(*gpu_unavailable)
46+
# @unittest.skipIf(open_source, "Not supported in open source yet")
3847
class MergePooledEmbeddingsTest(unittest.TestCase):
3948
# pyre-fixme[56]: Pyre was not able to infer the type of argument
4049
# `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value =
@@ -128,7 +137,14 @@ def test_all_to_one_device(
128137
) -> None:
129138
dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}")
130139
with torch.cuda.device(dst_device):
131-
inputs = [torch.randn(10, 20) for _ in range(num_inputs)]
140+
pitch = True
141+
if pitch:
142+
inputs = [
143+
make_pitched_tensor(10, 20, torch.float32, "cpu", alignment=256)
144+
for _ in range(num_inputs)]
145+
else:
146+
inputs = [torch.randn(10, 20) for _ in range(num_inputs)]
147+
132148
cuda_inputs = [
133149
input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs)
134150
]

0 commit comments

Comments
 (0)