Skip to content

Commit 60f7f87

Browse files
chouxifacebook-github-bot
authored andcommitted
Add res parameters to SplitTableBatchedEmbeddingBagsCodegen (#3330)
Summary: Pull Request resolved: #3330 Adding RES to SplitTableBatchedEmbeddingBagsCodegen The parameter populate logic is the similar existing one in _populate_ssd_tbe_params so extract them to a function. Now pulling everything from fused params to be reused by UVM_CACHE and SSD modes. The logic in addition here is the additional res_enabled_tables, only enable the RES for specific TBE when it contains enabled tables. - This is needed because currently we only want to enable it for one table and don't want to overhead for other tbes. Reviewed By: yixin94, yingufan, xinyuanzzz Differential Revision: D80503597 fbshipit-source-id: 9d854844a80e72df740a3a0ca8b62d85e3ebf491
1 parent a07bc63 commit 60f7f87

File tree

1 file changed

+63
-15
lines changed

1 file changed

+63
-15
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,58 @@
9898

9999
logger: logging.Logger = logging.getLogger(__name__)
100100

101+
RES_ENABLED_TABLES_STR = "res_enabled_tables"
102+
RES_STORE_SHARDS_STR = "res_store_shards"
103+
ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming"
104+
105+
106+
def _populate_res_params(config: GroupedEmbeddingConfig) -> Tuple[bool, RESParams]:
107+
# populate res_params, which is used for raw embedding streaming
108+
# here only populates the params available in fused_params and TBE configs
109+
res_params: RESParams = RESParams()
110+
fused_params = config.fused_params or {}
111+
# read and clean up the fused_params that are not in the constructor
112+
if RES_STORE_SHARDS_STR in fused_params:
113+
res_params.res_store_shards = fused_params.get(RES_STORE_SHARDS_STR)
114+
del fused_params[RES_STORE_SHARDS_STR]
115+
res_enabled_tables: Optional[List[str]] = None
116+
if RES_ENABLED_TABLES_STR in fused_params:
117+
res_enabled_tables = (
118+
fused_params.get(RES_ENABLED_TABLES_STR).split(",")
119+
if fused_params.get(RES_ENABLED_TABLES_STR) is not None
120+
else None
121+
)
122+
del fused_params[RES_ENABLED_TABLES_STR]
123+
enable_raw_embedding_streaming: Optional[bool] = None
124+
if ENABLE_RAW_EMBEDDING_STREAMING_STR in fused_params:
125+
enable_raw_embedding_streaming = fused_params.get(
126+
ENABLE_RAW_EMBEDDING_STREAMING_STR
127+
)
128+
129+
if (
130+
enable_raw_embedding_streaming is None
131+
or enable_raw_embedding_streaming is False
132+
):
133+
return (False, res_params)
134+
res_params.table_names = [table.name for table in config.embedding_tables]
135+
if res_enabled_tables is not None and len(res_enabled_tables) != 0:
136+
if len(set(res_enabled_tables) & set(res_params.table_names)) == 0:
137+
logger.info(
138+
f"No table is enabled for raw embedding streaming, "
139+
f"raw embedding streaming is disabled, {res_enabled_tables=} {res_params.table_names=}"
140+
)
141+
return (False, res_params)
142+
res_params.table_offsets = []
143+
for emb_tbl in config.embedding_tables:
144+
local_metadata = emb_tbl.local_metadata
145+
if (
146+
local_metadata is not None
147+
and local_metadata.shard_offsets is not None
148+
and len(local_metadata.shard_offsets) >= 1
149+
):
150+
res_params.table_offsets.append(local_metadata.shard_offsets[0])
151+
return (enable_raw_embedding_streaming, res_params)
152+
101153

102154
def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
103155
"""
@@ -186,22 +238,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
186238
ssd_tbe_params["cache_sets"] = int(max_cache_sets)
187239
ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables]
188240

189-
# populate res_params, which is used for raw embedding streaming
190-
# here only populates the params available in fused_params and TBE configs
191-
res_params: RESParams = RESParams()
192-
res_params.table_names = [table.name for table in config.embedding_tables]
193-
res_params.table_offsets = []
194-
for emb_tbl in config.embedding_tables:
195-
local_metadata = emb_tbl.local_metadata
196-
if (
197-
local_metadata is not None
198-
and local_metadata.shard_offsets is not None
199-
and len(local_metadata.shard_offsets) >= 1
200-
):
201-
res_params.table_offsets.append(local_metadata.shard_offsets[0])
202-
if "res_store_shards" in fused_params:
203-
res_params.res_store_shards = fused_params.get("res_store_shards")
241+
enable_res, res_params = _populate_res_params(config)
204242
ssd_tbe_params["res_params"] = res_params
243+
ssd_tbe_params[ENABLE_RAW_EMBEDDING_STREAMING_STR] = enable_res
205244

206245
return ssd_tbe_params
207246

@@ -2190,6 +2229,9 @@ def __init__(
21902229
if "cache_precision" not in fused_params:
21912230
fused_params["cache_precision"] = weights_precision
21922231

2232+
enable_res, res_params = _populate_res_params(config)
2233+
fused_params[ENABLE_RAW_EMBEDDING_STREAMING_STR] = enable_res
2234+
21932235
self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = (
21942236
SplitTableBatchedEmbeddingBagsCodegen(
21952237
embedding_specs=list(
@@ -2208,6 +2250,7 @@ def __init__(
22082250
self._col_offset,
22092251
)
22102252
),
2253+
res_params=res_params,
22112254
**fused_params,
22122255
)
22132256
)
@@ -3041,6 +3084,10 @@ def __init__(
30413084
fused_params["cache_precision"] = weights_precision
30423085
if weights_precision == SparseType.NFP8:
30433086
fused_params["cache_precision"] = SparseType.FP16
3087+
3088+
enable_res, res_params = _populate_res_params(config)
3089+
fused_params[ENABLE_RAW_EMBEDDING_STREAMING_STR] = enable_res
3090+
30443091
self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = (
30453092
SplitTableBatchedEmbeddingBagsCodegen(
30463093
embedding_specs=list(
@@ -3059,6 +3106,7 @@ def __init__(
30593106
self._col_offset,
30603107
)
30613108
),
3109+
res_params=res_params,
30623110
**fused_params,
30633111
)
30643112
)

0 commit comments

Comments
 (0)