diff --git a/fbgemm_gpu/bench/merge_embeddings_benchmark.py b/fbgemm_gpu/bench/merge_embeddings_benchmark.py index 6ed8ebb264..4ad4c36d64 100644 --- a/fbgemm_gpu/bench/merge_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/merge_embeddings_benchmark.py @@ -17,7 +17,6 @@ import numpy as np import tabulate import torch - from fbgemm_gpu.split_embedding_configs import SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( BoundsCheckMode, @@ -101,6 +100,8 @@ def generate_requests( # pyre-fixme[3]: Return type must be annotated. + + def _get_random_tensor( num_ads: int, embedding_dimension: int, @@ -108,38 +109,65 @@ def _get_random_tensor( data_type: str, gpu_idx: int, include_quantization: bool, + use_pitched: bool, + alignment: int = 256, # alignment in bytes ): + device = torch.device(f"cuda:{gpu_idx}") + if data_type == "FP16" or include_quantization: - result_tensor = torch.randn( - num_ads, - embedding_dimension * ads_tables, - dtype=torch.float16, - device=torch.device(f"cuda:{gpu_idx}"), - ) + dtype = torch.float16 + width_elems = embedding_dimension * ads_tables + elem_size = torch.finfo(dtype).bits // 8 + + if use_pitched: + width_bytes = width_elems * elem_size + pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment) + pitch_elems = pitch_bytes // elem_size + storage = torch.empty((num_ads, pitch_elems), dtype=dtype, device=device) + result_tensor = storage[:, :width_elems] # logical view + else: + result_tensor = torch.randn( + num_ads, width_elems, dtype=dtype, device=device + ) + elif data_type == "INT8": - assert ( - embedding_dimension % 2 - ) == 0, "needs to align to 2 bytes (half type size) for INT8" - result_tensor = torch.randint( - 0, - 255, - # 2 FP16 numbers for scale and bias, total of 4 bytes overhead - size=(num_ads, (embedding_dimension + 4) * ads_tables), - dtype=torch.uint8, - device=torch.device(f"cuda:{gpu_idx}"), - ) + assert embedding_dimension % 2 == 0, "needs to align to 2 bytes for INT8" + dtype = torch.uint8 + width_elems = (embedding_dimension + 4) * ads_tables + elem_size = 1 + + if use_pitched: + width_bytes = width_elems * elem_size + pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment) + pitch_elems = pitch_bytes // elem_size + storage = torch.randint( + 0, 255, (num_ads, pitch_elems), dtype=dtype, device=device + ) + result_tensor = storage[:, :width_elems] + else: + result_tensor = torch.randint( + 0, 255, (num_ads, width_elems), dtype=dtype, device=device + ) + elif data_type == "INT4": - assert ( - embedding_dimension % 4 - ) == 0, "needs to align to 2 bytes (half type size) for INT4" - result_tensor = torch.randint( - 0, - 255, - # Using torch.uint8 for int4 storage - size=(num_ads, (embedding_dimension // 2 + 4) * ads_tables), - dtype=torch.uint8, - device=torch.device(f"cuda:{gpu_idx}"), - ) + assert embedding_dimension % 4 == 0, "needs to align to 2 bytes for INT4" + dtype = torch.uint8 + width_elems = (embedding_dimension // 2 + 4) * ads_tables + elem_size = 1 + + if use_pitched: + width_bytes = width_elems * elem_size + pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment) + pitch_elems = pitch_bytes // elem_size + storage = torch.randint( + 0, 255, (num_ads, pitch_elems), dtype=dtype, device=device + ) + result_tensor = storage[:, :width_elems] + else: + result_tensor = torch.randint( + 0, 255, (num_ads, width_elems), dtype=dtype, device=device + ) + else: raise ValueError @@ -254,6 +282,7 @@ def benchmark( # noqa C901 num_ads: int, embedding_dimension: int, ads_tables: int, + use_pitched: bool, iters: int = 10, p2p_bw: bool = False, dst_device: int = 0, @@ -299,6 +328,7 @@ def benchmark( # noqa C901 data_type, gpu_idx, include_quantization, + use_pitched, ) for gpu_idx in range(num_gpus) ] @@ -486,6 +516,7 @@ def pool_func_with_quantization( @click.option("--num_of_embeddings", default=100000, type=int) @click.option("--pooling_factor", default=25, type=int) @click.option("--sweep", is_flag=True, default=False) +@click.option("--use_pitched", is_flag=True, default=False) def cli( all_to_one_only: bool, sum_reduce_to_one_only: bool, @@ -501,6 +532,7 @@ def cli( num_of_embeddings: int, pooling_factor: int, sweep: bool, + use_pitched: bool, ) -> None: csv_header = ( "mode, data_type, num_ads, embedding_dimension, ads_tables, num_gpus, dst_device, all_to_one_only, " @@ -535,6 +567,7 @@ def handler(signum, frame): num_ads, embedding_dimension, ads_tables, + use_pitched, iters, p2p_bw, dst_device, @@ -559,6 +592,7 @@ def handler(signum, frame): num_ads, embedding_dimension, ads_tables, + use_pitched, iters, p2p_bw, dst_device, diff --git a/fbgemm_gpu/test/merge_pooled_embeddings_test.py b/fbgemm_gpu/test/merge_pooled_embeddings_test.py index 996c992115..d6ff95de9c 100644 --- a/fbgemm_gpu/test/merge_pooled_embeddings_test.py +++ b/fbgemm_gpu/test/merge_pooled_embeddings_test.py @@ -14,6 +14,7 @@ import fbgemm_gpu import hypothesis.strategies as st +import numpy as np import torch from hypothesis import given, settings, Verbosity @@ -33,8 +34,22 @@ typed_gpu_unavailable: Tuple[bool, str] = gpu_unavailable -@unittest.skipIf(*gpu_unavailable) -@unittest.skipIf(open_source, "Not supported in open source yet") +def make_pitched_tensor(height, width, dtype, device, alignment=256): + elem_size = ( + torch.finfo(dtype).bits // 8 + if dtype.is_floating_point + else torch.iinfo(dtype).bits // 8 + ) + width_bytes = width * elem_size + pitch_bytes = int(np.ceil(width_bytes / alignment) * alignment) + pitch_elems = pitch_bytes // elem_size + storage = torch.randn((height, pitch_elems), dtype=dtype, device=device) + view = storage[:, :width] # logical shape + return view.contiguous() if alignment == 0 else view # return pitched view + + +# @unittest.skipIf(*gpu_unavailable) +# @unittest.skipIf(open_source, "Not supported in open source yet") class MergePooledEmbeddingsTest(unittest.TestCase): # pyre-fixme[56]: Pyre was not able to infer the type of argument # `hypothesis.strategies.integers($parameter$min_value = 1, $parameter$max_value = @@ -114,6 +129,7 @@ def ref(pooled_ad_embeddings, batch_indices): num_inputs=st.integers(min_value=1, max_value=10), num_gpus=st.integers(min_value=1, max_value=torch.cuda.device_count()), r=st.randoms(use_true_random=False), + use_pitched=st.booleans(), ) # Can instantiate 8 contexts which takes a long time. @settings(verbosity=Verbosity.verbose, max_examples=40, deadline=None) @@ -125,10 +141,18 @@ def test_all_to_one_device( num_gpus, # pyre-fixme[2]: Parameter must be annotated. r, + use_pitched, ) -> None: dst_device = torch.device(f"cuda:{r.randint(0, num_gpus - 1)}") with torch.cuda.device(dst_device): - inputs = [torch.randn(10, 20) for _ in range(num_inputs)] + if use_pitched: + inputs = [ + make_pitched_tensor(10, 20, torch.float32, "cpu", alignment=256) + for _ in range(num_inputs) + ] + else: + inputs = [torch.randn(10, 20) for _ in range(num_inputs)] + cuda_inputs = [ input.to(f"cuda:{i % num_gpus}") for i, input in enumerate(inputs) ]