Skip to content

Commit 13b9fd6

Browse files
author
Gaurav Shukla
committed
[TBE] Add a test module for table batch embedding
This commit adds a test module specifically for table batch embedding algorithm. This test case is in reference to the FBGEMM table batch embedding: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py#L270 Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent eb06d21 commit 13b9fd6

File tree

3 files changed

+63
-0
lines changed

3 files changed

+63
-0
lines changed

e2e_testing/torchscript/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from . import constant_alloc
5050
from . import threshold
5151
from . import histogram_binning_calibration
52+
from . import table_batch_embedding
5253

5354
def _get_argparse():
5455
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import torch
7+
8+
from torch_mlir_e2e_test.torchscript.framework import TestUtils
9+
from torch_mlir_e2e_test.torchscript.registry import register_test_case
10+
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
11+
12+
# ==============================================================================
13+
14+
# Reference: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py#L270
15+
16+
# Global parameters.
17+
NUM_TABLES = 2
18+
NUM_EMBEDDINGS = 10
19+
EMBEDDING_DIM = 4
20+
BATCH_SIZE = 4
21+
BAG_SIZE = 3
22+
23+
24+
class TableBatchEmbeddingModule(torch.nn.Module):
25+
def __init__(self):
26+
super(TableBatchEmbeddingModule, self).__init__()
27+
self.num_tables = NUM_TABLES
28+
self.num_embeddings = NUM_EMBEDDINGS
29+
self.embedding_dim = EMBEDDING_DIM
30+
self.batch_size = BATCH_SIZE
31+
self.bag_size = BAG_SIZE
32+
# Currently, pooling_mode is fixed to 'sum'.
33+
self.nn_embedding_list = torch.nn.ModuleList([
34+
torch.nn.EmbeddingBag(
35+
self.num_embeddings, self.embedding_dim, mode="sum", sparse=False)
36+
for i in range(self.num_tables)
37+
])
38+
39+
@export
40+
@annotate_args([
41+
None,
42+
([-1], torch.int64, True),
43+
([-1], torch.int64, True),
44+
])
45+
def forward(self, indices, offsets):
46+
indices_list = indices.view(self.num_tables, self.batch_size, self.bag_size)
47+
final_output = torch.tensor([])
48+
for i, nn_embedding in enumerate(self.nn_embedding_list):
49+
indices = indices_list[i].view(-1)
50+
output = nn_embedding(indices, offsets).view(self.batch_size, -1)
51+
final_output = torch.cat((final_output, output), dim=1)
52+
return final_output
53+
54+
55+
@register_test_case(module_factory=lambda: TableBatchEmbeddingModule())
56+
def TableBatchEmbeddingModule_basic(module, tu: TestUtils):
57+
indices = torch.randint(0, NUM_EMBEDDINGS, (NUM_TABLES * BATCH_SIZE * BAG_SIZE,))
58+
offsets = torch.cumsum(
59+
torch.tensor([0] + [BAG_SIZE for _ in range(BATCH_SIZE - 1)], dtype=torch.int64), 0)
60+
module.forward(indices, offsets)
61+

e2e_testing/torchscript/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
COMMON_TORCH_MLIR_LOWERING_XFAILS = {
1717
"QuantizedMLP_basic",
1818
"IouOfModule_basic",
19+
"TableBatchEmbeddingModule_basic",
1920
}
2021
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
2122

0 commit comments

Comments
 (0)