98
98
99
99
logger : logging .Logger = logging .getLogger (__name__ )
100
100
101
+ RES_ENABLED_TABLES_STR = "res_enabled_tables"
102
+ RES_STORE_SHARDS_STR = "res_store_shards"
103
+ ENABLE_RAW_EMBEDDING_STREAMING_STR = "enable_raw_embedding_streaming"
104
+
105
+
106
+ def _populate_res_params (config : GroupedEmbeddingConfig ) -> Tuple [bool , RESParams ]:
107
+ # populate res_params, which is used for raw embedding streaming
108
+ # here only populates the params available in fused_params and TBE configs
109
+ res_params : RESParams = RESParams ()
110
+ fused_params = config .fused_params or {}
111
+ # read and clean up the fused_params that are not in the constructor
112
+ if RES_STORE_SHARDS_STR in fused_params :
113
+ res_params .res_store_shards = fused_params .get (RES_STORE_SHARDS_STR )
114
+ del fused_params [RES_STORE_SHARDS_STR ]
115
+ res_enabled_tables : Optional [List [str ]] = None
116
+ if RES_ENABLED_TABLES_STR in fused_params :
117
+ res_enabled_tables = (
118
+ fused_params .get (RES_ENABLED_TABLES_STR ).split ("," )
119
+ if fused_params .get (RES_ENABLED_TABLES_STR ) is not None
120
+ else None
121
+ )
122
+ del fused_params [RES_ENABLED_TABLES_STR ]
123
+ enable_raw_embedding_streaming : Optional [bool ] = None
124
+ if ENABLE_RAW_EMBEDDING_STREAMING_STR in fused_params :
125
+ enable_raw_embedding_streaming = fused_params .get (
126
+ ENABLE_RAW_EMBEDDING_STREAMING_STR
127
+ )
128
+
129
+ if (
130
+ enable_raw_embedding_streaming is None
131
+ or enable_raw_embedding_streaming is False
132
+ ):
133
+ return (False , res_params )
134
+ res_params .table_names = [table .name for table in config .embedding_tables ]
135
+ if res_enabled_tables is not None and len (res_enabled_tables ) != 0 :
136
+ if len (set (res_enabled_tables ) & set (res_params .table_names )) == 0 :
137
+ logger .info (
138
+ f"No table is enabled for raw embedding streaming, "
139
+ f"raw embedding streaming is disabled, { res_enabled_tables = } { res_params .table_names = } "
140
+ )
141
+ return (False , res_params )
142
+ res_params .table_offsets = []
143
+ for emb_tbl in config .embedding_tables :
144
+ local_metadata = emb_tbl .local_metadata
145
+ if (
146
+ local_metadata is not None
147
+ and local_metadata .shard_offsets is not None
148
+ and len (local_metadata .shard_offsets ) >= 1
149
+ ):
150
+ res_params .table_offsets .append (local_metadata .shard_offsets [0 ])
151
+ return (enable_raw_embedding_streaming , res_params )
152
+
101
153
102
154
def _populate_ssd_tbe_params (config : GroupedEmbeddingConfig ) -> Dict [str , Any ]:
103
155
"""
@@ -186,22 +238,9 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
186
238
ssd_tbe_params ["cache_sets" ] = int (max_cache_sets )
187
239
ssd_tbe_params ["table_names" ] = [table .name for table in config .embedding_tables ]
188
240
189
- # populate res_params, which is used for raw embedding streaming
190
- # here only populates the params available in fused_params and TBE configs
191
- res_params : RESParams = RESParams ()
192
- res_params .table_names = [table .name for table in config .embedding_tables ]
193
- res_params .table_offsets = []
194
- for emb_tbl in config .embedding_tables :
195
- local_metadata = emb_tbl .local_metadata
196
- if (
197
- local_metadata is not None
198
- and local_metadata .shard_offsets is not None
199
- and len (local_metadata .shard_offsets ) >= 1
200
- ):
201
- res_params .table_offsets .append (local_metadata .shard_offsets [0 ])
202
- if "res_store_shards" in fused_params :
203
- res_params .res_store_shards = fused_params .get ("res_store_shards" )
241
+ enable_res , res_params = _populate_res_params (config )
204
242
ssd_tbe_params ["res_params" ] = res_params
243
+ ssd_tbe_params [ENABLE_RAW_EMBEDDING_STREAMING_STR ] = enable_res
205
244
206
245
return ssd_tbe_params
207
246
@@ -2190,6 +2229,9 @@ def __init__(
2190
2229
if "cache_precision" not in fused_params :
2191
2230
fused_params ["cache_precision" ] = weights_precision
2192
2231
2232
+ enable_res , res_params = _populate_res_params (config )
2233
+ fused_params [ENABLE_RAW_EMBEDDING_STREAMING_STR ] = enable_res
2234
+
2193
2235
self ._emb_module : SplitTableBatchedEmbeddingBagsCodegen = (
2194
2236
SplitTableBatchedEmbeddingBagsCodegen (
2195
2237
embedding_specs = list (
@@ -2208,6 +2250,7 @@ def __init__(
2208
2250
self ._col_offset ,
2209
2251
)
2210
2252
),
2253
+ res_params = res_params ,
2211
2254
** fused_params ,
2212
2255
)
2213
2256
)
@@ -3041,6 +3084,10 @@ def __init__(
3041
3084
fused_params ["cache_precision" ] = weights_precision
3042
3085
if weights_precision == SparseType .NFP8 :
3043
3086
fused_params ["cache_precision" ] = SparseType .FP16
3087
+
3088
+ enable_res , res_params = _populate_res_params (config )
3089
+ fused_params [ENABLE_RAW_EMBEDDING_STREAMING_STR ] = enable_res
3090
+
3044
3091
self ._emb_module : SplitTableBatchedEmbeddingBagsCodegen = (
3045
3092
SplitTableBatchedEmbeddingBagsCodegen (
3046
3093
embedding_specs = list (
@@ -3059,6 +3106,7 @@ def __init__(
3059
3106
self ._col_offset ,
3060
3107
)
3061
3108
),
3109
+ res_params = res_params ,
3062
3110
** fused_params ,
3063
3111
)
3064
3112
)
0 commit comments